diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..1ac6315 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,94 @@ +name: CI pipeline + +on: + push: + branches: [ 'dev-*', 'feature-*' ] + pull_request: + branches: [ 'dev-*', 'feature-*', 'main' ] + workflow_dispatch: + +jobs: + test: + runs-on: ubuntu-latest + name: Test & Code Quality + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Python 3.11 + uses: actions/setup-python@v5 + with: + python-version: '3.11' + cache: 'pip' + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements.txt + pip install -r requirements-dev.txt + + - name: Code style check (isort + Black, one order — no conflict) + run: | + echo "🔍 Applying isort then black (pyproject.toml: isort profile=black)..." + python -m isort . + python -m black . + echo "🔍 Checking that repo is already formatted (no diff after isort+black)..." + git diff --exit-code || ( + echo "❌ Code style drift. Locally run: isort . && black . && git add -A && git commit -m 'style: isort + black'" + exit 1 + ) + + - name: Linting (flake8) - Critical errors + run: | + echo "🔍 Running flake8 linter (critical errors only)..." + flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics || true + + - name: Linting (flake8) - Warnings + run: | + echo "🔍 Running flake8 linter (warnings)..." + flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics || true + continue-on-error: true + + - name: Run tests + run: | + echo "🧪 Running tests..." + python -m pytest tests/ -v --tb=short + + - name: Send test success notification + if: success() + uses: appleboy/telegram-action@v1.0.0 + with: + to: ${{ secrets.TELEGRAM_CHAT_ID }} + token: ${{ secrets.TELEGRAM_BOT_TOKEN }} + message: | + ✅ CI Tests Passed + + 📦 Repository: telegram-helper-bot + 🌿 Branch: ${{ github.ref_name }} + 📝 Commit: ${{ github.sha }} + 👤 Author: ${{ github.actor }} + + ✅ All tests passed! Code quality checks completed successfully. + + 🔗 View details: ${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }} + continue-on-error: true + + - name: Send test failure notification + if: failure() + uses: appleboy/telegram-action@v1.0.0 + with: + to: ${{ secrets.TELEGRAM_CHAT_ID }} + token: ${{ secrets.TELEGRAM_BOT_TOKEN }} + message: | + ❌ CI Tests Failed + + 📦 Repository: telegram-helper-bot + 🌿 Branch: ${{ github.ref_name }} + 📝 Commit: ${{ github.sha }} + 👤 Author: ${{ github.actor }} + + ❌ Tests failed! Deployment blocked. Please fix the issues and try again. + + 🔗 View details: ${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }} + continue-on-error: true \ No newline at end of file diff --git a/.github/workflows/deploy.yml b/.github/workflows/deploy.yml new file mode 100644 index 0000000..98d39eb --- /dev/null +++ b/.github/workflows/deploy.yml @@ -0,0 +1,357 @@ +name: Deploy to Production + +on: + push: + branches: [ main ] + workflow_dispatch: + inputs: + action: + description: 'Action to perform' + required: true + type: choice + options: + - deploy + - rollback + rollback_commit: + description: 'Commit hash to rollback to (optional, uses last successful if empty)' + required: false + type: string + +jobs: + deploy: + runs-on: ubuntu-latest + name: Deploy to Production + if: | + github.event_name == 'push' || + (github.event_name == 'workflow_dispatch' && github.event.inputs.action == 'deploy') + concurrency: + group: production-deploy-telegram-helper-bot + cancel-in-progress: false + environment: + name: production + + steps: + - name: Checkout code + uses: actions/checkout@v4 + with: + ref: main + + - name: Deploy to server + uses: appleboy/ssh-action@v1.0.0 + with: + host: ${{ vars.SERVER_HOST || secrets.SERVER_HOST }} + username: ${{ vars.SERVER_USER || secrets.SERVER_USER }} + key: ${{ secrets.SSH_PRIVATE_KEY }} + port: ${{ vars.SSH_PORT || secrets.SSH_PORT || 22 }} + script: | + set -e + export TELEGRAM_BOT_TOKEN="${{ secrets.TELEGRAM_BOT_TOKEN }}" + export TELEGRAM_TEST_BOT_TOKEN="${{ secrets.TELEGRAM_TEST_BOT_TOKEN }}" + + echo "🚀 Starting deployment to production..." + + cd /home/prod + + # Сохраняем информацию о коммите + CURRENT_COMMIT=$(git rev-parse HEAD) + COMMIT_MESSAGE=$(git log -1 --pretty=format:"%s" || echo "Unknown") + COMMIT_AUTHOR=$(git log -1 --pretty=format:"%an" || echo "Unknown") + TIMESTAMP=$(date +"%Y-%m-%d %H:%M:%S") + + echo "📝 Current commit: $CURRENT_COMMIT" + echo "📝 Commit message: $COMMIT_MESSAGE" + echo "📝 Author: $COMMIT_AUTHOR" + + # Записываем в историю деплоев + HISTORY_FILE="/home/prod/.deploy_history_telegram_helper_bot.txt" + HISTORY_SIZE="${DEPLOY_HISTORY_SIZE:-10}" + echo "${TIMESTAMP}|${CURRENT_COMMIT}|${COMMIT_MESSAGE}|${COMMIT_AUTHOR}|deploying" >> "$HISTORY_FILE" + tail -n "$HISTORY_SIZE" "$HISTORY_FILE" > "${HISTORY_FILE}.tmp" && mv "${HISTORY_FILE}.tmp" "$HISTORY_FILE" + + # Обновляем код + echo "📥 Pulling latest changes from main..." + sudo chown -R deploy:deploy /home/prod/bots/telegram-helper-bot || true + cd /home/prod/bots/telegram-helper-bot + git fetch origin main + git reset --hard origin/main + sudo chown -R deploy:deploy /home/prod/bots/telegram-helper-bot || true + + NEW_COMMIT=$(git rev-parse HEAD) + echo "✅ Code updated: $CURRENT_COMMIT → $NEW_COMMIT" + + # Применяем миграции БД перед перезапуском контейнера + echo "🔄 Applying database migrations..." + DB_PATH="/home/prod/bots/telegram-helper-bot/database/tg-bot-database.db" + + if [ -f "$DB_PATH" ]; then + cd /home/prod/bots/telegram-helper-bot + python3 scripts/apply_migrations.py --db "$DB_PATH" || { + echo "❌ Ошибка при применении миграций!" + exit 1 + } + echo "✅ Миграции применены успешно" + else + echo "⚠️ База данных не найдена, пропускаем миграции (будет создана при первом запуске)" + fi + + # Валидация docker-compose + echo "🔍 Validating docker-compose configuration..." + cd /home/prod + docker-compose config > /dev/null || exit 1 + echo "✅ docker-compose.yml is valid" + + # Проверка дискового пространства + MIN_FREE_GB=5 + AVAILABLE_SPACE=$(df -BG /home/prod 2>/dev/null | tail -1 | awk '{print $4}' | sed 's/G//' || echo "0") + echo "💾 Available disk space: ${AVAILABLE_SPACE}GB" + + if [ "$AVAILABLE_SPACE" -lt "$MIN_FREE_GB" ]; then + echo "⚠️ Insufficient disk space! Cleaning up Docker resources..." + docker system prune -f --volumes || true + fi + + # Пересобираем и перезапускаем контейнер бота + echo "🔨 Rebuilding and restarting telegram-bot container..." + cd /home/prod + + export TELEGRAM_BOT_TOKEN TELEGRAM_TEST_BOT_TOKEN + docker-compose stop telegram-bot || true + docker-compose build --pull telegram-bot + docker-compose up -d telegram-bot + + echo "✅ Telegram bot container rebuilt and started" + + # Ждем немного и проверяем healthcheck + echo "⏳ Waiting for container to start..." + sleep 10 + + if docker ps | grep -q bots_telegram_bot; then + echo "✅ Container is running" + else + echo "❌ Container failed to start!" + docker logs bots_telegram_bot --tail 50 || true + exit 1 + fi + + - name: Update deploy history + if: always() + uses: appleboy/ssh-action@v1.0.0 + with: + host: ${{ vars.SERVER_HOST || secrets.SERVER_HOST }} + username: ${{ vars.SERVER_USER || secrets.SERVER_USER }} + key: ${{ secrets.SSH_PRIVATE_KEY }} + port: ${{ vars.SSH_PORT || secrets.SSH_PORT || 22 }} + script: | + HISTORY_FILE="/home/prod/.deploy_history_telegram_helper_bot.txt" + + if [ -f "$HISTORY_FILE" ]; then + DEPLOY_STATUS="failed" + if [ "${{ job.status }}" = "success" ]; then + DEPLOY_STATUS="success" + fi + + sed -i '$s/|deploying$/|'"$DEPLOY_STATUS"'/' "$HISTORY_FILE" + echo "✅ Deploy history updated: $DEPLOY_STATUS" + fi + + - name: Send deployment notification + if: always() + uses: appleboy/telegram-action@v1.0.0 + with: + to: ${{ secrets.TELEGRAM_CHAT_ID }} + token: ${{ secrets.TELEGRAM_BOT_TOKEN }} + message: | + ${{ job.status == 'success' && '✅' || '❌' }} Deployment: ${{ job.status }} + + 📦 Repository: telegram-helper-bot + 🌿 Branch: main + 📝 Commit: ${{ github.sha }} + 👤 Author: ${{ github.actor }} + + ${{ job.status == 'success' && '✅ Deployment successful! Container restarted with migrations applied.' || '❌ Deployment failed! Check logs for details.' }} + + 🔗 View details: ${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }} + continue-on-error: true + + - name: Get PR body from merged PR + if: job.status == 'success' && github.event_name == 'push' + env: + GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} + run: | + echo "🔍 Searching for merged PR associated with commit ${{ github.sha }}..." + + # Находим последний мерженный PR для main ветки по merge commit SHA + COMMIT_SHA="${{ github.sha }}" + PR_NUMBER=$(gh pr list --state merged --base main --limit 10 --json number,mergeCommit --jq ".[] | select(.mergeCommit.oid == \"$COMMIT_SHA\") | .number" | head -1) + + # Если не нашли по merge commit, ищем последний мерженный PR + if [ -z "$PR_NUMBER" ]; then + echo "⚠️ PR not found by merge commit, trying to get latest merged PR..." + PR_NUMBER=$(gh pr list --state merged --base main --limit 1 --json number --jq '.[0].number') + fi + + if [ -n "$PR_NUMBER" ] && [ "$PR_NUMBER" != "null" ]; then + echo "✅ Found PR #$PR_NUMBER" + PR_BODY=$(gh pr view $PR_NUMBER --json body --jq '.body // ""') + + if [ -n "$PR_BODY" ] && [ "$PR_BODY" != "null" ]; then + echo "PR_BODY<> $GITHUB_ENV + echo "$PR_BODY" >> $GITHUB_ENV + echo "EOF" >> $GITHUB_ENV + echo "PR_NUMBER=$PR_NUMBER" >> $GITHUB_ENV + echo "✅ PR body extracted successfully" + else + echo "⚠️ PR body is empty" + fi + else + echo "⚠️ No merged PR found for this commit" + fi + continue-on-error: true + + - name: Send PR body to important logs + if: job.status == 'success' && github.event_name == 'push' && env.PR_BODY != '' + uses: appleboy/telegram-action@v1.0.0 + with: + to: ${{ secrets.IMPORTANT_LOGS_CHAT }} + token: ${{ secrets.TELEGRAM_BOT_TOKEN }} + message: | + 📋 Pull Request Description (PR #${{ env.PR_NUMBER }}): + + ${{ env.PR_BODY }} + + 🔗 PR: ${{ github.server_url }}/${{ github.repository }}/pull/${{ env.PR_NUMBER }} + 📝 Commit: ${{ github.sha }} + continue-on-error: true + + rollback: + runs-on: ubuntu-latest + name: Rollback to Previous Version + if: | + github.event_name == 'workflow_dispatch' && + github.event.inputs.action == 'rollback' + environment: + name: production + + steps: + - name: Checkout code + uses: actions/checkout@v4 + with: + ref: main + + - name: Rollback on server + uses: appleboy/ssh-action@v1.0.0 + with: + host: ${{ vars.SERVER_HOST || secrets.SERVER_HOST }} + username: ${{ vars.SERVER_USER || secrets.SERVER_USER }} + key: ${{ secrets.SSH_PRIVATE_KEY }} + port: ${{ vars.SSH_PORT || secrets.SSH_PORT || 22 }} + script: | + set -e + export TELEGRAM_BOT_TOKEN="${{ secrets.TELEGRAM_BOT_TOKEN }}" + export TELEGRAM_TEST_BOT_TOKEN="${{ secrets.TELEGRAM_TEST_BOT_TOKEN }}" + + echo "🔄 Starting rollback..." + + cd /home/prod + + # Определяем коммит для отката + ROLLBACK_COMMIT="${{ github.event.inputs.rollback_commit }}" + HISTORY_FILE="/home/prod/.deploy_history_telegram_helper_bot.txt" + + if [ -z "$ROLLBACK_COMMIT" ]; then + echo "📝 No commit specified, finding last successful deploy..." + if [ -f "$HISTORY_FILE" ]; then + ROLLBACK_COMMIT=$(grep "|success$" "$HISTORY_FILE" | tail -1 | cut -d'|' -f2 || echo "") + fi + + if [ -z "$ROLLBACK_COMMIT" ]; then + echo "❌ No successful deploy found in history!" + echo "💡 Please specify commit hash manually or check deploy history" + exit 1 + fi + fi + + echo "📝 Rolling back to commit: $ROLLBACK_COMMIT" + + # Проверяем, что коммит существует + cd /home/prod/bots/telegram-helper-bot + if ! git cat-file -e "$ROLLBACK_COMMIT" 2>/dev/null; then + echo "❌ Commit $ROLLBACK_COMMIT not found!" + exit 1 + fi + + # Сохраняем текущий коммит + CURRENT_COMMIT=$(git rev-parse HEAD) + COMMIT_MESSAGE=$(git log -1 --pretty=format:"%s" "$ROLLBACK_COMMIT" || echo "Rollback") + TIMESTAMP=$(date +"%Y-%m-%d %H:%M:%S") + + echo "📝 Current commit: $CURRENT_COMMIT" + echo "📝 Target commit: $ROLLBACK_COMMIT" + echo "📝 Commit message: $COMMIT_MESSAGE" + + # Исправляем права перед откатом + sudo chown -R deploy:deploy /home/prod/bots/telegram-helper-bot || true + + # Откатываем код + echo "🔄 Rolling back code..." + git fetch origin main + git reset --hard "$ROLLBACK_COMMIT" + + # Исправляем права после отката + sudo chown -R deploy:deploy /home/prod/bots/telegram-helper-bot || true + + echo "✅ Code rolled back: $CURRENT_COMMIT → $ROLLBACK_COMMIT" + + # Валидация docker-compose + echo "🔍 Validating docker-compose configuration..." + cd /home/prod + docker-compose config > /dev/null || exit 1 + echo "✅ docker-compose.yml is valid" + + # Проверка дискового пространства + MIN_FREE_GB=5 + AVAILABLE_SPACE=$(df -BG /home/prod 2>/dev/null | tail -1 | awk '{print $4}' | sed 's/G//' || echo "0") + echo "💾 Available disk space: ${AVAILABLE_SPACE}GB" + + if [ "$AVAILABLE_SPACE" -lt "$MIN_FREE_GB" ]; then + echo "⚠️ Insufficient disk space! Cleaning up Docker resources..." + docker system prune -f --volumes || true + fi + + # Пересобираем и перезапускаем контейнер + echo "🔨 Rebuilding and restarting telegram-bot container..." + cd /home/prod + + export TELEGRAM_BOT_TOKEN TELEGRAM_TEST_BOT_TOKEN + docker-compose stop telegram-bot || true + docker-compose build --pull telegram-bot + docker-compose up -d telegram-bot + + echo "✅ Telegram bot container rebuilt and started" + + # Записываем в историю + echo "${TIMESTAMP}|${ROLLBACK_COMMIT}|Rollback to: ${COMMIT_MESSAGE}|github-actions|rolled_back" >> "$HISTORY_FILE" + HISTORY_SIZE="${DEPLOY_HISTORY_SIZE:-10}" + tail -n "$HISTORY_SIZE" "$HISTORY_FILE" > "${HISTORY_FILE}.tmp" && mv "${HISTORY_FILE}.tmp" "$HISTORY_FILE" + + echo "✅ Rollback completed successfully" + + - name: Send rollback notification + if: always() + uses: appleboy/telegram-action@v1.0.0 + with: + to: ${{ secrets.TELEGRAM_CHAT_ID }} + token: ${{ secrets.TELEGRAM_BOT_TOKEN }} + message: | + ${{ job.status == 'success' && '🔄' || '❌' }} Rollback: ${{ job.status }} + + 📦 Repository: telegram-helper-bot + 🌿 Branch: main + 📝 Rolled back to: ${{ github.event.inputs.rollback_commit || 'Last successful commit' }} + 👤 Triggered by: ${{ github.actor }} + + ${{ job.status == 'success' && '✅ Rollback completed successfully! Services restored to previous version.' || '❌ Rollback failed! Check logs for details.' }} + + 🔗 View details: ${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }} + continue-on-error: true + diff --git a/.gitignore b/.gitignore index 6cc702c..8016b04 100644 --- a/.gitignore +++ b/.gitignore @@ -34,6 +34,9 @@ database/test.db test.db *.db +# Случайно созданный файл при использовании SQLite :memory: не по назначению +:memory: + # IDE and editor files .vscode/ .idea/ diff --git a/:memory: b/:memory: deleted file mode 100644 index 159e90a..0000000 Binary files a/:memory: and /dev/null differ diff --git a/database/__init__.py b/database/__init__.py index 731bfbd..ae0cdc6 100644 --- a/database/__init__.py +++ b/database/__init__.py @@ -11,15 +11,35 @@ from .async_db import AsyncBotDB from .base import DatabaseConnection -from .models import (Admin, AudioListenRecord, AudioMessage, AudioModerate, - BlacklistUser, MessageContentLink, Migration, PostContent, - TelegramPost, User, UserMessage) +from .models import ( + Admin, + AudioListenRecord, + AudioMessage, + AudioModerate, + BlacklistUser, + MessageContentLink, + Migration, + PostContent, + TelegramPost, + User, + UserMessage, +) from .repository_factory import RepositoryFactory # Для обратной совместимости экспортируем старый интерфейс __all__ = [ - 'User', 'BlacklistUser', 'UserMessage', 'TelegramPost', 'PostContent', - 'MessageContentLink', 'Admin', 'Migration', 'AudioMessage', 'AudioListenRecord', 'AudioModerate', - 'RepositoryFactory', 'DatabaseConnection', 'AsyncBotDB' + "User", + "BlacklistUser", + "UserMessage", + "TelegramPost", + "PostContent", + "MessageContentLink", + "Admin", + "Migration", + "AudioMessage", + "AudioListenRecord", + "AudioModerate", + "RepositoryFactory", + "DatabaseConnection", + "AsyncBotDB", ] - diff --git a/database/async_db.py b/database/async_db.py index e5f74d8..39bdf94 100644 --- a/database/async_db.py +++ b/database/async_db.py @@ -2,202 +2,253 @@ from datetime import datetime from typing import Any, Dict, List, Optional, Tuple import aiosqlite -from database.models import (Admin, AudioMessage, BlacklistHistoryRecord, - BlacklistUser, PostContent, TelegramPost, User, - UserMessage) + +from database.models import ( + Admin, + AudioMessage, + BlacklistHistoryRecord, + BlacklistUser, + PostContent, + TelegramPost, + User, + UserMessage, +) from database.repository_factory import RepositoryFactory class AsyncBotDB: """Новый асинхронный класс для работы с базой данных с использованием репозиториев.""" - + def __init__(self, db_path: str): self.factory = RepositoryFactory(db_path) self.logger = self.factory.users.logger - + async def create_tables(self): """Создание всех таблиц в базе данных.""" await self.factory.create_all_tables() self.logger.info("Все таблицы успешно созданы") - + # Методы для работы с пользователями async def user_exists(self, user_id: int) -> bool: """Проверяет, существует ли пользователь в базе данных.""" return await self.factory.users.user_exists(user_id) - + async def add_user(self, user: User): """Добавление нового пользователя.""" await self.factory.users.add_user(user) - + async def get_user_info(self, user_id: int) -> Optional[Dict[str, Any]]: """Получение информации о пользователе.""" user = await self.factory.users.get_user_info(user_id) if user: return { - 'username': user.username, - 'full_name': user.full_name, - 'has_stickers': user.has_stickers, - 'emoji': user.emoji + "username": user.username, + "full_name": user.full_name, + "has_stickers": user.has_stickers, + "emoji": user.emoji, } return None - + async def get_username(self, user_id: int) -> Optional[str]: """Возвращает username пользователя.""" return await self.factory.users.get_username(user_id) - + async def get_user_id_by_username(self, username: str) -> Optional[int]: """Возвращает user_id пользователя по username.""" return await self.factory.users.get_user_id_by_username(username) - + async def get_full_name_by_id(self, user_id: int) -> Optional[str]: """Возвращает full_name пользователя.""" return await self.factory.users.get_full_name_by_id(user_id) - - async def get_username_and_full_name(self, user_id: int) -> tuple[Optional[str], Optional[str]]: + + async def get_username_and_full_name( + self, user_id: int + ) -> tuple[Optional[str], Optional[str]]: """Возвращает username и full_name пользователя.""" username = await self.get_username(user_id) full_name = await self.get_full_name_by_id(user_id) return username, full_name - + async def get_user_by_id(self, user_id: int) -> Optional[User]: """Получение пользователя по ID.""" return await self.factory.users.get_user_by_id(user_id) - + async def get_user_first_name(self, user_id: int) -> Optional[str]: """Возвращает first_name пользователя.""" return await self.factory.users.get_user_first_name(user_id) - + async def get_all_user_id(self) -> List[int]: """Возвращает список всех user_id.""" return await self.factory.users.get_all_user_ids() - + async def get_last_users(self, limit: int = 30) -> List[tuple]: """Получение последних пользователей.""" return await self.factory.users.get_last_users(limit) - + async def update_user_date(self, user_id: int): """Обновление даты последнего изменения пользователя.""" await self.factory.users.update_user_date(user_id) - - async def update_user_info(self, user_id: int, username: str = None, full_name: str = None): + + async def update_user_info( + self, user_id: int, username: str = None, full_name: str = None + ): """Обновление информации о пользователе.""" await self.factory.users.update_user_info(user_id, username, full_name) - + async def update_user_emoji(self, user_id: int, emoji: str): """Обновление эмодзи пользователя.""" await self.factory.users.update_user_emoji(user_id, emoji) - + async def update_stickers_info(self, user_id: int): """Обновление информации о стикерах.""" await self.factory.users.update_stickers_info(user_id) - + async def get_stickers_info(self, user_id: int) -> bool: """Получение информации о стикерах.""" return await self.factory.users.get_stickers_info(user_id) - + async def check_emoji_exists(self, emoji: str) -> bool: """Проверка существования эмодзи.""" return await self.factory.users.check_emoji_exists(emoji) - + async def get_user_emoji(self, user_id: int) -> str: """Получает эмодзи пользователя.""" return await self.factory.users.get_user_emoji(user_id) - + async def check_emoji_for_user(self, user_id: int) -> str: """Проверяет, есть ли уже у пользователя назначенный emoji.""" return await self.factory.users.check_emoji_for_user(user_id) - + # Методы для работы с сообщениями - async def add_message(self, message_text: str, user_id: int, message_id: int, date: int = None): + async def add_message( + self, message_text: str, user_id: int, message_id: int, date: int = None + ): """Добавление сообщения пользователя.""" if date is None: from datetime import datetime + date = int(datetime.now().timestamp()) - + message = UserMessage( message_text=message_text, user_id=user_id, telegram_message_id=message_id, - date=date + date=date, ) await self.factory.messages.add_message(message) - + async def get_user_by_message_id(self, message_id: int) -> Optional[int]: """Получение пользователя по message_id.""" return await self.factory.messages.get_user_by_message_id(message_id) - + # Методы для работы с постами async def add_post(self, post: TelegramPost): """Добавление поста.""" await self.factory.posts.add_post(post) - + async def update_helper_message(self, message_id: int, helper_message_id: int): """Обновление helper сообщения.""" await self.factory.posts.update_helper_message(message_id, helper_message_id) - - async def add_post_content(self, post_id: int, message_id: int, content_name: str, content_type: str): + + async def add_post_content( + self, post_id: int, message_id: int, content_name: str, content_type: str + ): """Добавление контента поста.""" - return await self.factory.posts.add_post_content(post_id, message_id, content_name, content_type) - + return await self.factory.posts.add_post_content( + post_id, message_id, content_name, content_type + ) + async def add_message_link(self, post_id: int, message_id: int) -> bool: """Добавляет связь между post_id и message_id в таблицу message_link_to_content.""" return await self.factory.posts.add_message_link(post_id, message_id) - - async def get_post_content_from_telegram_by_last_id(self, last_post_id: int) -> List[Tuple[str, str]]: + + async def get_post_content_from_telegram_by_last_id( + self, last_post_id: int + ) -> List[Tuple[str, str]]: """Получает контент поста по helper_text_message_id.""" return await self.factory.posts.get_post_content_by_helper_id(last_post_id) - async def get_post_content_by_helper_id(self, helper_message_id: int) -> List[Tuple[str, str]]: + async def get_post_content_by_helper_id( + self, helper_message_id: int + ) -> List[Tuple[str, str]]: """Алиас для get_post_content_from_telegram_by_last_id (используется callback-сервисом).""" return await self.get_post_content_from_telegram_by_last_id(helper_message_id) - - async def get_post_content_by_message_id(self, message_id: int) -> List[Tuple[str, str]]: + + async def get_post_content_by_message_id( + self, message_id: int + ) -> List[Tuple[str, str]]: """Получает контент одиночного поста по message_id.""" return await self.factory.posts.get_post_content_by_message_id(message_id) - - async def update_published_message_id(self, original_message_id: int, published_message_id: int): + + async def update_published_message_id( + self, original_message_id: int, published_message_id: int + ): """Обновляет published_message_id для опубликованного поста.""" - await self.factory.posts.update_published_message_id(original_message_id, published_message_id) - - async def add_published_post_content(self, published_message_id: int, content_path: str, content_type: str): + await self.factory.posts.update_published_message_id( + original_message_id, published_message_id + ) + + async def add_published_post_content( + self, published_message_id: int, content_path: str, content_type: str + ): """Добавляет контент опубликованного поста.""" - return await self.factory.posts.add_published_post_content(published_message_id, content_path, content_type) - - async def get_published_post_content(self, published_message_id: int) -> List[Tuple[str, str]]: + return await self.factory.posts.add_published_post_content( + published_message_id, content_path, content_type + ) + + async def get_published_post_content( + self, published_message_id: int + ) -> List[Tuple[str, str]]: """Получает контент опубликованного поста.""" return await self.factory.posts.get_published_post_content(published_message_id) - - async def get_post_text_from_telegram_by_last_id(self, last_post_id: int) -> Optional[str]: + + async def get_post_text_from_telegram_by_last_id( + self, last_post_id: int + ) -> Optional[str]: """Получает текст поста по helper_text_message_id.""" return await self.factory.posts.get_post_text_by_helper_id(last_post_id) async def get_post_text_by_helper_id(self, helper_message_id: int) -> Optional[str]: """Алиас для get_post_text_from_telegram_by_last_id (используется callback-сервисом).""" return await self.get_post_text_from_telegram_by_last_id(helper_message_id) - - async def get_post_ids_from_telegram_by_last_id(self, last_post_id: int) -> List[int]: + + async def get_post_ids_from_telegram_by_last_id( + self, last_post_id: int + ) -> List[int]: """Получает ID сообщений по helper_text_message_id.""" return await self.factory.posts.get_post_ids_by_helper_id(last_post_id) - + async def get_post_ids_by_helper_id(self, helper_message_id: int) -> List[int]: """Алиас для get_post_ids_from_telegram_by_last_id (используется callback-сервисом).""" return await self.get_post_ids_from_telegram_by_last_id(helper_message_id) - + async def get_author_id_by_message_id(self, message_id: int) -> Optional[int]: """Получает ID автора по message_id.""" return await self.factory.posts.get_author_id_by_message_id(message_id) - - async def get_author_id_by_helper_message_id(self, helper_text_message_id: int) -> Optional[int]: + + async def get_author_id_by_helper_message_id( + self, helper_text_message_id: int + ) -> Optional[int]: """Получает ID автора по helper_text_message_id.""" - return await self.factory.posts.get_author_id_by_helper_message_id(helper_text_message_id) - - async def get_post_text_and_anonymity_by_message_id(self, message_id: int) -> tuple[Optional[str], Optional[bool]]: + return await self.factory.posts.get_author_id_by_helper_message_id( + helper_text_message_id + ) + + async def get_post_text_and_anonymity_by_message_id( + self, message_id: int + ) -> tuple[Optional[str], Optional[bool]]: """Получает текст и is_anonymous поста по message_id.""" - return await self.factory.posts.get_post_text_and_anonymity_by_message_id(message_id) - - async def get_post_text_and_anonymity_by_helper_id(self, helper_message_id: int) -> tuple[Optional[str], Optional[bool]]: + return await self.factory.posts.get_post_text_and_anonymity_by_message_id( + message_id + ) + + async def get_post_text_and_anonymity_by_helper_id( + self, helper_message_id: int + ) -> tuple[Optional[str], Optional[bool]]: """Получает текст и is_anonymous поста по helper_text_message_id.""" - return await self.factory.posts.get_post_text_and_anonymity_by_helper_id(helper_message_id) + return await self.factory.posts.get_post_text_and_anonymity_by_helper_id( + helper_message_id + ) async def update_status_by_message_id(self, message_id: int, status: str) -> int: """Обновление статуса поста по message_id (одиночные посты). Возвращает число обновлённых строк.""" @@ -210,20 +261,20 @@ class AsyncBotDB: return await self.factory.posts.update_status_for_media_group_by_helper_id( helper_message_id, status ) - + # Методы для ML Scoring async def get_post_text_by_message_id(self, message_id: int) -> Optional[str]: """Получает текст поста по message_id.""" return await self.factory.posts.get_post_text_by_message_id(message_id) - + async def update_ml_scores(self, message_id: int, ml_scores_json: str) -> bool: """Обновляет ML-скоры для поста.""" return await self.factory.posts.update_ml_scores(message_id, ml_scores_json) - + async def get_approved_posts_texts(self, limit: int = 1000) -> List[str]: """Получает тексты одобренных постов для обучения RAG.""" return await self.factory.posts.get_approved_posts_texts(limit) - + async def get_declined_posts_texts(self, limit: int = 1000) -> List[str]: """Получает тексты отклоненных постов для обучения RAG.""" return await self.factory.posts.get_declined_posts_texts(limit) @@ -248,7 +299,7 @@ class AsyncBotDB: ban_author=ban_author, ) await self.factory.blacklist.add_user(blacklist_user) - + # Логируем в историю банов try: date_ban = int(datetime.now().timestamp()) @@ -265,7 +316,7 @@ class AsyncBotDB: self.logger.error( f"Ошибка записи в историю банов для user_id={user_id}: {e}" ) - + async def delete_user_blacklist(self, user_id: int) -> bool: """ Удаляет пользователя из черного списка. @@ -280,174 +331,206 @@ class AsyncBotDB: self.logger.error( f"Ошибка обновления истории при разбане для user_id={user_id}: {e}" ) - + # Удаляем из черного списка (критический путь) return await self.factory.blacklist.remove_user(user_id) - + async def check_user_in_blacklist(self, user_id: int) -> bool: """Проверяет, существует ли запись с данным user_id в blacklist.""" return await self.factory.blacklist.user_exists(user_id) - - async def get_blacklist_users(self, offset: int = 0, limit: int = 10) -> List[tuple]: + + async def get_blacklist_users( + self, offset: int = 0, limit: int = 10 + ) -> List[tuple]: """Получение пользователей из черного списка.""" users = await self.factory.blacklist.get_all_users(offset, limit) - return [(user.user_id, user.message_for_user, user.date_to_unban) for user in users] - + return [ + (user.user_id, user.message_for_user, user.date_to_unban) for user in users + ] + async def get_banned_users_from_db(self) -> List[tuple]: """Возвращает список пользователей в черном списке.""" users = await self.factory.blacklist.get_all_users_no_limit() - return [(user.user_id, user.message_for_user, user.date_to_unban) for user in users] - - async def get_banned_users_from_db_with_limits(self, offset: int, limit: int) -> List[tuple]: + return [ + (user.user_id, user.message_for_user, user.date_to_unban) for user in users + ] + + async def get_banned_users_from_db_with_limits( + self, offset: int, limit: int + ) -> List[tuple]: """Возвращает список пользователей в черном списке с учетом смещения и ограничения.""" users = await self.factory.blacklist.get_all_users(offset, limit) - return [(user.user_id, user.message_for_user, user.date_to_unban) for user in users] - + return [ + (user.user_id, user.message_for_user, user.date_to_unban) for user in users + ] + async def get_blacklist_users_by_id(self, user_id: int) -> Optional[tuple]: """Возвращает информацию о пользователе в черном списке по user_id.""" user = await self.factory.blacklist.get_user(user_id) if user: return (user.user_id, user.message_for_user, user.date_to_unban) return None - + async def get_blacklist_count(self) -> int: """Получение количества пользователей в черном списке.""" return await self.factory.blacklist.get_count() - - async def get_users_for_unblock_today(self, current_timestamp: int) -> Dict[int, int]: + + async def get_users_for_unblock_today( + self, current_timestamp: int + ) -> Dict[int, int]: """Возвращает список пользователей, у которых истек срок блокировки.""" - return await self.factory.blacklist.get_users_for_unblock_today(current_timestamp) - + return await self.factory.blacklist.get_users_for_unblock_today( + current_timestamp + ) + # Методы для работы с администраторами async def add_admin(self, user_id: int, role: str = "admin"): """Добавление администратора.""" admin = Admin(user_id=user_id, role=role) await self.factory.admins.add_admin(admin) - + async def remove_admin(self, user_id: int): """Удаление администратора.""" await self.factory.admins.remove_admin(user_id) - + async def is_admin(self, user_id: int) -> bool: """Проверка, является ли пользователь администратором.""" return await self.factory.admins.is_admin(user_id) - + async def get_all_admins(self) -> list[Admin]: """Получение всех администраторов.""" return await self.factory.admins.get_all_admins() - + # Методы для работы с аудио - async def add_audio_record(self, file_name: str, author_id: int, date_added: str, - listen_count: int, file_id: str): + async def add_audio_record( + self, + file_name: str, + author_id: int, + date_added: str, + listen_count: int, + file_id: str, + ): """Добавляет информацию о войсе пользователя.""" audio = AudioMessage( file_name=file_name, author_id=author_id, date_added=date_added, listen_count=listen_count, - file_id=file_id + file_id=file_id, ) await self.factory.audio.add_audio_record(audio) - - async def add_audio_record_simple(self, file_name: str, user_id: int, date_added) -> None: + + async def add_audio_record_simple( + self, file_name: str, user_id: int, date_added + ) -> None: """Добавляет простую запись об аудио файле.""" await self.factory.audio.add_audio_record_simple(file_name, user_id, date_added) - + async def last_date_audio(self) -> Optional[str]: """Получает дату последнего войса.""" return await self.factory.audio.get_last_date_audio() - + async def get_last_user_audio_record(self, user_id: int) -> bool: """Получает данные о количестве записей пользователя.""" count = await self.factory.audio.get_user_audio_records_count(user_id) return bool(count) - + async def get_path_for_audio_record(self, user_id: int) -> Optional[str]: """Получает данные о названии файла.""" return await self.factory.audio.get_path_for_audio_record(user_id) - + async def check_listen_audio(self, user_id: int) -> List[str]: """Проверяет прослушано ли аудио пользователем.""" return await self.factory.audio.check_listen_audio(user_id) - + async def mark_listened_audio(self, file_name: str, user_id: int): """Отмечает аудио прослушанным для конкретного пользователя.""" await self.factory.audio.mark_listened_audio(file_name, user_id) - + async def get_id_for_audio_record(self, user_id: int) -> int: """Получает следующий номер аудио сообщения пользователя.""" return await self.factory.audio.get_user_audio_records_count(user_id) - + async def get_user_audio_records_count(self, user_id: int) -> int: """Получает количество аудио записей пользователя.""" return await self.factory.audio.get_user_audio_records_count(user_id) - + async def refresh_listen_audio(self, user_id: int): """Очищает всю информацию о прослушанных аудио пользователем.""" await self.factory.audio.refresh_listen_audio(user_id) - + async def delete_listen_count_for_user(self, user_id: int): """Удаляет данные о прослушанных пользователем аудио.""" await self.factory.audio.delete_listen_count_for_user(user_id) - + async def get_user_id_by_file_name(self, file_name: str) -> Optional[int]: """Получает user_id пользователя по имени файла.""" return await self.factory.audio.get_user_id_by_file_name(file_name) - + async def get_date_by_file_name(self, file_name: str) -> Optional[str]: """Получает дату добавления файла.""" return await self.factory.audio.get_date_by_file_name(file_name) - + # Методы для voice bot - async def set_user_id_and_message_id_for_voice_bot(self, message_id: int, user_id: int) -> bool: + async def set_user_id_and_message_id_for_voice_bot( + self, message_id: int, user_id: int + ) -> bool: """Устанавливает связь между message_id и user_id для voice bot.""" - return await self.factory.audio.set_user_id_and_message_id_for_voice_bot(message_id, user_id) - - async def get_user_id_by_message_id_for_voice_bot(self, message_id: int) -> Optional[int]: + return await self.factory.audio.set_user_id_and_message_id_for_voice_bot( + message_id, user_id + ) + + async def get_user_id_by_message_id_for_voice_bot( + self, message_id: int + ) -> Optional[int]: """Получает user_id пользователя по message_id для voice bot.""" - return await self.factory.audio.get_user_id_by_message_id_for_voice_bot(message_id) - + return await self.factory.audio.get_user_id_by_message_id_for_voice_bot( + message_id + ) + async def delete_audio_moderate_record(self, message_id: int) -> None: """Удаляет запись из таблицы audio_moderate по message_id.""" await self.factory.audio.delete_audio_moderate_record(message_id) - + async def get_all_audio_records(self) -> List[Dict[str, Any]]: """Получить все записи аудио сообщений.""" return await self.factory.audio.get_all_audio_records() - + async def delete_audio_record_by_file_name(self, file_name: str) -> None: """Удалить запись аудио сообщения по имени файла.""" await self.factory.audio.delete_audio_record_by_file_name(file_name) - + # Методы для миграций async def create_table(self, sql_script: str): """Создает таблицу в базе. Используется в миграциях.""" await self.factory.migrations.create_table_from_sql(sql_script) - + # Методы для voice bot welcome tracking async def check_voice_bot_welcome_received(self, user_id: int) -> bool: """Проверяет, получал ли пользователь приветственное сообщение от voice_bot.""" return await self.factory.users.check_voice_bot_welcome_received(user_id) - + async def mark_voice_bot_welcome_received(self, user_id: int) -> bool: """Отмечает, что пользователь получил приветственное сообщение от voice_bot.""" return await self.factory.users.mark_voice_bot_welcome_received(user_id) - + # Методы для проверки целостности async def check_database_integrity(self): """Проверяет целостность базы данных и очищает WAL файлы.""" await self.factory.check_database_integrity() - + async def cleanup_wal_files(self): """Очищает WAL файлы и переключает на DELETE режим для предотвращения проблем с I/O.""" await self.factory.cleanup_wal_files() - + async def close(self): """Закрытие соединений.""" # Соединения закрываются в каждом методе pass - - async def fetch_one(self, query: str, params: tuple = ()) -> Optional[Dict[str, Any]]: + + async def fetch_one( + self, query: str, params: tuple = () + ) -> Optional[Dict[str, Any]]: """Выполняет SQL запрос и возвращает один результат.""" try: async with aiosqlite.connect(self.factory.db_path) as conn: diff --git a/database/base.py b/database/base.py index 4ede01a..ca32425 100644 --- a/database/base.py +++ b/database/base.py @@ -2,17 +2,18 @@ import os from typing import Optional import aiosqlite + from logs.custom_logger import logger class DatabaseConnection: """Базовый класс для работы с базой данных.""" - + def __init__(self, db_path: str): self.db_path = os.path.abspath(db_path) self.logger = logger - self.logger.info(f'Инициация базы данных: {self.db_path}') - + self.logger.info(f"Инициация базы данных: {self.db_path}") + async def _get_connection(self): """Получение асинхронного соединения с базой данных.""" try: @@ -28,7 +29,7 @@ class DatabaseConnection: except Exception as e: self.logger.error(f"Ошибка при получении соединения: {e}") raise - + async def _execute_query(self, query: str, params: tuple = ()): """Выполнение запроса с автоматическим закрытием соединения.""" conn = None @@ -43,7 +44,7 @@ class DatabaseConnection: finally: if conn: await conn.close() - + async def _execute_query_with_result(self, query: str, params: tuple = ()): """Выполнение запроса с результатом и автоматическим закрытием соединения.""" conn = None @@ -59,7 +60,7 @@ class DatabaseConnection: finally: if conn: await conn.close() - + async def _execute_transaction(self, queries: list): """Выполнение транзакции с несколькими запросами.""" conn = None @@ -76,7 +77,7 @@ class DatabaseConnection: finally: if conn: await conn.close() - + async def check_database_integrity(self): """Проверяет целостность базы данных и очищает WAL файлы.""" conn = None @@ -84,14 +85,16 @@ class DatabaseConnection: conn = await self._get_connection() result = await conn.execute("PRAGMA integrity_check") integrity_result = await result.fetchone() - + if integrity_result and integrity_result[0] == "ok": self.logger.info("Проверка целостности базы данных прошла успешно") await conn.execute("PRAGMA wal_checkpoint(TRUNCATE)") self.logger.info("WAL файлы очищены") else: - self.logger.warning(f"Проблемы с целостностью базы данных: {integrity_result}") - + self.logger.warning( + f"Проблемы с целостностью базы данных: {integrity_result}" + ) + except Exception as e: self.logger.error(f"Ошибка при проверке целостности базы данных: {e}") raise diff --git a/database/models.py b/database/models.py index bfbb030..7e9c7bd 100644 --- a/database/models.py +++ b/database/models.py @@ -6,6 +6,7 @@ from typing import List, Optional @dataclass class User: """Модель пользователя.""" + user_id: int first_name: str full_name: str @@ -22,6 +23,7 @@ class User: @dataclass class BlacklistUser: """Модель пользователя в черном списке.""" + user_id: int message_for_user: Optional[str] = None date_to_unban: Optional[int] = None @@ -32,6 +34,7 @@ class BlacklistUser: @dataclass class BlacklistHistoryRecord: """Модель записи истории банов/разбанов.""" + user_id: int message_for_user: Optional[str] = None date_ban: int = 0 @@ -45,6 +48,7 @@ class BlacklistHistoryRecord: @dataclass class UserMessage: """Модель сообщения пользователя.""" + message_text: str user_id: int telegram_message_id: int @@ -54,6 +58,7 @@ class UserMessage: @dataclass class TelegramPost: """Модель поста из Telegram.""" + message_id: int text: str author_id: int @@ -66,6 +71,7 @@ class TelegramPost: @dataclass class PostContent: """Модель контента поста.""" + message_id: int content_name: str content_type: str @@ -74,6 +80,7 @@ class PostContent: @dataclass class MessageContentLink: """Модель связи сообщения с контентом.""" + post_id: int message_id: int @@ -81,6 +88,7 @@ class MessageContentLink: @dataclass class Admin: """Модель администратора.""" + user_id: int role: str = "admin" created_at: Optional[str] = None @@ -89,6 +97,7 @@ class Admin: @dataclass class Migration: """Модель миграции.""" + script_name: str applied_at: Optional[str] = None @@ -96,6 +105,7 @@ class Migration: @dataclass class AudioMessage: """Модель аудио сообщения.""" + file_name: str author_id: int date_added: str @@ -106,6 +116,7 @@ class AudioMessage: @dataclass class AudioListenRecord: """Модель записи прослушивания аудио.""" + file_name: str user_id: int is_listen: bool = False @@ -114,5 +125,6 @@ class AudioListenRecord: @dataclass class AudioModerate: """Модель для voice bot.""" + message_id: int user_id: int diff --git a/database/repositories/__init__.py b/database/repositories/__init__.py index 6b165d2..3b57f50 100644 --- a/database/repositories/__init__.py +++ b/database/repositories/__init__.py @@ -22,7 +22,12 @@ from .post_repository import PostRepository from .user_repository import UserRepository __all__ = [ - 'UserRepository', 'BlacklistRepository', 'BlacklistHistoryRepository', - 'MessageRepository', 'PostRepository', 'AdminRepository', 'AudioRepository', - 'MigrationRepository' + "UserRepository", + "BlacklistRepository", + "BlacklistHistoryRepository", + "MessageRepository", + "PostRepository", + "AdminRepository", + "AudioRepository", + "MigrationRepository", ] diff --git a/database/repositories/admin_repository.py b/database/repositories/admin_repository.py index b696ac7..b7de3f8 100644 --- a/database/repositories/admin_repository.py +++ b/database/repositories/admin_repository.py @@ -6,70 +6,68 @@ from database.models import Admin class AdminRepository(DatabaseConnection): """Репозиторий для работы с администраторами.""" - + async def create_tables(self): """Создание таблицы администраторов.""" # Включаем поддержку внешних ключей await self._execute_query("PRAGMA foreign_keys = ON") - - query = ''' + + query = """ CREATE TABLE IF NOT EXISTS admins ( user_id INTEGER NOT NULL PRIMARY KEY, role TEXT DEFAULT 'admin', created_at INTEGER DEFAULT (strftime('%s', 'now')), FOREIGN KEY (user_id) REFERENCES our_users (user_id) ON DELETE CASCADE ) - ''' + """ await self._execute_query(query) self.logger.info("Таблица администраторов создана") - + async def add_admin(self, admin: Admin) -> None: """Добавление администратора.""" query = "INSERT INTO admins (user_id, role) VALUES (?, ?)" params = (admin.user_id, admin.role) - + await self._execute_query(query, params) - self.logger.info(f"Администратор добавлен: user_id={admin.user_id}, role={admin.role}") - + self.logger.info( + f"Администратор добавлен: user_id={admin.user_id}, role={admin.role}" + ) + async def remove_admin(self, user_id: int) -> None: """Удаление администратора.""" query = "DELETE FROM admins WHERE user_id = ?" await self._execute_query(query, (user_id,)) self.logger.info(f"Администратор удален: user_id={user_id}") - + async def is_admin(self, user_id: int) -> bool: """Проверка, является ли пользователь администратором.""" query = "SELECT 1 FROM admins WHERE user_id = ?" rows = await self._execute_query_with_result(query, (user_id,)) row = rows[0] if rows else None return bool(row) - + async def get_admin(self, user_id: int) -> Optional[Admin]: """Получение информации об администраторе.""" query = "SELECT user_id, role, created_at FROM admins WHERE user_id = ?" rows = await self._execute_query_with_result(query, (user_id,)) row = rows[0] if rows else None - + if row: return Admin( - user_id=row[0], - role=row[1], - created_at=row[2] if len(row) > 2 else None + user_id=row[0], role=row[1], created_at=row[2] if len(row) > 2 else None ) return None - + async def get_all_admins(self) -> list[Admin]: """Получение всех администраторов.""" query = "SELECT user_id, role, created_at FROM admins ORDER BY created_at DESC" rows = await self._execute_query_with_result(query) - + admins = [] for row in rows: admin = Admin( - user_id=row[0], - role=row[1], - created_at=row[2] if len(row) > 2 else None + user_id=row[0], role=row[1], created_at=row[2] if len(row) > 2 else None ) admins.append(admin) - + return admins diff --git a/database/repositories/audio_repository.py b/database/repositories/audio_repository.py index 1c2a301..3d561fb 100644 --- a/database/repositories/audio_repository.py +++ b/database/repositories/audio_repository.py @@ -1,4 +1,4 @@ -from datetime import datetime +from datetime import datetime, timezone from typing import Any, Dict, List, Optional from database.base import DatabaseConnection @@ -7,15 +7,15 @@ from database.models import AudioListenRecord, AudioMessage, AudioModerate class AudioRepository(DatabaseConnection): """Репозиторий для работы с аудио сообщениями.""" - + async def enable_foreign_keys(self): """Включает поддержку внешних ключей.""" await self._execute_query("PRAGMA foreign_keys = ON;") - + async def create_tables(self): """Создание таблиц для аудио.""" # Таблица аудио сообщений - audio_query = ''' + audio_query = """ CREATE TABLE IF NOT EXISTS audio_message_reference ( id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, file_name TEXT NOT NULL UNIQUE, @@ -23,33 +23,33 @@ class AudioRepository(DatabaseConnection): date_added INTEGER NOT NULL, FOREIGN KEY (author_id) REFERENCES our_users (user_id) ON DELETE CASCADE ) - ''' + """ await self._execute_query(audio_query) - + # Таблица прослушивания аудио - listen_query = ''' + listen_query = """ CREATE TABLE IF NOT EXISTS user_audio_listens ( file_name TEXT NOT NULL, user_id INTEGER NOT NULL, PRIMARY KEY (file_name, user_id), FOREIGN KEY (user_id) REFERENCES our_users (user_id) ON DELETE CASCADE ) - ''' + """ await self._execute_query(listen_query) - + # Таблица для voice bot - voice_query = ''' + voice_query = """ CREATE TABLE IF NOT EXISTS audio_moderate ( user_id INTEGER NOT NULL, message_id INTEGER, PRIMARY KEY (user_id, message_id), FOREIGN KEY (user_id) REFERENCES our_users (user_id) ON DELETE CASCADE ) - ''' + """ await self._execute_query(voice_query) - + self.logger.info("Таблицы для аудио созданы") - + async def add_audio_record(self, audio: AudioMessage) -> None: """Добавляет информацию о войсе пользователя.""" query = """ @@ -63,13 +63,17 @@ class AudioRepository(DatabaseConnection): date_timestamp = int(audio.date_added.timestamp()) else: date_timestamp = audio.date_added - + params = (audio.file_name, audio.author_id, date_timestamp) - + await self._execute_query(query, params) - self.logger.info(f"Аудио добавлено: file_name={audio.file_name}, author_id={audio.author_id}") - - async def add_audio_record_simple(self, file_name: str, user_id: int, date_added) -> None: + self.logger.info( + f"Аудио добавлено: file_name={audio.file_name}, author_id={audio.author_id}" + ) + + async def add_audio_record_simple( + self, file_name: str, user_id: int, date_added + ) -> None: """Добавляет информацию о войсе пользователя (упрощенная версия).""" query = """ INSERT INTO audio_message_reference (file_name, author_id, date_added) @@ -82,30 +86,30 @@ class AudioRepository(DatabaseConnection): date_timestamp = int(date_added.timestamp()) else: date_timestamp = date_added - + params = (file_name, user_id, date_timestamp) - + await self._execute_query(query, params) self.logger.info(f"Аудио добавлено: file_name={file_name}, user_id={user_id}") - + async def get_last_date_audio(self) -> Optional[int]: """Получает дату последнего войса.""" query = "SELECT date_added FROM audio_message_reference ORDER BY date_added DESC LIMIT 1" rows = await self._execute_query_with_result(query) row = rows[0] if rows else None - + if row: self.logger.info(f"Последняя дата аудио: {row[0]}") return row[0] return None - + async def get_user_audio_records_count(self, user_id: int) -> int: """Получает количество записей пользователя.""" query = "SELECT COUNT(*) FROM audio_message_reference WHERE author_id = ?" rows = await self._execute_query_with_result(query, (user_id,)) row = rows[0] if rows else None return row[0] if row else 0 - + async def get_path_for_audio_record(self, user_id: int) -> Optional[str]: """Получает название последнего файла пользователя.""" query = """ @@ -115,7 +119,7 @@ class AudioRepository(DatabaseConnection): rows = await self._execute_query_with_result(query, (user_id,)) row = rows[0] if rows else None return row[0] if row else None - + async def check_listen_audio(self, user_id: int) -> List[str]: """Проверяет непрослушанные аудио для пользователя.""" query = """ @@ -125,115 +129,129 @@ class AudioRepository(DatabaseConnection): WHERE l.user_id = ? AND l.file_name IS NOT NULL """ listened_files = await self._execute_query_with_result(query, (user_id,)) - + # Получаем все аудио, кроме созданных пользователем - all_audio_query = 'SELECT file_name FROM audio_message_reference WHERE author_id <> ?' + all_audio_query = ( + "SELECT file_name FROM audio_message_reference WHERE author_id <> ?" + ) all_files = await self._execute_query_with_result(all_audio_query, (user_id,)) - + # Находим непрослушанные listened_set = {row[0] for row in listened_files} all_set = {row[0] for row in all_files} new_files = list(all_set - listened_set) - - self.logger.info(f"Найдено {len(new_files)} непрослушанных аудио для пользователя {user_id}") + + self.logger.info( + f"Найдено {len(new_files)} непрослушанных аудио для пользователя {user_id}" + ) return new_files - + async def mark_listened_audio(self, file_name: str, user_id: int) -> None: """Отмечает аудио прослушанным для пользователя.""" query = "INSERT OR IGNORE INTO user_audio_listens (file_name, user_id) VALUES (?, ?)" params = (file_name, user_id) - + await self._execute_query(query, params) - self.logger.info(f"Аудио {file_name} отмечено как прослушанное для пользователя {user_id}") - + self.logger.info( + f"Аудио {file_name} отмечено как прослушанное для пользователя {user_id}" + ) + async def get_user_id_by_file_name(self, file_name: str) -> Optional[int]: """Получает user_id пользователя по имени файла.""" query = "SELECT author_id FROM audio_message_reference WHERE file_name = ?" rows = await self._execute_query_with_result(query, (file_name,)) row = rows[0] if rows else None - + if row: user_id = row[0] self.logger.info(f"Получен user_id {user_id} для файла {file_name}") return user_id return None - + async def get_date_by_file_name(self, file_name: str) -> Optional[str]: """Получает дату добавления файла.""" query = "SELECT date_added FROM audio_message_reference WHERE file_name = ?" rows = await self._execute_query_with_result(query, (file_name,)) row = rows[0] if rows else None - + if row: date_added = row[0] - # Преобразуем UNIX timestamp в читаемую дату - readable_date = datetime.fromtimestamp(date_added).strftime('%d.%m.%Y %H:%M') + # Преобразуем UNIX timestamp в читаемую дату (UTC для одинакового результата везде) + readable_date = datetime.fromtimestamp( + date_added, tz=timezone.utc + ).strftime("%d.%m.%Y %H:%M") self.logger.info(f"Получена дата {readable_date} для файла {file_name}") return readable_date return None - + async def refresh_listen_audio(self, user_id: int) -> None: """Очищает всю информацию о прослушанных аудио пользователем.""" query = "DELETE FROM user_audio_listens WHERE user_id = ?" await self._execute_query(query, (user_id,)) self.logger.info(f"Очищены записи прослушивания для пользователя {user_id}") - + async def delete_listen_count_for_user(self, user_id: int) -> None: """Удаляет данные о прослушанных пользователем аудио.""" query = "DELETE FROM user_audio_listens WHERE user_id = ?" await self._execute_query(query, (user_id,)) self.logger.info(f"Удалены записи прослушивания для пользователя {user_id}") - + # Методы для voice bot - async def set_user_id_and_message_id_for_voice_bot(self, message_id: int, user_id: int) -> bool: + async def set_user_id_and_message_id_for_voice_bot( + self, message_id: int, user_id: int + ) -> bool: """Устанавливает связь между message_id и user_id для voice bot.""" try: query = "INSERT OR IGNORE INTO audio_moderate (user_id, message_id) VALUES (?, ?)" params = (user_id, message_id) - + await self._execute_query(query, params) - self.logger.info(f"Связь установлена: message_id={message_id}, user_id={user_id}") + self.logger.info( + f"Связь установлена: message_id={message_id}, user_id={user_id}" + ) return True except Exception as e: self.logger.error(f"Ошибка установки связи: {e}") return False - - async def get_user_id_by_message_id_for_voice_bot(self, message_id: int) -> Optional[int]: + + async def get_user_id_by_message_id_for_voice_bot( + self, message_id: int + ) -> Optional[int]: """Получает user_id пользователя по message_id для voice bot.""" query = "SELECT user_id FROM audio_moderate WHERE message_id = ?" rows = await self._execute_query_with_result(query, (message_id,)) row = rows[0] if rows else None - + if row: user_id = row[0] self.logger.info(f"Получен user_id {user_id} для message_id {message_id}") return user_id return None - + async def delete_audio_moderate_record(self, message_id: int) -> None: """Удаляет запись из таблицы audio_moderate по message_id.""" query = "DELETE FROM audio_moderate WHERE message_id = ?" await self._execute_query(query, (message_id,)) - self.logger.info(f"Удалена запись из audio_moderate для message_id {message_id}") - + self.logger.info( + f"Удалена запись из audio_moderate для message_id {message_id}" + ) + async def get_all_audio_records(self) -> List[Dict[str, Any]]: """Получить все записи аудио сообщений.""" query = "SELECT file_name, author_id, date_added FROM audio_message_reference" rows = await self._execute_query_with_result(query) - + records = [] for row in rows: - records.append({ - 'file_name': row[0], - 'author_id': row[1], - 'date_added': row[2] - }) - + records.append( + {"file_name": row[0], "author_id": row[1], "date_added": row[2]} + ) + self.logger.info(f"Получено {len(records)} записей аудио сообщений") return records - + async def delete_audio_record_by_file_name(self, file_name: str) -> None: """Удалить запись аудио сообщения по имени файла.""" query = "DELETE FROM audio_message_reference WHERE file_name = ?" await self._execute_query(query, (file_name,)) - self.logger.info(f"Удалена запись аудио сообщения: {file_name}") \ No newline at end of file + self.logger.info(f"Удалена запись аудио сообщения: {file_name}") diff --git a/database/repositories/blacklist_history_repository.py b/database/repositories/blacklist_history_repository.py index 14f95e8..54685ea 100644 --- a/database/repositories/blacklist_history_repository.py +++ b/database/repositories/blacklist_history_repository.py @@ -6,10 +6,10 @@ from database.models import BlacklistHistoryRecord class BlacklistHistoryRepository(DatabaseConnection): """Репозиторий для работы с историей банов/разбанов.""" - + async def create_tables(self): """Создание таблицы истории банов/разбанов.""" - query = ''' + query = """ CREATE TABLE IF NOT EXISTS blacklist_history ( id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, user_id INTEGER NOT NULL, @@ -22,9 +22,9 @@ class BlacklistHistoryRepository(DatabaseConnection): FOREIGN KEY (user_id) REFERENCES our_users(user_id) ON DELETE CASCADE, FOREIGN KEY (ban_author) REFERENCES our_users(user_id) ON DELETE SET NULL ) - ''' + """ await self._execute_query(query) - + # Создаем индексы await self._execute_query( "CREATE INDEX IF NOT EXISTS idx_blacklist_history_user_id ON blacklist_history(user_id)" @@ -35,9 +35,9 @@ class BlacklistHistoryRepository(DatabaseConnection): await self._execute_query( "CREATE INDEX IF NOT EXISTS idx_blacklist_history_date_unban ON blacklist_history(date_unban)" ) - + self.logger.info("Таблица истории банов/разбанов создана") - + async def add_record_on_ban(self, record: BlacklistHistoryRecord) -> None: """Добавляет запись о бане в историю.""" query = """ @@ -48,8 +48,9 @@ class BlacklistHistoryRepository(DatabaseConnection): """ # Используем текущее время, если не указано from datetime import datetime + current_timestamp = int(datetime.now().timestamp()) - + params = ( record.user_id, record.message_for_user, @@ -59,28 +60,29 @@ class BlacklistHistoryRepository(DatabaseConnection): record.created_at if record.created_at is not None else current_timestamp, record.updated_at if record.updated_at is not None else current_timestamp, ) - + await self._execute_query(query, params) self.logger.info( f"Запись о бане добавлена в историю: user_id={record.user_id}, " f"date_ban={record.date_ban}" ) - + async def set_unban_date(self, user_id: int, date_unban: int) -> bool: """ Обновляет date_unban и updated_at в последней записи (date_unban IS NULL) для пользователя. - + Args: user_id: ID пользователя date_unban: Timestamp даты разбана - + Returns: True если запись обновлена, False если не найдена открытая запись """ try: from datetime import datetime + current_timestamp = int(datetime.now().timestamp()) - + # SQLite не поддерживает ORDER BY в UPDATE, поэтому используем подзапрос # Сначала проверяем, есть ли открытая запись check_query = """ @@ -90,13 +92,13 @@ class BlacklistHistoryRepository(DatabaseConnection): LIMIT 1 """ rows = await self._execute_query_with_result(check_query, (user_id,)) - + if not rows: self.logger.warning( f"Не найдена открытая запись в истории для обновления: user_id={user_id}" ) return False - + # Обновляем найденную запись update_query = """ UPDATE blacklist_history @@ -104,11 +106,11 @@ class BlacklistHistoryRepository(DatabaseConnection): updated_at = ? WHERE id = ? """ - + record_id = rows[0][0] params = (date_unban, current_timestamp, record_id) await self._execute_query(update_query, params) - + self.logger.info( f"Дата разбана обновлена в истории: user_id={user_id}, date_unban={date_unban}" ) diff --git a/database/repositories/blacklist_repository.py b/database/repositories/blacklist_repository.py index 6559645..f8d275e 100644 --- a/database/repositories/blacklist_repository.py +++ b/database/repositories/blacklist_repository.py @@ -6,10 +6,10 @@ from database.models import BlacklistUser class BlacklistRepository(DatabaseConnection): """Репозиторий для работы с черным списком.""" - + async def create_tables(self): """Создание таблицы черного списка.""" - query = ''' + query = """ CREATE TABLE IF NOT EXISTS blacklist ( user_id INTEGER NOT NULL PRIMARY KEY, message_for_user TEXT, @@ -19,10 +19,10 @@ class BlacklistRepository(DatabaseConnection): FOREIGN KEY (user_id) REFERENCES our_users (user_id) ON DELETE CASCADE, FOREIGN KEY (ban_author) REFERENCES our_users (user_id) ON DELETE SET NULL ) - ''' + """ await self._execute_query(query) self.logger.info("Таблица черного списка создана") - + async def add_user(self, blacklist_user: BlacklistUser) -> None: """Добавляет пользователя в черный список.""" query = """ @@ -35,29 +35,35 @@ class BlacklistRepository(DatabaseConnection): blacklist_user.date_to_unban, blacklist_user.ban_author, ) - + await self._execute_query(query, params) - self.logger.info(f"Пользователь добавлен в черный список: user_id={blacklist_user.user_id}") - + self.logger.info( + f"Пользователь добавлен в черный список: user_id={blacklist_user.user_id}" + ) + async def remove_user(self, user_id: int) -> bool: """Удаляет пользователя из черного списка.""" try: query = "DELETE FROM blacklist WHERE user_id = ?" await self._execute_query(query, (user_id,)) - self.logger.info(f"Пользователь с идентификатором {user_id} успешно удален из черного списка.") + self.logger.info( + f"Пользователь с идентификатором {user_id} успешно удален из черного списка." + ) return True except Exception as e: - self.logger.error(f"Ошибка удаления пользователя с идентификатором {user_id} " - f"из таблицы blacklist. Ошибка: {str(e)}") + self.logger.error( + f"Ошибка удаления пользователя с идентификатором {user_id} " + f"из таблицы blacklist. Ошибка: {str(e)}" + ) return False - + async def user_exists(self, user_id: int) -> bool: """Проверяет, существует ли запись с данным user_id в blacklist.""" query = "SELECT 1 FROM blacklist WHERE user_id = ?" rows = await self._execute_query_with_result(query, (user_id,)) self.logger.info(f"Существует ли пользователь: user_id={user_id} Итог: {rows}") return bool(rows) - + async def get_user(self, user_id: int) -> Optional[BlacklistUser]: """Возвращает информацию о пользователе в черном списке по user_id.""" query = """ @@ -67,7 +73,7 @@ class BlacklistRepository(DatabaseConnection): """ rows = await self._execute_query_with_result(query, (user_id,)) row = rows[0] if rows else None - + if row: return BlacklistUser( user_id=row[0], @@ -77,8 +83,10 @@ class BlacklistRepository(DatabaseConnection): ban_author=row[4] if len(row) > 4 else None, ) return None - - async def get_all_users(self, offset: int = 0, limit: int = 10) -> List[BlacklistUser]: + + async def get_all_users( + self, offset: int = 0, limit: int = 10 + ) -> List[BlacklistUser]: """Возвращает список пользователей в черном списке.""" query = """ SELECT user_id, message_for_user, date_to_unban, created_at, ban_author @@ -86,7 +94,7 @@ class BlacklistRepository(DatabaseConnection): LIMIT ?, ? """ rows = await self._execute_query_with_result(query, (offset, limit)) - + users = [] for row in rows: users.append( @@ -98,10 +106,12 @@ class BlacklistRepository(DatabaseConnection): ban_author=row[4] if len(row) > 4 else None, ) ) - - self.logger.info(f"Получен список пользователей в черном списке (offset={offset}, limit={limit}): {len(users)}") + + self.logger.info( + f"Получен список пользователей в черном списке (offset={offset}, limit={limit}): {len(users)}" + ) return users - + async def get_all_users_no_limit(self) -> List[BlacklistUser]: """Возвращает список всех пользователей в черном списке без лимитов.""" query = """ @@ -109,7 +119,7 @@ class BlacklistRepository(DatabaseConnection): FROM blacklist """ rows = await self._execute_query_with_result(query) - + users = [] for row in rows: users.append( @@ -121,19 +131,23 @@ class BlacklistRepository(DatabaseConnection): ban_author=row[4] if len(row) > 4 else None, ) ) - - self.logger.info(f"Получен список всех пользователей в черном списке: {len(users)}") + + self.logger.info( + f"Получен список всех пользователей в черном списке: {len(users)}" + ) return users - - async def get_users_for_unblock_today(self, current_timestamp: int) -> Dict[int, int]: + + async def get_users_for_unblock_today( + self, current_timestamp: int + ) -> Dict[int, int]: """Возвращает список пользователей, у которых истек срок блокировки.""" query = "SELECT user_id FROM blacklist WHERE date_to_unban IS NOT NULL AND date_to_unban <= ?" rows = await self._execute_query_with_result(query, (current_timestamp,)) - + users = {user_id: user_id for user_id, in rows} self.logger.info(f"Получен список пользователей для разблокировки: {users}") return users - + async def get_count(self) -> int: """Получение количества пользователей в черном списке.""" query = "SELECT COUNT(*) FROM blacklist" diff --git a/database/repositories/message_repository.py b/database/repositories/message_repository.py index d52a6c4..d24934f 100644 --- a/database/repositories/message_repository.py +++ b/database/repositories/message_repository.py @@ -7,10 +7,10 @@ from database.models import UserMessage class MessageRepository(DatabaseConnection): """Репозиторий для работы с сообщениями пользователей.""" - + async def create_tables(self): """Создание таблицы сообщений пользователей.""" - query = ''' + query = """ CREATE TABLE IF NOT EXISTS user_messages ( id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, message_text TEXT, @@ -19,24 +19,31 @@ class MessageRepository(DatabaseConnection): date INTEGER NOT NULL, FOREIGN KEY (user_id) REFERENCES our_users (user_id) ON DELETE CASCADE ) - ''' + """ await self._execute_query(query) self.logger.info("Таблица сообщений пользователей создана") - + async def add_message(self, message: UserMessage) -> None: """Добавление сообщения пользователя.""" if message.date is None: message.date = int(datetime.now().timestamp()) - + query = """ INSERT INTO user_messages (message_text, user_id, telegram_message_id, date) VALUES (?, ?, ?, ?) """ - params = (message.message_text, message.user_id, message.telegram_message_id, message.date) - + params = ( + message.message_text, + message.user_id, + message.telegram_message_id, + message.date, + ) + await self._execute_query(query, params) - self.logger.info(f"Новое сообщение добавлено: telegram_message_id={message.telegram_message_id}") - + self.logger.info( + f"Новое сообщение добавлено: telegram_message_id={message.telegram_message_id}" + ) + async def get_user_by_message_id(self, message_id: int) -> Optional[int]: """Получение пользователя по message_id.""" query = "SELECT user_id FROM user_messages WHERE telegram_message_id = ?" diff --git a/database/repositories/migration_repository.py b/database/repositories/migration_repository.py index 6dcb67d..8c7a02b 100644 --- a/database/repositories/migration_repository.py +++ b/database/repositories/migration_repository.py @@ -1,11 +1,13 @@ """Репозиторий для работы с миграциями базы данных.""" + import aiosqlite + from database.base import DatabaseConnection class MigrationRepository(DatabaseConnection): """Репозиторий для управления миграциями базы данных.""" - + async def create_table(self): """Создает таблицу migrations, если она не существует.""" query = """ @@ -17,13 +19,15 @@ class MigrationRepository(DatabaseConnection): """ await self._execute_query(query) self.logger.info("Таблица migrations создана или уже существует") - + async def get_applied_migrations(self) -> list[str]: """Возвращает список имен примененных скриптов миграций.""" conn = None try: conn = await self._get_connection() - cursor = await conn.execute("SELECT script_name FROM migrations ORDER BY applied_at") + cursor = await conn.execute( + "SELECT script_name FROM migrations ORDER BY applied_at" + ) rows = await cursor.fetchall() await cursor.close() return [row[0] for row in rows] @@ -33,15 +37,14 @@ class MigrationRepository(DatabaseConnection): finally: if conn: await conn.close() - + async def is_migration_applied(self, script_name: str) -> bool: """Проверяет, применена ли миграция.""" conn = None try: conn = await self._get_connection() cursor = await conn.execute( - "SELECT COUNT(*) FROM migrations WHERE script_name = ?", - (script_name,) + "SELECT COUNT(*) FROM migrations WHERE script_name = ?", (script_name,) ) row = await cursor.fetchone() await cursor.close() @@ -52,15 +55,14 @@ class MigrationRepository(DatabaseConnection): finally: if conn: await conn.close() - + async def mark_migration_applied(self, script_name: str) -> None: """Отмечает миграцию как примененную.""" conn = None try: conn = await self._get_connection() await conn.execute( - "INSERT INTO migrations (script_name) VALUES (?)", - (script_name,) + "INSERT INTO migrations (script_name) VALUES (?)", (script_name,) ) await conn.commit() self.logger.info(f"Миграция {script_name} отмечена как примененная") @@ -72,7 +74,7 @@ class MigrationRepository(DatabaseConnection): finally: if conn: await conn.close() - + async def create_table_from_sql(self, sql_script: str) -> None: """Создает таблицу из SQL скрипта. Используется в миграциях.""" await self._execute_query(sql_script) diff --git a/database/repositories/post_repository.py b/database/repositories/post_repository.py index e819cb6..37cdea0 100644 --- a/database/repositories/post_repository.py +++ b/database/repositories/post_repository.py @@ -7,11 +7,11 @@ from database.models import MessageContentLink, PostContent, TelegramPost class PostRepository(DatabaseConnection): """Репозиторий для работы с постами из Telegram.""" - + async def create_tables(self): """Создание таблиц для постов.""" # Таблица постов из Telegram - post_query = ''' + post_query = """ CREATE TABLE IF NOT EXISTS post_from_telegram_suggest ( message_id INTEGER NOT NULL PRIMARY KEY, text TEXT, @@ -23,9 +23,9 @@ class PostRepository(DatabaseConnection): published_message_id INTEGER, FOREIGN KEY (author_id) REFERENCES our_users (user_id) ON DELETE CASCADE ) - ''' + """ await self._execute_query(post_query) - + # Добавляем поле published_message_id если его нет (для существующих БД) try: check_column_query = """ @@ -34,19 +34,27 @@ class PostRepository(DatabaseConnection): """ existing_columns = await self._execute_query_with_result(check_column_query) if not existing_columns: - await self._execute_query('ALTER TABLE post_from_telegram_suggest ADD COLUMN published_message_id INTEGER') - self.logger.info("Столбец published_message_id добавлен в post_from_telegram_suggest") + await self._execute_query( + "ALTER TABLE post_from_telegram_suggest ADD COLUMN published_message_id INTEGER" + ) + self.logger.info( + "Столбец published_message_id добавлен в post_from_telegram_suggest" + ) except Exception as e: # Если проверка не удалась, пытаемся добавить столбец (может быть уже существует) try: - await self._execute_query('ALTER TABLE post_from_telegram_suggest ADD COLUMN published_message_id INTEGER') - self.logger.info("Столбец published_message_id добавлен в post_from_telegram_suggest (fallback)") + await self._execute_query( + "ALTER TABLE post_from_telegram_suggest ADD COLUMN published_message_id INTEGER" + ) + self.logger.info( + "Столбец published_message_id добавлен в post_from_telegram_suggest (fallback)" + ) except Exception: # Столбец уже существует, игнорируем ошибку pass - + # Таблица контента постов - content_query = ''' + content_query = """ CREATE TABLE IF NOT EXISTS content_post_from_telegram ( message_id INTEGER NOT NULL, content_name TEXT NOT NULL, @@ -54,22 +62,22 @@ class PostRepository(DatabaseConnection): PRIMARY KEY (message_id, content_name), FOREIGN KEY (message_id) REFERENCES post_from_telegram_suggest (message_id) ON DELETE CASCADE ) - ''' + """ await self._execute_query(content_query) - + # Таблица связи сообщений с контентом - link_query = ''' + link_query = """ CREATE TABLE IF NOT EXISTS message_link_to_content ( post_id INTEGER NOT NULL, message_id INTEGER NOT NULL, PRIMARY KEY (post_id, message_id), FOREIGN KEY (post_id) REFERENCES post_from_telegram_suggest (message_id) ON DELETE CASCADE ) - ''' + """ await self._execute_query(link_query) - + # Таблица контента опубликованных постов - published_content_query = ''' + published_content_query = """ CREATE TABLE IF NOT EXISTS published_post_content ( published_message_id INTEGER NOT NULL, content_name TEXT NOT NULL, @@ -77,38 +85,55 @@ class PostRepository(DatabaseConnection): published_at INTEGER NOT NULL, PRIMARY KEY (published_message_id, content_name) ) - ''' + """ await self._execute_query(published_content_query) - + # Создаем индексы try: - await self._execute_query('CREATE INDEX IF NOT EXISTS idx_published_post_content_message_id ON published_post_content(published_message_id)') - await self._execute_query('CREATE INDEX IF NOT EXISTS idx_post_from_telegram_suggest_published ON post_from_telegram_suggest(published_message_id)') + await self._execute_query( + "CREATE INDEX IF NOT EXISTS idx_published_post_content_message_id ON published_post_content(published_message_id)" + ) + await self._execute_query( + "CREATE INDEX IF NOT EXISTS idx_post_from_telegram_suggest_published ON post_from_telegram_suggest(published_message_id)" + ) except Exception: # Индексы уже существуют, игнорируем ошибку pass - + self.logger.info("Таблицы для постов созданы") - + async def add_post(self, post: TelegramPost) -> None: """Добавление поста.""" if not post.created_at: post.created_at = int(datetime.now().timestamp()) status = post.status if post.status else "suggest" # Преобразуем bool в int для SQLite (True -> 1, False -> 0, None -> None) - is_anonymous_int = None if post.is_anonymous is None else (1 if post.is_anonymous else 0) + is_anonymous_int = ( + None if post.is_anonymous is None else (1 if post.is_anonymous else 0) + ) # Используем INSERT OR IGNORE чтобы избежать ошибок при повторном создании query = """ INSERT OR IGNORE INTO post_from_telegram_suggest (message_id, text, author_id, created_at, status, is_anonymous) VALUES (?, ?, ?, ?, ?, ?) """ - params = (post.message_id, post.text, post.author_id, post.created_at, status, is_anonymous_int) + params = ( + post.message_id, + post.text, + post.author_id, + post.created_at, + status, + is_anonymous_int, + ) await self._execute_query(query, params) - self.logger.info(f"Пост добавлен (или уже существует): message_id={post.message_id}, text длина={len(post.text) if post.text else 0}, is_anonymous={is_anonymous_int}") - - async def update_helper_message(self, message_id: int, helper_message_id: int) -> None: + self.logger.info( + f"Пост добавлен (или уже существует): message_id={post.message_id}, text длина={len(post.text) if post.text else 0}, is_anonymous={is_anonymous_int}" + ) + + async def update_helper_message( + self, message_id: int, helper_message_id: int + ) -> None: """Обновление helper сообщения.""" query = "UPDATE post_from_telegram_suggest SET helper_text_message_id = ? WHERE message_id = ?" await self._execute_query(query, (helper_message_id, message_id)) @@ -131,12 +156,16 @@ class PostRepository(DatabaseConnection): f"update_status_by_message_id: 0 строк обновлено для message_id={message_id}, status={status}" ) else: - self.logger.info(f"Статус поста message_id={message_id} обновлён на {status}") + self.logger.info( + f"Статус поста message_id={message_id} обновлён на {status}" + ) return n except Exception as e: if conn: await conn.rollback() - self.logger.error(f"Ошибка при обновлении статуса message_id={message_id}: {e}") + self.logger.error( + f"Ошибка при обновлении статуса message_id={message_id}: {e}" + ) raise finally: if conn: @@ -182,39 +211,53 @@ class PostRepository(DatabaseConnection): if conn: await conn.close() - async def add_post_content(self, post_id: int, message_id: int, content_name: str, content_type: str) -> bool: + async def add_post_content( + self, post_id: int, message_id: int, content_name: str, content_type: str + ) -> bool: """Добавление контента поста.""" try: # Сначала добавляем связь link_query = "INSERT OR IGNORE INTO message_link_to_content (post_id, message_id) VALUES (?, ?)" await self._execute_query(link_query, (post_id, message_id)) - + # Затем добавляем контент content_query = """ INSERT OR IGNORE INTO content_post_from_telegram (message_id, content_name, content_type) VALUES (?, ?, ?) """ - await self._execute_query(content_query, (message_id, content_name, content_type)) - - self.logger.info(f"Контент поста добавлен: post_id={post_id}, message_id={message_id}") + await self._execute_query( + content_query, (message_id, content_name, content_type) + ) + + self.logger.info( + f"Контент поста добавлен: post_id={post_id}, message_id={message_id}" + ) return True except Exception as e: self.logger.error(f"Ошибка при добавлении контента поста: {e}") return False - + async def add_message_link(self, post_id: int, message_id: int) -> bool: """Добавляет связь между post_id и message_id в таблицу message_link_to_content.""" try: - self.logger.info(f"Добавление связи: post_id={post_id}, message_id={message_id}") + self.logger.info( + f"Добавление связи: post_id={post_id}, message_id={message_id}" + ) link_query = "INSERT OR IGNORE INTO message_link_to_content (post_id, message_id) VALUES (?, ?)" await self._execute_query(link_query, (post_id, message_id)) - self.logger.info(f"Связь успешно добавлена: post_id={post_id}, message_id={message_id}") + self.logger.info( + f"Связь успешно добавлена: post_id={post_id}, message_id={message_id}" + ) return True except Exception as e: - self.logger.error(f"Ошибка при добавлении связи post_id={post_id}, message_id={message_id}: {e}") + self.logger.error( + f"Ошибка при добавлении связи post_id={post_id}, message_id={message_id}: {e}" + ) return False - - async def get_post_content_by_helper_id(self, helper_message_id: int) -> List[Tuple[str, str]]: + + async def get_post_content_by_helper_id( + self, helper_message_id: int + ) -> List[Tuple[str, str]]: """Получает контент поста по helper_text_message_id.""" query = """ SELECT cpft.content_name, cpft.content_type @@ -223,12 +266,16 @@ class PostRepository(DatabaseConnection): JOIN content_post_from_telegram cpft ON cpft.message_id = mltc.message_id WHERE pft.helper_text_message_id = ? """ - post_content = await self._execute_query_with_result(query, (helper_message_id,)) - + post_content = await self._execute_query_with_result( + query, (helper_message_id,) + ) + self.logger.info(f"Получен контент поста: {len(post_content)} элементов") return post_content - - async def get_post_content_by_message_id(self, message_id: int) -> List[Tuple[str, str]]: + + async def get_post_content_by_message_id( + self, message_id: int + ) -> List[Tuple[str, str]]: """Получает контент одиночного поста по message_id.""" query = """ SELECT cpft.content_name, cpft.content_type @@ -238,21 +285,25 @@ class PostRepository(DatabaseConnection): WHERE pft.message_id = ? AND pft.helper_text_message_id IS NULL """ post_content = await self._execute_query_with_result(query, (message_id,)) - - self.logger.info(f"Получен контент одиночного поста: {len(post_content)} элементов для message_id={message_id}") + + self.logger.info( + f"Получен контент одиночного поста: {len(post_content)} элементов для message_id={message_id}" + ) return post_content - + async def get_post_text_by_helper_id(self, helper_message_id: int) -> Optional[str]: """Получает текст поста по helper_text_message_id.""" query = "SELECT text FROM post_from_telegram_suggest WHERE helper_text_message_id = ?" rows = await self._execute_query_with_result(query, (helper_message_id,)) row = rows[0] if rows else None - + if row: - self.logger.info(f"Получен текст поста для helper_message_id={helper_message_id}") + self.logger.info( + f"Получен текст поста для helper_message_id={helper_message_id}" + ) return row[0] return None - + async def get_post_ids_by_helper_id(self, helper_message_id: int) -> List[int]: """Получает ID сообщений по helper_text_message_id.""" query = """ @@ -262,114 +313,145 @@ class PostRepository(DatabaseConnection): WHERE pft.helper_text_message_id = ? """ rows = await self._execute_query_with_result(query, (helper_message_id,)) - + post_ids = [row[0] for row in rows] self.logger.info(f"Получены ID сообщений: {len(post_ids)} элементов") return post_ids - + async def get_author_id_by_message_id(self, message_id: int) -> Optional[int]: """Получает ID автора по message_id.""" query = "SELECT author_id FROM post_from_telegram_suggest WHERE message_id = ?" rows = await self._execute_query_with_result(query, (message_id,)) row = rows[0] if rows else None - + if row: author_id = row[0] - self.logger.info(f"Получен author_id: {author_id} для message_id={message_id}") + self.logger.info( + f"Получен author_id: {author_id} для message_id={message_id}" + ) return author_id return None - - async def get_author_id_by_helper_message_id(self, helper_message_id: int) -> Optional[int]: + + async def get_author_id_by_helper_message_id( + self, helper_message_id: int + ) -> Optional[int]: """Получает ID автора по helper_text_message_id.""" query = "SELECT author_id FROM post_from_telegram_suggest WHERE helper_text_message_id = ?" rows = await self._execute_query_with_result(query, (helper_message_id,)) row = rows[0] if rows else None - + if row: author_id = row[0] - self.logger.info(f"Получен author_id: {author_id} для helper_message_id={helper_message_id}") + self.logger.info( + f"Получен author_id: {author_id} для helper_message_id={helper_message_id}" + ) return author_id return None - - async def get_post_text_and_anonymity_by_message_id(self, message_id: int) -> Tuple[Optional[str], Optional[bool]]: + + async def get_post_text_and_anonymity_by_message_id( + self, message_id: int + ) -> Tuple[Optional[str], Optional[bool]]: """Получает текст и is_anonymous поста по message_id.""" query = "SELECT text, is_anonymous FROM post_from_telegram_suggest WHERE message_id = ?" rows = await self._execute_query_with_result(query, (message_id,)) row = rows[0] if rows else None - + if row: text = row[0] or "" is_anonymous_int = row[1] # Преобразуем int в bool (1 -> True, 0 -> False, NULL -> None) is_anonymous = None if is_anonymous_int is None else bool(is_anonymous_int) - self.logger.info(f"Получены текст и is_anonymous для message_id={message_id}") + self.logger.info( + f"Получены текст и is_anonymous для message_id={message_id}" + ) return text, is_anonymous return None, None - - async def get_post_text_and_anonymity_by_helper_id(self, helper_message_id: int) -> Tuple[Optional[str], Optional[bool]]: + + async def get_post_text_and_anonymity_by_helper_id( + self, helper_message_id: int + ) -> Tuple[Optional[str], Optional[bool]]: """Получает текст и is_anonymous поста по helper_text_message_id.""" query = "SELECT text, is_anonymous FROM post_from_telegram_suggest WHERE helper_text_message_id = ?" rows = await self._execute_query_with_result(query, (helper_message_id,)) row = rows[0] if rows else None - + if row: text = row[0] or "" is_anonymous_int = row[1] # Преобразуем int в bool (1 -> True, 0 -> False, NULL -> None) is_anonymous = None if is_anonymous_int is None else bool(is_anonymous_int) - self.logger.info(f"Получены текст и is_anonymous для helper_message_id={helper_message_id}") + self.logger.info( + f"Получены текст и is_anonymous для helper_message_id={helper_message_id}" + ) return text, is_anonymous return None, None - - async def update_published_message_id(self, original_message_id: int, published_message_id: int) -> None: + + async def update_published_message_id( + self, original_message_id: int, published_message_id: int + ) -> None: """Обновляет published_message_id для опубликованного поста.""" query = "UPDATE post_from_telegram_suggest SET published_message_id = ? WHERE message_id = ?" await self._execute_query(query, (published_message_id, original_message_id)) - self.logger.info(f"Обновлен published_message_id: {original_message_id} -> {published_message_id}") - + self.logger.info( + f"Обновлен published_message_id: {original_message_id} -> {published_message_id}" + ) + async def add_published_post_content( self, published_message_id: int, content_path: str, content_type: str ) -> bool: """Добавляет контент опубликованного поста.""" try: from datetime import datetime + published_at = int(datetime.now().timestamp()) - + query = """ INSERT OR IGNORE INTO published_post_content (published_message_id, content_name, content_type, published_at) VALUES (?, ?, ?, ?) """ - await self._execute_query(query, (published_message_id, content_path, content_type, published_at)) - self.logger.info(f"Добавлен контент опубликованного поста: published_message_id={published_message_id}, type={content_type}") + await self._execute_query( + query, (published_message_id, content_path, content_type, published_at) + ) + self.logger.info( + f"Добавлен контент опубликованного поста: published_message_id={published_message_id}, type={content_type}" + ) return True except Exception as e: - self.logger.error(f"Ошибка при добавлении контента опубликованного поста: {e}") + self.logger.error( + f"Ошибка при добавлении контента опубликованного поста: {e}" + ) return False - - async def get_published_post_content(self, published_message_id: int) -> List[Tuple[str, str]]: + + async def get_published_post_content( + self, published_message_id: int + ) -> List[Tuple[str, str]]: """Получает контент опубликованного поста.""" query = """ SELECT content_name, content_type FROM published_post_content WHERE published_message_id = ? """ - post_content = await self._execute_query_with_result(query, (published_message_id,)) - self.logger.info(f"Получен контент опубликованного поста: {len(post_content)} элементов для published_message_id={published_message_id}") + post_content = await self._execute_query_with_result( + query, (published_message_id,) + ) + self.logger.info( + f"Получен контент опубликованного поста: {len(post_content)} элементов для published_message_id={published_message_id}" + ) return post_content - + # ============================================ # Методы для работы с ML-скорингом # ============================================ - + async def update_ml_scores(self, message_id: int, ml_scores_json: str) -> bool: """ Обновляет ML-скоры для поста. - + Args: message_id: ID сообщения в группе модерации ml_scores_json: JSON строка со скорами - + Returns: True если обновлено успешно """ @@ -379,16 +461,18 @@ class PostRepository(DatabaseConnection): self.logger.info(f"ML-скоры обновлены для message_id={message_id}") return True except Exception as e: - self.logger.error(f"Ошибка обновления ML-скоров для message_id={message_id}: {e}") + self.logger.error( + f"Ошибка обновления ML-скоров для message_id={message_id}: {e}" + ) return False - + async def get_ml_scores_by_message_id(self, message_id: int) -> Optional[str]: """ Получает ML-скоры для поста. - + Args: message_id: ID сообщения - + Returns: JSON строка со скорами или None """ @@ -397,14 +481,14 @@ class PostRepository(DatabaseConnection): if rows and rows[0][0]: return rows[0][0] return None - + async def get_post_text_by_message_id(self, message_id: int) -> Optional[str]: """ Получает текст поста по message_id. - + Args: message_id: ID сообщения - + Returns: Текст поста или None """ @@ -413,14 +497,14 @@ class PostRepository(DatabaseConnection): if rows and rows[0][0]: return rows[0][0] return None - + async def get_approved_posts_texts(self, limit: int = 1000) -> List[str]: """ Получает тексты опубликованных постов для обучения RAG. - + Args: limit: Максимальное количество постов - + Returns: Список текстов """ @@ -437,14 +521,14 @@ class PostRepository(DatabaseConnection): texts = [row[0] for row in rows if row[0]] self.logger.info(f"Получено {len(texts)} опубликованных постов для обучения") return texts - + async def get_declined_posts_texts(self, limit: int = 1000) -> List[str]: """ Получает тексты отклоненных постов для обучения RAG. - + Args: limit: Максимальное количество постов - + Returns: Список текстов """ @@ -461,4 +545,3 @@ class PostRepository(DatabaseConnection): texts = [row[0] for row in rows if row[0]] self.logger.info(f"Получено {len(texts)} отклоненных постов для обучения") return texts - diff --git a/database/repositories/user_repository.py b/database/repositories/user_repository.py index b87ee02..1e6cfd4 100644 --- a/database/repositories/user_repository.py +++ b/database/repositories/user_repository.py @@ -7,10 +7,10 @@ from database.models import User class UserRepository(DatabaseConnection): """Репозиторий для работы с пользователями.""" - + async def create_tables(self): """Создание таблицы пользователей.""" - query = ''' + query = """ CREATE TABLE IF NOT EXISTS our_users ( user_id INTEGER NOT NULL PRIMARY KEY, first_name TEXT, @@ -24,42 +24,56 @@ class UserRepository(DatabaseConnection): date_changed INTEGER NOT NULL, voice_bot_welcome_received BOOLEAN DEFAULT 0 ) - ''' + """ await self._execute_query(query) self.logger.info("Таблица пользователей создана") - + async def user_exists(self, user_id: int) -> bool: """Проверяет, существует ли пользователь в базе данных.""" query = "SELECT user_id FROM our_users WHERE user_id = ?" rows = await self._execute_query_with_result(query, (user_id,)) - self.logger.info(f"Проверка существования пользователя: user_id={user_id}, результат={rows}") + self.logger.info( + f"Проверка существования пользователя: user_id={user_id}, результат={rows}" + ) return bool(len(rows)) - + async def add_user(self, user: User) -> None: """Добавление нового пользователя с защитой от дублирования.""" if not user.date_added: user.date_added = int(datetime.now().timestamp()) if not user.date_changed: user.date_changed = int(datetime.now().timestamp()) - + query = """ INSERT OR IGNORE INTO our_users (user_id, first_name, full_name, username, is_bot, language_code, emoji, has_stickers, date_added, date_changed, voice_bot_welcome_received) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """ - params = (user.user_id, user.first_name, user.full_name, user.username, - user.is_bot, user.language_code, user.emoji, user.has_stickers, - user.date_added, user.date_changed, user.voice_bot_welcome_received) - + params = ( + user.user_id, + user.first_name, + user.full_name, + user.username, + user.is_bot, + user.language_code, + user.emoji, + user.has_stickers, + user.date_added, + user.date_changed, + user.voice_bot_welcome_received, + ) + await self._execute_query(query, params) - self.logger.info(f"Пользователь обработан (создан или уже существует): {user.user_id}") - + self.logger.info( + f"Пользователь обработан (создан или уже существует): {user.user_id}" + ) + async def get_user_info(self, user_id: int) -> Optional[User]: """Получение информации о пользователе.""" query = "SELECT username, full_name, has_stickers, emoji FROM our_users WHERE user_id = ?" rows = await self._execute_query_with_result(query, (user_id,)) row = rows[0] if rows else None - + if row: return User( user_id=user_id, @@ -67,16 +81,16 @@ class UserRepository(DatabaseConnection): full_name=row[1], username=row[0], has_stickers=bool(row[2]) if row[2] is not None else False, - emoji=row[3] + emoji=row[3], ) return None - + async def get_user_by_id(self, user_id: int) -> Optional[User]: """Получение пользователя по ID.""" query = "SELECT * FROM our_users WHERE user_id = ?" rows = await self._execute_query_with_result(query, (user_id,)) row = rows[0] if rows else None - + if row: return User( user_id=row[0], @@ -89,58 +103,66 @@ class UserRepository(DatabaseConnection): emoji=row[7], date_added=row[8], date_changed=row[9], - voice_bot_welcome_received=bool(row[10]) if len(row) > 10 else False + voice_bot_welcome_received=bool(row[10]) if len(row) > 10 else False, ) return None - + async def get_username(self, user_id: int) -> Optional[str]: """Возвращает username пользователя.""" query = "SELECT username FROM our_users WHERE user_id = ?" rows = await self._execute_query_with_result(query, (user_id,)) row = rows[0] if rows else None - + if row: username = row[0] - self.logger.info(f"Username пользователя найден: user_id={user_id}, username={username}") + self.logger.info( + f"Username пользователя найден: user_id={user_id}, username={username}" + ) return username return None - + async def get_user_id_by_username(self, username: str) -> Optional[int]: """Возвращает user_id пользователя по username.""" query = "SELECT user_id FROM our_users WHERE username = ?" rows = await self._execute_query_with_result(query, (username,)) row = rows[0] if rows else None - + if row: user_id = row[0] - self.logger.info(f"User_id пользователя найден: username={username}, user_id={user_id}") + self.logger.info( + f"User_id пользователя найден: username={username}, user_id={user_id}" + ) return user_id return None - + async def get_full_name_by_id(self, user_id: int) -> Optional[str]: """Возвращает full_name пользователя.""" query = "SELECT full_name FROM our_users WHERE user_id = ?" rows = await self._execute_query_with_result(query, (user_id,)) row = rows[0] if rows else None - + if row: full_name = row[0] - self.logger.info(f"Full_name пользователя найден: user_id={user_id}, full_name={full_name}") + self.logger.info( + f"Full_name пользователя найден: user_id={user_id}, full_name={full_name}" + ) return full_name return None - + async def get_user_first_name(self, user_id: int) -> Optional[str]: """Возвращает first_name пользователя.""" query = "SELECT first_name FROM our_users WHERE user_id = ?" rows = await self._execute_query_with_result(query, (user_id,)) row = rows[0] if rows else None - + if row: first_name = row[0] - self.logger.info(f"First_name пользователя найден: user_id={user_id}, first_name={first_name}") + self.logger.info( + f"First_name пользователя найден: user_id={user_id}, first_name={first_name}" + ) return first_name return None - + async def get_all_user_ids(self) -> List[int]: """Возвращает список всех user_id.""" query = "SELECT user_id FROM our_users" @@ -148,20 +170,22 @@ class UserRepository(DatabaseConnection): user_ids = [row[0] for row in rows] self.logger.info(f"Получен список всех user_id: {user_ids}") return user_ids - + async def get_last_users(self, limit: int = 30) -> List[tuple]: """Получение последних пользователей.""" query = "SELECT full_name, user_id FROM our_users ORDER BY date_changed DESC LIMIT ?" rows = await self._execute_query_with_result(query, (limit,)) return rows - + async def update_user_date(self, user_id: int) -> None: """Обновление даты последнего изменения пользователя.""" date_changed = int(datetime.now().timestamp()) query = "UPDATE our_users SET date_changed = ? WHERE user_id = ?" await self._execute_query(query, (date_changed, user_id)) - - async def update_user_info(self, user_id: int, username: str = None, full_name: str = None) -> None: + + async def update_user_info( + self, user_id: int, username: str = None, full_name: str = None + ) -> None: """Обновление информации о пользователе.""" if username and full_name: query = "UPDATE our_users SET username = ?, full_name = ? WHERE user_id = ?" @@ -174,85 +198,93 @@ class UserRepository(DatabaseConnection): params = (full_name, user_id) else: return - + await self._execute_query(query, params) - + async def update_user_emoji(self, user_id: int, emoji: str) -> None: """Обновление эмодзи пользователя.""" query = "UPDATE our_users SET emoji = ? WHERE user_id = ?" await self._execute_query(query, (emoji, user_id)) - + async def update_stickers_info(self, user_id: int) -> None: """Обновление информации о стикерах.""" query = "UPDATE our_users SET has_stickers = 1 WHERE user_id = ?" await self._execute_query(query, (user_id,)) - + async def get_stickers_info(self, user_id: int) -> bool: """Получение информации о стикерах.""" query = "SELECT has_stickers FROM our_users WHERE user_id = ?" rows = await self._execute_query_with_result(query, (user_id,)) row = rows[0] if rows else None return bool(row[0]) if row and row[0] is not None else False - + async def check_emoji_exists(self, emoji: str) -> bool: """Проверка существования эмодзи.""" query = "SELECT 1 FROM our_users WHERE emoji = ?" rows = await self._execute_query_with_result(query, (emoji,)) row = rows[0] if rows else None return bool(row) - + async def get_user_emoji(self, user_id: int) -> str: """ Получает эмодзи пользователя. - + Args: user_id: ID пользователя. - + Returns: str: Эмодзи пользователя или "Смайл еще не определен" если не установлен. """ query = "SELECT emoji FROM our_users WHERE user_id = ?" rows = await self._execute_query_with_result(query, (user_id,)) row = rows[0] if rows else None - + if row and row[0]: emoji = row[0] - self.logger.info(f"Эмодзи пользователя найден: user_id={user_id}, emoji={emoji}") + self.logger.info( + f"Эмодзи пользователя найден: user_id={user_id}, emoji={emoji}" + ) return str(emoji) else: self.logger.info(f"Эмодзи пользователя не найден: user_id={user_id}") return "Смайл еще не определен" - + async def check_emoji_for_user(self, user_id: int) -> str: """ Проверяет, есть ли уже у пользователя назначенный emoji. - + Args: user_id: ID пользователя. - + Returns: str: Эмодзи пользователя или "Смайл еще не определен" если не установлен. """ return await self.get_user_emoji(user_id) - + async def check_voice_bot_welcome_received(self, user_id: int) -> bool: """Проверяет, получал ли пользователь приветственное сообщение от voice_bot.""" query = "SELECT voice_bot_welcome_received FROM our_users WHERE user_id = ?" rows = await self._execute_query_with_result(query, (user_id,)) row = rows[0] if rows else None - + if row: welcome_received = bool(row[0]) - self.logger.info(f"Пользователь {user_id} получал приветствие: {welcome_received}") + self.logger.info( + f"Пользователь {user_id} получал приветствие: {welcome_received}" + ) return welcome_received return False - + async def mark_voice_bot_welcome_received(self, user_id: int) -> bool: """Отмечает, что пользователь получил приветственное сообщение от voice_bot.""" try: - query = "UPDATE our_users SET voice_bot_welcome_received = 1 WHERE user_id = ?" + query = ( + "UPDATE our_users SET voice_bot_welcome_received = 1 WHERE user_id = ?" + ) await self._execute_query(query, (user_id,)) - self.logger.info(f"Пользователь {user_id} отмечен как получивший приветствие") + self.logger.info( + f"Пользователь {user_id} отмечен как получивший приветствие" + ) return True except Exception as e: self.logger.error(f"Ошибка при отметке получения приветствия: {e}") diff --git a/database/repository_factory.py b/database/repository_factory.py index 3d08e77..d218f21 100644 --- a/database/repository_factory.py +++ b/database/repository_factory.py @@ -2,8 +2,9 @@ from typing import Optional from database.repositories.admin_repository import AdminRepository from database.repositories.audio_repository import AudioRepository -from database.repositories.blacklist_history_repository import \ - BlacklistHistoryRepository +from database.repositories.blacklist_history_repository import ( + BlacklistHistoryRepository, +) from database.repositories.blacklist_repository import BlacklistRepository from database.repositories.message_repository import MessageRepository from database.repositories.migration_repository import MigrationRepository @@ -13,7 +14,7 @@ from database.repositories.user_repository import UserRepository class RepositoryFactory: """Фабрика для создания репозиториев.""" - + def __init__(self, db_path: str): self.db_path = db_path self._user_repo: Optional[UserRepository] = None @@ -24,63 +25,63 @@ class RepositoryFactory: self._admin_repo: Optional[AdminRepository] = None self._audio_repo: Optional[AudioRepository] = None self._migration_repo: Optional[MigrationRepository] = None - + @property def users(self) -> UserRepository: """Возвращает репозиторий пользователей.""" if self._user_repo is None: self._user_repo = UserRepository(self.db_path) return self._user_repo - + @property def blacklist(self) -> BlacklistRepository: """Возвращает репозиторий черного списка.""" if self._blacklist_repo is None: self._blacklist_repo = BlacklistRepository(self.db_path) return self._blacklist_repo - + @property def blacklist_history(self) -> BlacklistHistoryRepository: """Возвращает репозиторий истории банов/разбанов.""" if self._blacklist_history_repo is None: self._blacklist_history_repo = BlacklistHistoryRepository(self.db_path) return self._blacklist_history_repo - + @property def messages(self) -> MessageRepository: """Возвращает репозиторий сообщений.""" if self._message_repo is None: self._message_repo = MessageRepository(self.db_path) return self._message_repo - + @property def posts(self) -> PostRepository: """Возвращает репозиторий постов.""" if self._post_repo is None: self._post_repo = PostRepository(self.db_path) return self._post_repo - + @property def admins(self) -> AdminRepository: """Возвращает репозиторий администраторов.""" if self._admin_repo is None: self._admin_repo = AdminRepository(self.db_path) return self._admin_repo - + @property def audio(self) -> AudioRepository: """Возвращает репозиторий аудио.""" if self._audio_repo is None: self._audio_repo = AudioRepository(self.db_path) return self._audio_repo - + @property def migrations(self) -> MigrationRepository: """Возвращает репозиторий миграций.""" if self._migration_repo is None: self._migration_repo = MigrationRepository(self.db_path) return self._migration_repo - + async def create_all_tables(self): """Создает все таблицы в базе данных.""" await self.migrations.create_table() # Сначала создаем таблицу миграций @@ -91,11 +92,11 @@ class RepositoryFactory: await self.posts.create_tables() await self.admins.create_tables() await self.audio.create_tables() - + async def check_database_integrity(self): """Проверяет целостность базы данных.""" await self.users.check_database_integrity() - + async def cleanup_wal_files(self): """Очищает WAL файлы.""" await self.users.cleanup_wal_files() diff --git a/helper_bot/config/rate_limit_config.py b/helper_bot/config/rate_limit_config.py index 92e5ea7..b856436 100644 --- a/helper_bot/config/rate_limit_config.py +++ b/helper_bot/config/rate_limit_config.py @@ -1,6 +1,7 @@ """ Конфигурация для rate limiting """ + from dataclasses import dataclass from typing import Optional @@ -8,26 +9,28 @@ from typing import Optional @dataclass class RateLimitSettings: """Настройки rate limiting для разных типов сообщений""" - + # Основные настройки messages_per_second: float = 0.5 # Максимум 0.5 сообщений в секунду на чат burst_limit: int = 2 # Максимум 2 сообщения подряд retry_after_multiplier: float = 1.5 # Множитель для увеличения задержки при retry max_retry_delay: float = 30.0 # Максимальная задержка между попытками max_retries: int = 3 # Максимальное количество повторных попыток - + # Специальные настройки для разных типов сообщений voice_message_delay: float = 2.0 # Дополнительная задержка для голосовых сообщений media_message_delay: float = 1.5 # Дополнительная задержка для медиа сообщений - text_message_delay: float = 1.0 # Дополнительная задержка для текстовых сообщений - + text_message_delay: float = 1.0 # Дополнительная задержка для текстовых сообщений + # Настройки для разных типов чатов private_chat_multiplier: float = 1.0 # Множитель для приватных чатов - group_chat_multiplier: float = 0.8 # Множитель для групповых чатов - channel_multiplier: float = 0.6 # Множитель для каналов - + group_chat_multiplier: float = 0.8 # Множитель для групповых чатов + channel_multiplier: float = 0.6 # Множитель для каналов + # Глобальные ограничения - global_messages_per_second: float = 10.0 # Максимум 10 сообщений в секунду глобально + global_messages_per_second: float = ( + 10.0 # Максимум 10 сообщений в секунду глобально + ) global_burst_limit: int = 20 # Максимум 20 сообщений подряд глобально @@ -37,7 +40,7 @@ DEVELOPMENT_CONFIG = RateLimitSettings( burst_limit=3, retry_after_multiplier=1.2, max_retry_delay=15.0, - max_retries=2 + max_retries=2, ) PRODUCTION_CONFIG = RateLimitSettings( @@ -48,7 +51,7 @@ PRODUCTION_CONFIG = RateLimitSettings( max_retries=3, voice_message_delay=2.5, media_message_delay=2.0, - text_message_delay=1.5 + text_message_delay=1.5, ) STRICT_CONFIG = RateLimitSettings( @@ -59,46 +62,45 @@ STRICT_CONFIG = RateLimitSettings( max_retries=5, voice_message_delay=3.0, media_message_delay=2.5, - text_message_delay=2.0 + text_message_delay=2.0, ) def get_rate_limit_config(environment: str = "production") -> RateLimitSettings: """ Получает конфигурацию rate limiting в зависимости от окружения - + Args: environment: Окружение ('development', 'production', 'strict') - + Returns: RateLimitSettings: Конфигурация для указанного окружения """ configs = { "development": DEVELOPMENT_CONFIG, "production": PRODUCTION_CONFIG, - "strict": STRICT_CONFIG + "strict": STRICT_CONFIG, } - + return configs.get(environment, PRODUCTION_CONFIG) def get_adaptive_config( - current_error_rate: float, - base_config: Optional[RateLimitSettings] = None + current_error_rate: float, base_config: Optional[RateLimitSettings] = None ) -> RateLimitSettings: """ Получает адаптивную конфигурацию на основе текущего уровня ошибок - + Args: current_error_rate: Текущий уровень ошибок (0.0 - 1.0) base_config: Базовая конфигурация - + Returns: RateLimitSettings: Адаптированная конфигурация """ if base_config is None: base_config = PRODUCTION_CONFIG - + # Если уровень ошибок высокий, ужесточаем ограничения if current_error_rate > 0.1: # Более 10% ошибок return RateLimitSettings( @@ -109,9 +111,9 @@ def get_adaptive_config( max_retries=base_config.max_retries + 1, voice_message_delay=base_config.voice_message_delay * 1.5, media_message_delay=base_config.media_message_delay * 1.3, - text_message_delay=base_config.text_message_delay * 1.2 + text_message_delay=base_config.text_message_delay * 1.2, ) - + # Если уровень ошибок низкий, можно немного ослабить ограничения elif current_error_rate < 0.01: # Менее 1% ошибок return RateLimitSettings( @@ -122,8 +124,8 @@ def get_adaptive_config( max_retries=max(1, base_config.max_retries - 1), voice_message_delay=base_config.voice_message_delay * 0.8, media_message_delay=base_config.media_message_delay * 0.9, - text_message_delay=base_config.text_message_delay * 0.9 + text_message_delay=base_config.text_message_delay * 0.9, ) - + # Возвращаем базовую конфигурацию return base_config diff --git a/helper_bot/filters/main.py b/helper_bot/filters/main.py index 021e287..691a6ee 100644 --- a/helper_bot/filters/main.py +++ b/helper_bot/filters/main.py @@ -5,7 +5,7 @@ from aiogram.types import Message class ChatTypeFilter(BaseFilter): # [1] - def __init__(self, chat_type: Union[str, list]): # [2] + def __init__(self, chat_type: Union[str, list]): # [2] self.chat_type = chat_type async def __call__(self, message: Message) -> bool: # [3] diff --git a/helper_bot/handlers/admin/__init__.py b/helper_bot/handlers/admin/__init__.py index af092ad..4b9d6f2 100644 --- a/helper_bot/handlers/admin/__init__.py +++ b/helper_bot/handlers/admin/__init__.py @@ -1,27 +1,37 @@ from .admin_handlers import admin_router from .dependencies import AdminAccessMiddleware, BotDB, Settings -from .exceptions import (AdminAccessDeniedError, AdminError, InvalidInputError, - UserAlreadyBannedError, UserNotFoundError) +from .exceptions import ( + AdminAccessDeniedError, + AdminError, + InvalidInputError, + UserAlreadyBannedError, + UserNotFoundError, +) from .services import AdminService, BannedUser, User -from .utils import (escape_html, format_ban_confirmation, format_user_info, - handle_admin_error, return_to_admin_menu) +from .utils import ( + escape_html, + format_ban_confirmation, + format_user_info, + handle_admin_error, + return_to_admin_menu, +) __all__ = [ - 'admin_router', - 'AdminAccessMiddleware', - 'BotDB', - 'Settings', - 'AdminService', - 'User', - 'BannedUser', - 'AdminError', - 'AdminAccessDeniedError', - 'UserNotFoundError', - 'InvalidInputError', - 'UserAlreadyBannedError', - 'return_to_admin_menu', - 'handle_admin_error', - 'format_user_info', - 'format_ban_confirmation', - 'escape_html' -] \ No newline at end of file + "admin_router", + "AdminAccessMiddleware", + "BotDB", + "Settings", + "AdminService", + "User", + "BannedUser", + "AdminError", + "AdminAccessDeniedError", + "UserNotFoundError", + "InvalidInputError", + "UserAlreadyBannedError", + "return_to_admin_menu", + "handle_admin_error", + "format_user_info", + "format_ban_confirmation", + "escape_html", +] diff --git a/helper_bot/handlers/admin/admin_handlers.py b/helper_bot/handlers/admin/admin_handlers.py index a0aeb03..66d0519 100644 --- a/helper_bot/handlers/admin/admin_handlers.py +++ b/helper_bot/handlers/admin/admin_handlers.py @@ -1,22 +1,30 @@ from aiogram import F, Router, types from aiogram.filters import Command, MagicData, StateFilter from aiogram.fsm.context import FSMContext + from helper_bot.filters.main import ChatTypeFilter from helper_bot.handlers.admin.dependencies import AdminAccessMiddleware -from helper_bot.handlers.admin.exceptions import (InvalidInputError, - UserAlreadyBannedError) +from helper_bot.handlers.admin.exceptions import ( + InvalidInputError, + UserAlreadyBannedError, +) from helper_bot.handlers.admin.services import AdminService -from helper_bot.handlers.admin.utils import (escape_html, - format_ban_confirmation, - format_user_info, - handle_admin_error, - return_to_admin_menu) -from helper_bot.keyboards.keyboards import (create_keyboard_for_approve_ban, - create_keyboard_for_ban_days, - create_keyboard_for_ban_reason, - create_keyboard_with_pagination, - get_reply_keyboard_admin) +from helper_bot.handlers.admin.utils import ( + escape_html, + format_ban_confirmation, + format_user_info, + handle_admin_error, + return_to_admin_menu, +) +from helper_bot.keyboards.keyboards import ( + create_keyboard_for_approve_ban, + create_keyboard_for_ban_days, + create_keyboard_for_ban_reason, + create_keyboard_with_pagination, + get_reply_keyboard_admin, +) from helper_bot.utils.base_dependency_factory import get_global_instance + # Local imports - metrics from helper_bot.utils.metrics import db_query_time, track_errors, track_time from logs.custom_logger import logger @@ -30,23 +38,19 @@ admin_router.message.middleware(AdminAccessMiddleware()) # ХЕНДЛЕРЫ МЕНЮ # ============================================================================ -@admin_router.message( - ChatTypeFilter(chat_type=["private"]), - Command('admin') -) + +@admin_router.message(ChatTypeFilter(chat_type=["private"]), Command("admin")) @track_time("admin_panel", "admin_handlers") @track_errors("admin_handlers", "admin_panel") -async def admin_panel( - message: types.Message, - state: FSMContext, - **kwargs -): +async def admin_panel(message: types.Message, state: FSMContext, **kwargs): """Главное меню администратора""" try: await state.set_state("ADMIN") logger.info(f"Запуск админ панели для пользователя: {message.from_user.id}") markup = get_reply_keyboard_admin() - await message.answer("Добро пожаловать в админку. Выбери что хочешь:", reply_markup=markup) + await message.answer( + "Добро пожаловать в админку. Выбери что хочешь:", reply_markup=markup + ) except Exception as e: await handle_admin_error(message, e, state, "admin_panel") @@ -55,18 +59,20 @@ async def admin_panel( # ХЕНДЛЕР ОТМЕНЫ # ============================================================================ + @admin_router.message( ChatTypeFilter(chat_type=["private"]), - StateFilter("AWAIT_BAN_TARGET", "AWAIT_BAN_DETAILS", "AWAIT_BAN_DURATION", "BAN_CONFIRMATION"), - F.text == 'Отменить' + StateFilter( + "AWAIT_BAN_TARGET", + "AWAIT_BAN_DETAILS", + "AWAIT_BAN_DURATION", + "BAN_CONFIRMATION", + ), + F.text == "Отменить", ) @track_time("cancel_ban_process", "admin_handlers") @track_errors("admin_handlers", "cancel_ban_process") -async def cancel_ban_process( - message: types.Message, - state: FSMContext, - **kwargs - ): +async def cancel_ban_process(message: types.Message, state: FSMContext, **kwargs): """Отмена процесса блокировки""" try: current_state = await state.get_state() @@ -79,32 +85,31 @@ async def cancel_ban_process( @admin_router.message( ChatTypeFilter(chat_type=["private"]), StateFilter("ADMIN"), - F.text == 'Бан (Список)' + F.text == "Бан (Список)", ) @track_time("get_last_users", "admin_handlers") @track_errors("admin_handlers", "get_last_users") @db_query_time("get_last_users", "users", "select") async def get_last_users( - message: types.Message, - state: FSMContext, - bot_db: MagicData("bot_db") - ): + message: types.Message, state: FSMContext, bot_db: MagicData("bot_db") +): """Получение списка последних пользователей""" try: - logger.info(f"Получение списка последних пользователей. Пользователь: {message.from_user.full_name}") + logger.info( + f"Получение списка последних пользователей. Пользователь: {message.from_user.full_name}" + ) admin_service = AdminService(bot_db) users = await admin_service.get_last_users() - + # Преобразуем в формат для клавиатуры (кортежи как ожидает create_keyboard_with_pagination) - users_data = [ - (user.full_name, user.user_id) - for user in users - ] - - keyboard = create_keyboard_with_pagination(1, len(users_data), users_data, 'ban') + users_data = [(user.full_name, user.user_id) for user in users] + + keyboard = create_keyboard_with_pagination( + 1, len(users_data), users_data, "ban" + ) await message.answer( text="Список пользователей которые последними обращались к боту", - reply_markup=keyboard + reply_markup=keyboard, ) except Exception as e: await handle_admin_error(message, e, state, "get_last_users") @@ -113,27 +118,31 @@ async def get_last_users( @admin_router.message( ChatTypeFilter(chat_type=["private"]), StateFilter("ADMIN"), - F.text == 'Разбан (список)' + F.text == "Разбан (список)", ) @track_time("get_banned_users", "admin_handlers") @track_errors("admin_handlers", "get_banned_users") @db_query_time("get_banned_users", "users", "select") async def get_banned_users( - message: types.Message, - state: FSMContext, - bot_db: MagicData("bot_db") - ): + message: types.Message, state: FSMContext, bot_db: MagicData("bot_db") +): """Получение списка заблокированных пользователей""" try: - logger.info(f"Получение списка заблокированных пользователей. Пользователь: {message.from_user.full_name}") + logger.info( + f"Получение списка заблокированных пользователей. Пользователь: {message.from_user.full_name}" + ) admin_service = AdminService(bot_db) message_text, buttons_list = await admin_service.get_banned_users_for_display(0) - + if buttons_list: - keyboard = create_keyboard_with_pagination(1, len(buttons_list), buttons_list, 'unlock') + keyboard = create_keyboard_with_pagination( + 1, len(buttons_list), buttons_list, "unlock" + ) await message.answer(text=message_text, reply_markup=keyboard) else: - await message.answer(text="В списке заблокированных пользователей никого нет") + await message.answer( + text="В списке заблокированных пользователей никого нет" + ) except Exception as e: await handle_admin_error(message, e, state, "get_banned_users") @@ -141,85 +150,95 @@ async def get_banned_users( @admin_router.message( ChatTypeFilter(chat_type=["private"]), StateFilter("ADMIN"), - F.text == '📊 ML Статистика' + F.text == "📊 ML Статистика", ) @track_time("get_ml_stats", "admin_handlers") @track_errors("admin_handlers", "get_ml_stats") -async def get_ml_stats( - message: types.Message, - state: FSMContext, - **kwargs - ): +async def get_ml_stats(message: types.Message, state: FSMContext, **kwargs): """Получение статистики ML-скоринга""" try: - logger.info(f"Запрос ML статистики от пользователя: {message.from_user.full_name}") - + logger.info( + f"Запрос ML статистики от пользователя: {message.from_user.full_name}" + ) + bdf = get_global_instance() scoring_manager = bdf.get_scoring_manager() - + if not scoring_manager: - await message.answer("📊 ML Scoring отключен\n\nДля включения установите RAG_ENABLED=true или DEEPSEEK_ENABLED=true в .env") + await message.answer( + "📊 ML Scoring отключен\n\nДля включения установите RAG_ENABLED=true или DEEPSEEK_ENABLED=true в .env" + ) return - + stats = await scoring_manager.get_stats() - + # Формируем текст статистики lines = ["📊 ML Scoring Статистика\n"] - + # RAG статистика if "rag" in stats: rag = stats["rag"] lines.append("🤖 RAG API:") - + # Проверяем, есть ли данные из API (новый контракт содержит model_loaded и vector_store) if "model_loaded" in rag or "vector_store" in rag: # Данные из API /stats if "model_loaded" in rag: - model_loaded = rag.get('model_loaded', False) - lines.append(f" • Модель загружена: {'✅' if model_loaded else '❌'}") + model_loaded = rag.get("model_loaded", False) + lines.append( + f" • Модель загружена: {'✅' if model_loaded else '❌'}" + ) if "model_name" in rag: lines.append(f" • Модель: {rag.get('model_name', 'N/A')}") if "device" in rag: lines.append(f" • Устройство: {rag.get('device', 'N/A')}") - + # Статистика из vector_store if "vector_store" in rag: vector_store = rag["vector_store"] positive_count = vector_store.get("positive_count", 0) negative_count = vector_store.get("negative_count", 0) total_count = vector_store.get("total_count", 0) - + lines.append(f" • Положительных примеров: {positive_count}") lines.append(f" • Отрицательных примеров: {negative_count}") lines.append(f" • Всего примеров: {total_count}") - + if "vector_dim" in vector_store: - lines.append(f" • Размерность векторов: {vector_store.get('vector_dim', 'N/A')}") + lines.append( + f" • Размерность векторов: {vector_store.get('vector_dim', 'N/A')}" + ) if "max_examples" in vector_store: - lines.append(f" • Макс. примеров: {vector_store.get('max_examples', 'N/A')}") + lines.append( + f" • Макс. примеров: {vector_store.get('max_examples', 'N/A')}" + ) else: # Fallback на синхронные данные (если API недоступен) lines.append(f" • API URL: {rag.get('api_url', 'N/A')}") if "enabled" in rag: - lines.append(f" • Статус: {'✅ Включен' if rag.get('enabled') else '❌ Отключен'}") - + lines.append( + f" • Статус: {'✅ Включен' if rag.get('enabled') else '❌ Отключен'}" + ) + lines.append("") - + # DeepSeek статистика if "deepseek" in stats: ds = stats["deepseek"] lines.append("🔮 DeepSeek API:") - lines.append(f" • Статус: {'✅ Включен' if ds.get('enabled') else '❌ Отключен'}") + lines.append( + f" • Статус: {'✅ Включен' if ds.get('enabled') else '❌ Отключен'}" + ) lines.append(f" • Модель: {ds.get('model', 'N/A')}") lines.append(f" • Таймаут: {ds.get('timeout', 'N/A')}с") lines.append("") - + # Если ничего не включено if "rag" not in stats and "deepseek" not in stats: lines.append("⚠️ Ни один сервис не настроен") - + await message.answer("\n".join(lines), parse_mode="HTML") - + except Exception as e: logger.error(f"Ошибка получения ML статистики: {e}") await message.answer(f"❌ Ошибка получения статистики: {str(e)}") @@ -229,68 +248,80 @@ async def get_ml_stats( # ХЕНДЛЕРЫ ПРОЦЕССА БАНА # ============================================================================ + @admin_router.message( ChatTypeFilter(chat_type=["private"]), StateFilter("ADMIN"), - F.text.in_(['Бан по нику', 'Бан по ID']) + F.text.in_(["Бан по нику", "Бан по ID"]), ) @track_time("start_ban_process", "admin_handlers") @track_errors("admin_handlers", "start_ban_process") -async def start_ban_process( - message: types.Message, - state: FSMContext, - **kwargs - ): +async def start_ban_process(message: types.Message, state: FSMContext, **kwargs): """Начало процесса блокировки пользователя""" try: - ban_type = "username" if message.text == 'Бан по нику' else "id" + ban_type = "username" if message.text == "Бан по нику" else "id" await state.update_data(ban_type=ban_type) - - prompt_text = "Пришли мне username блокируемого пользователя" if ban_type == "username" else "Пришли мне ID блокируемого пользователя" + + prompt_text = ( + "Пришли мне username блокируемого пользователя" + if ban_type == "username" + else "Пришли мне ID блокируемого пользователя" + ) await message.answer(prompt_text) - await state.set_state('AWAIT_BAN_TARGET') + await state.set_state("AWAIT_BAN_TARGET") except Exception as e: await handle_admin_error(message, e, state, "start_ban_process") @admin_router.message( - ChatTypeFilter(chat_type=["private"]), - StateFilter("AWAIT_BAN_TARGET") + ChatTypeFilter(chat_type=["private"]), StateFilter("AWAIT_BAN_TARGET") ) @track_time("process_ban_target", "admin_handlers") @track_errors("admin_handlers", "process_ban_target") async def process_ban_target( - message: types.Message, - state: FSMContext, - bot_db: MagicData("bot_db") - ): + message: types.Message, state: FSMContext, bot_db: MagicData("bot_db") +): """Обработка введенного username/ID для блокировки""" - logger.info(f"process_ban_target: === НАЧАЛО ОБРАБОТКИ === Получено сообщение от {message.from_user.id}: {message.text}") - + logger.info( + f"process_ban_target: === НАЧАЛО ОБРАБОТКИ === Получено сообщение от {message.from_user.id}: {message.text}" + ) + try: user_data = await state.get_data() - ban_type = user_data.get('ban_type') + ban_type = user_data.get("ban_type") admin_service = AdminService(bot_db) - + logger.info(f"process_ban_target: ban_type={ban_type}, user_data={user_data}") # Определяем пользователя if ban_type == "username": - logger.info(f"process_ban_target: Поиск пользователя по username: {message.text}") + logger.info( + f"process_ban_target: Поиск пользователя по username: {message.text}" + ) user = await admin_service.get_user_by_username(message.text) if not user: - logger.warning(f"process_ban_target: Пользователь с username '{message.text}' не найден") - await message.answer(f"Пользователь с username '{escape_html(message.text)}' не найден.") + logger.warning( + f"process_ban_target: Пользователь с username '{message.text}' не найден" + ) + await message.answer( + f"Пользователь с username '{escape_html(message.text)}' не найден." + ) await return_to_admin_menu(message, state) return else: # ban_type == "id" try: - logger.info(f"process_ban_target: Валидация и поиск пользователя по ID: {message.text}") + logger.info( + f"process_ban_target: Валидация и поиск пользователя по ID: {message.text}" + ) user_id = await admin_service.validate_user_input(message.text) user = await admin_service.get_user_by_id(user_id) if not user: - logger.warning(f"process_ban_target: Пользователь с ID {user_id} не найден в базе данных") - await message.answer(f"Пользователь с ID {user_id} не найден в базе данных.") + logger.warning( + f"process_ban_target: Пользователь с ID {user_id} не найден в базе данных" + ) + await message.answer( + f"Пользователь с ID {user_id} не найден в базе данных." + ) await return_to_admin_menu(message, state) return except InvalidInputError as e: @@ -298,115 +329,117 @@ async def process_ban_target( await message.answer(str(e)) await return_to_admin_menu(message, state) return - - logger.info(f"process_ban_target: Найден пользователь: {user.user_id}, {user.username}, {user.full_name}") - + + logger.info( + f"process_ban_target: Найден пользователь: {user.user_id}, {user.username}, {user.full_name}" + ) + # Сохраняем данные пользователя await state.update_data( target_user_id=user.user_id, target_username=user.username, - target_full_name=user.full_name + target_full_name=user.full_name, ) - + # Показываем информацию о пользователе и запрашиваем причину user_info = format_user_info(user.user_id, user.username, user.full_name) markup = create_keyboard_for_ban_reason() - logger.info(f"process_ban_target: Отправка сообщения с причиной бана, user_info: {user_info}") - + logger.info( + f"process_ban_target: Отправка сообщения с причиной бана, user_info: {user_info}" + ) + await message.answer( text=f"{user_info}\n\nВыбери причину бана из списка или напиши ее в чат", - reply_markup=markup + reply_markup=markup, ) - await state.set_state('AWAIT_BAN_DETAILS') + await state.set_state("AWAIT_BAN_DETAILS") logger.info("process_ban_target: Состояние изменено на AWAIT_BAN_DETAILS") - + except Exception as e: logger.error(f"process_ban_target: Неожиданная ошибка: {e}", exc_info=True) await handle_admin_error(message, e, state, "process_ban_target") @admin_router.message( - ChatTypeFilter(chat_type=["private"]), - StateFilter("AWAIT_BAN_DETAILS") + ChatTypeFilter(chat_type=["private"]), StateFilter("AWAIT_BAN_DETAILS") ) @track_time("process_ban_reason", "admin_handlers") @track_errors("admin_handlers", "process_ban_reason") -async def process_ban_reason( - message: types.Message, - state: FSMContext, - **kwargs - ): +async def process_ban_reason(message: types.Message, state: FSMContext, **kwargs): """Обработка причины блокировки""" - logger.info(f"process_ban_reason: === НАЧАЛО ОБРАБОТКИ === Получено сообщение от {message.from_user.id}: {message.text}") - + logger.info( + f"process_ban_reason: === НАЧАЛО ОБРАБОТКИ === Получено сообщение от {message.from_user.id}: {message.text}" + ) + try: # Проверяем текущее состояние current_state = await state.get_state() logger.info(f"process_ban_reason: Текущее состояние: {current_state}") - + # Проверяем данные состояния state_data = await state.get_data() logger.info(f"process_ban_reason: Данные состояния: {state_data}") - - logger.info(f"process_ban_reason: Обновление данных состояния с причиной: {message.text}") + + logger.info( + f"process_ban_reason: Обновление данных состояния с причиной: {message.text}" + ) await state.update_data(ban_reason=message.text) - + markup = create_keyboard_for_ban_days() safe_reason = escape_html(message.text) - logger.info(f"process_ban_reason: Отправка сообщения с выбором срока бана, причина: {safe_reason}") - + logger.info( + f"process_ban_reason: Отправка сообщения с выбором срока бана, причина: {safe_reason}" + ) + await message.answer( f"Выбрана причина: {safe_reason}. Выбери срок бана в днях или напиши его в чат", - reply_markup=markup + reply_markup=markup, ) - await state.set_state('AWAIT_BAN_DURATION') + await state.set_state("AWAIT_BAN_DURATION") logger.info("process_ban_reason: Состояние изменено на AWAIT_BAN_DURATION") - + except Exception as e: logger.error(f"process_ban_reason: Неожиданная ошибка: {e}", exc_info=True) await handle_admin_error(message, e, state, "process_ban_reason") @admin_router.message( - ChatTypeFilter(chat_type=["private"]), - StateFilter("AWAIT_BAN_DURATION") + ChatTypeFilter(chat_type=["private"]), StateFilter("AWAIT_BAN_DURATION") ) @track_time("process_ban_duration", "admin_handlers") @track_errors("admin_handlers", "process_ban_duration") -async def process_ban_duration( - message: types.Message, - state: FSMContext, - **kwargs - ): +async def process_ban_duration(message: types.Message, state: FSMContext, **kwargs): """Обработка срока блокировки""" try: user_data = await state.get_data() - + # Определяем срок блокировки - if message.text == 'Навсегда': + if message.text == "Навсегда": ban_days = None else: try: ban_days = int(message.text) if ban_days <= 0: - await message.answer("Срок блокировки должен быть положительным числом.") + await message.answer( + "Срок блокировки должен быть положительным числом." + ) return except ValueError: - await message.answer("Пожалуйста, введите корректное число дней или выберите 'Навсегда'.") + await message.answer( + "Пожалуйста, введите корректное число дней или выберите 'Навсегда'." + ) return - + await state.update_data(ban_days=ban_days) - + # Показываем подтверждение confirmation_text = format_ban_confirmation( - user_data['target_user_id'], - user_data['ban_reason'], - ban_days + user_data["target_user_id"], user_data["ban_reason"], ban_days ) markup = create_keyboard_for_approve_ban() await message.answer(confirmation_text, reply_markup=markup) - await state.set_state('BAN_CONFIRMATION') - + await state.set_state("BAN_CONFIRMATION") + except Exception as e: await handle_admin_error(message, e, state, "process_ban_duration") @@ -414,35 +447,31 @@ async def process_ban_duration( @admin_router.message( ChatTypeFilter(chat_type=["private"]), StateFilter("BAN_CONFIRMATION"), - F.text == 'Подтвердить' + F.text == "Подтвердить", ) @track_time("confirm_ban", "admin_handlers") @track_errors("admin_handlers", "confirm_ban") async def confirm_ban( - message: types.Message, - state: FSMContext, - bot_db: MagicData("bot_db"), - **kwargs - ): + message: types.Message, state: FSMContext, bot_db: MagicData("bot_db"), **kwargs +): """Подтверждение блокировки пользователя""" try: user_data = await state.get_data() admin_service = AdminService(bot_db) - # Выполняем блокировку await admin_service.ban_user( - user_id=user_data['target_user_id'], - username=user_data['target_username'], - reason=user_data['ban_reason'], - ban_days=user_data['ban_days'], + user_id=user_data["target_user_id"], + username=user_data["target_username"], + reason=user_data["ban_reason"], + ban_days=user_data["ban_days"], ban_author_id=message.from_user.id, ) - - safe_username = escape_html(user_data['target_username']) + + safe_username = escape_html(user_data["target_username"]) await message.reply(f"Пользователь {safe_username} успешно заблокирован.") await return_to_admin_menu(message, state) - + except UserAlreadyBannedError as e: await message.reply(str(e)) await return_to_admin_menu(message, state) diff --git a/helper_bot/handlers/admin/constants.py b/helper_bot/handlers/admin/constants.py index 490caff..ff1220c 100644 --- a/helper_bot/handlers/admin/constants.py +++ b/helper_bot/handlers/admin/constants.py @@ -9,7 +9,7 @@ ADMIN_BUTTON_TEXTS: Final[Dict[str, str]] = { "BAN_BY_ID": "Бан по ID", "UNBAN_LIST": "Разбан (список)", "RETURN_TO_BOT": "Вернуться в бота", - "CANCEL": "Отменить" + "CANCEL": "Отменить", } # Admin button to command mapping for metrics @@ -19,11 +19,11 @@ ADMIN_BUTTON_COMMAND_MAPPING: Final[Dict[str, str]] = { "Бан по ID": "admin_ban_by_id", "Разбан (список)": "admin_unban_list", "Вернуться в бота": "admin_return_to_bot", - "Отменить": "admin_cancel" + "Отменить": "admin_cancel", } # Admin commands ADMIN_COMMANDS: Final[Dict[str, str]] = { "ADMIN": "admin", - "TEST_METRICS": "test_metrics" + "TEST_METRICS": "test_metrics", } diff --git a/helper_bot/handlers/admin/dependencies.py b/helper_bot/handlers/admin/dependencies.py index 0a4cd9e..89a486f 100644 --- a/helper_bot/handlers/admin/dependencies.py +++ b/helper_bot/handlers/admin/dependencies.py @@ -7,6 +7,7 @@ except ImportError: from aiogram import BaseMiddleware from aiogram.types import TelegramObject + from helper_bot.utils.base_dependency_factory import get_global_instance from helper_bot.utils.helper_func import check_access from logs.custom_logger import logger @@ -14,36 +15,46 @@ from logs.custom_logger import logger class AdminAccessMiddleware(BaseMiddleware): """Middleware для проверки административного доступа""" - - async def __call__(self, handler, event: TelegramObject, data: Dict[str, Any]) -> Any: - if hasattr(event, 'from_user'): + + async def __call__( + self, handler, event: TelegramObject, data: Dict[str, Any] + ) -> Any: + if hasattr(event, "from_user"): user_id = event.from_user.id - username = getattr(event.from_user, 'username', 'Unknown') - - logger.info(f"AdminAccessMiddleware: проверка доступа для пользователя {username} (ID: {user_id})") - + username = getattr(event.from_user, "username", "Unknown") + + logger.info( + f"AdminAccessMiddleware: проверка доступа для пользователя {username} (ID: {user_id})" + ) + # Получаем bot_db из data (внедренного DependenciesMiddleware) - bot_db = data.get('bot_db') + bot_db = data.get("bot_db") if not bot_db: # Fallback: получаем напрямую если middleware не сработала bdf = get_global_instance() bot_db = bdf.get_db() - + is_admin_result = await check_access(user_id, bot_db) - logger.info(f"AdminAccessMiddleware: результат проверки для {username}: {is_admin_result}") - + logger.info( + f"AdminAccessMiddleware: результат проверки для {username}: {is_admin_result}" + ) + if not is_admin_result: - logger.warning(f"AdminAccessMiddleware: доступ запрещен для пользователя {username} (ID: {user_id})") - if hasattr(event, 'answer'): - await event.answer('Доступ запрещен!') + logger.warning( + f"AdminAccessMiddleware: доступ запрещен для пользователя {username} (ID: {user_id})" + ) + if hasattr(event, "answer"): + await event.answer("Доступ запрещен!") return - + try: # Вызываем хендлер с data return await handler(event, data) except TypeError as e: if "missing 1 required positional argument: 'data'" in str(e): - logger.error(f"Ошибка в AdminAccessMiddleware: {e}. Хендлер не принимает параметр 'data'") + logger.error( + f"Ошибка в AdminAccessMiddleware: {e}. Хендлер не принимает параметр 'data'" + ) # Пытаемся вызвать хендлер без data (для совместимости с MagicData) return await handler(event) else: diff --git a/helper_bot/handlers/admin/exceptions.py b/helper_bot/handlers/admin/exceptions.py index 8ad1fed..3cf290d 100644 --- a/helper_bot/handlers/admin/exceptions.py +++ b/helper_bot/handlers/admin/exceptions.py @@ -1,23 +1,28 @@ class AdminError(Exception): """Базовое исключение для административных операций""" + pass class AdminAccessDeniedError(AdminError): """Исключение при отказе в административном доступе""" + pass class UserNotFoundError(AdminError): """Исключение при отсутствии пользователя""" + pass class InvalidInputError(AdminError): """Исключение при некорректном вводе данных""" + pass class UserAlreadyBannedError(AdminError): """Исключение при попытке забанить уже заблокированного пользователя""" + pass diff --git a/helper_bot/handlers/admin/rate_limit_handlers.py b/helper_bot/handlers/admin/rate_limit_handlers.py index 3c73c6a..2837121 100644 --- a/helper_bot/handlers/admin/rate_limit_handlers.py +++ b/helper_bot/handlers/admin/rate_limit_handlers.py @@ -1,25 +1,31 @@ """ Обработчики команд для мониторинга rate limiting """ + from aiogram import F, Router, types from aiogram.filters import Command, MagicData from aiogram.fsm.context import FSMContext from aiogram.types import FSInputFile + from helper_bot.filters.main import ChatTypeFilter -from helper_bot.middlewares.dependencies_middleware import \ - DependenciesMiddleware +from helper_bot.middlewares.dependencies_middleware import DependenciesMiddleware + # Local imports - metrics from helper_bot.utils.metrics import track_errors, track_time from helper_bot.utils.rate_limit_metrics import ( - get_rate_limit_metrics_summary, update_rate_limit_gauges) -from helper_bot.utils.rate_limit_monitor import (get_rate_limit_summary, - rate_limit_monitor) + get_rate_limit_metrics_summary, + update_rate_limit_gauges, +) +from helper_bot.utils.rate_limit_monitor import ( + get_rate_limit_summary, + rate_limit_monitor, +) from logs.custom_logger import logger class RateLimitHandlers: def __init__(self, db, settings): - self.db = db.get_db() if hasattr(db, 'get_db') else db + self.db = db.get_db() if hasattr(db, "get_db") else db self.settings = settings self.router = Router() self._setup_handlers() @@ -33,38 +39,38 @@ class RateLimitHandlers: self.router.message.register( self.rate_limit_stats_handler, ChatTypeFilter(chat_type=["private"]), - Command("ratelimit_stats") + Command("ratelimit_stats"), ) - + # Команда для сброса статистики rate limiting self.router.message.register( self.reset_rate_limit_stats_handler, ChatTypeFilter(chat_type=["private"]), - Command("reset_ratelimit_stats") + Command("reset_ratelimit_stats"), ) - + # Команда для просмотра ошибок rate limiting self.router.message.register( self.rate_limit_errors_handler, ChatTypeFilter(chat_type=["private"]), - Command("ratelimit_errors") + Command("ratelimit_errors"), ) - + # Команда для просмотра Prometheus метрик self.router.message.register( self.rate_limit_prometheus_handler, ChatTypeFilter(chat_type=["private"]), - Command("ratelimit_prometheus") + Command("ratelimit_prometheus"), ) @track_time("rate_limit_stats_handler", "rate_limit_handlers") @track_errors("rate_limit_handlers", "rate_limit_stats_handler") async def rate_limit_stats_handler( self, - message: types.Message, - state: FSMContext, + message: types.Message, + state: FSMContext, bot_db: MagicData("bot_db"), - settings: MagicData("settings") + settings: MagicData("settings"), ): """Показывает статистику rate limiting""" try: @@ -72,11 +78,11 @@ class RateLimitHandlers: if not await bot_db.is_admin(message.from_user.id): await message.answer("У вас нет прав для выполнения этой команды.") return - + # Получаем сводку summary = get_rate_limit_summary() global_stats = rate_limit_monitor.get_global_stats() - + # Формируем сообщение со статистикой stats_text = ( f"📊 Статистика Rate Limiting\n\n" @@ -89,15 +95,17 @@ class RateLimitHandlers: f"• Активных чатов: {summary['active_chats']}\n" f"• Ошибок за час: {summary['recent_errors_count']}\n\n" ) - + # Добавляем детальную статистику stats_text += f"🔍 Детальная статистика:\n" stats_text += f"• Успешных запросов: {global_stats.successful_requests}\n" stats_text += f"• Неудачных запросов: {global_stats.failed_requests}\n" stats_text += f"• RetryAfter ошибок: {global_stats.retry_after_errors}\n" stats_text += f"• Других ошибок: {global_stats.other_errors}\n" - stats_text += f"• Общее время ожидания: {global_stats.total_wait_time:.2f}с\n\n" - + stats_text += ( + f"• Общее время ожидания: {global_stats.total_wait_time:.2f}с\n\n" + ) + # Добавляем топ чатов по запросам top_chats = rate_limit_monitor.get_top_chats_by_requests(5) if top_chats: @@ -105,16 +113,16 @@ class RateLimitHandlers: for i, (chat_id, chat_stats) in enumerate(top_chats, 1): stats_text += f"{i}. Chat {chat_id}: {chat_stats.total_requests} запросов ({chat_stats.success_rate:.1%} успех)\n" stats_text += "\n" - + # Добавляем чаты с высоким процентом ошибок high_error_chats = rate_limit_monitor.get_chats_with_high_error_rate(0.1) if high_error_chats: stats_text += f"⚠️ Чаты с высоким процентом ошибок (>10%):\n" for chat_id, chat_stats in high_error_chats[:3]: stats_text += f"• Chat {chat_id}: {chat_stats.error_rate:.1%} ошибок ({chat_stats.failed_requests}/{chat_stats.total_requests})\n" - - await message.answer(stats_text, parse_mode='HTML') - + + await message.answer(stats_text, parse_mode="HTML") + except Exception as e: logger.error(f"Ошибка при получении статистики rate limiting: {e}") await message.answer("Произошла ошибка при получении статистики.") @@ -123,10 +131,10 @@ class RateLimitHandlers: @track_errors("rate_limit_handlers", "reset_rate_limit_stats_handler") async def reset_rate_limit_stats_handler( self, - message: types.Message, - state: FSMContext, + message: types.Message, + state: FSMContext, bot_db: MagicData("bot_db"), - settings: MagicData("settings") + settings: MagicData("settings"), ): """Сбрасывает статистику rate limiting""" try: @@ -134,12 +142,12 @@ class RateLimitHandlers: if not await bot_db.is_admin(message.from_user.id): await message.answer("У вас нет прав для выполнения этой команды.") return - + # Сбрасываем статистику rate_limit_monitor.reset_stats() - + await message.answer("✅ Статистика rate limiting сброшена.") - + except Exception as e: logger.error(f"Ошибка при сбросе статистики rate limiting: {e}") await message.answer("Произошла ошибка при сбросе статистики.") @@ -148,10 +156,10 @@ class RateLimitHandlers: @track_errors("rate_limit_handlers", "rate_limit_errors_handler") async def rate_limit_errors_handler( self, - message: types.Message, - state: FSMContext, + message: types.Message, + state: FSMContext, bot_db: MagicData("bot_db"), - settings: MagicData("settings") + settings: MagicData("settings"), ): """Показывает недавние ошибки rate limiting""" try: @@ -159,29 +167,34 @@ class RateLimitHandlers: if not await bot_db.is_admin(message.from_user.id): await message.answer("У вас нет прав для выполнения этой команды.") return - + # Получаем ошибки за последний час recent_errors = rate_limit_monitor.get_recent_errors(60) error_summary = rate_limit_monitor.get_error_summary(60) - + if not recent_errors: - await message.answer("✅ Ошибок rate limiting за последний час не было.") + await message.answer( + "✅ Ошибок rate limiting за последний час не было." + ) return - + # Формируем сообщение с ошибками errors_text = f"🚨 Ошибки Rate Limiting (последний час)\n\n" errors_text += f"📊 Сводка ошибок:\n" for error_type, count in error_summary.items(): errors_text += f"• {error_type}: {count}\n" errors_text += f"\nВсего ошибок: {len(recent_errors)}\n\n" - + # Показываем последние 10 ошибок errors_text += f"🔍 Последние ошибки:\n" for i, error in enumerate(recent_errors[-10:], 1): from datetime import datetime - timestamp = datetime.fromtimestamp(error['timestamp']).strftime("%H:%M:%S") + + timestamp = datetime.fromtimestamp(error["timestamp"]).strftime( + "%H:%M:%S" + ) errors_text += f"{i}. {timestamp} - Chat {error['chat_id']} - {error['error_type']}\n" - + # Если сообщение слишком длинное, разбиваем на части if len(errors_text) > 4000: # Отправляем сводку @@ -190,32 +203,37 @@ class RateLimitHandlers: for error_type, count in error_summary.items(): summary_text += f"• {error_type}: {count}\n" summary_text += f"\nВсего ошибок: {len(recent_errors)}" - - await message.answer(summary_text, parse_mode='HTML') - + + await message.answer(summary_text, parse_mode="HTML") + # Отправляем детали отдельным сообщением details_text = f"🔍 Последние ошибки:\n" for i, error in enumerate(recent_errors[-10:], 1): from datetime import datetime - timestamp = datetime.fromtimestamp(error['timestamp']).strftime("%H:%M:%S") + + timestamp = datetime.fromtimestamp(error["timestamp"]).strftime( + "%H:%M:%S" + ) details_text += f"{i}. {timestamp} - Chat {error['chat_id']} - {error['error_type']}\n" - - await message.answer(details_text, parse_mode='HTML') + + await message.answer(details_text, parse_mode="HTML") else: - await message.answer(errors_text, parse_mode='HTML') - + await message.answer(errors_text, parse_mode="HTML") + except Exception as e: logger.error(f"Ошибка при получении ошибок rate limiting: {e}") - await message.answer("Произошла ошибка при получении информации об ошибках.") + await message.answer( + "Произошла ошибка при получении информации об ошибках." + ) @track_time("rate_limit_prometheus_handler", "rate_limit_handlers") @track_errors("rate_limit_handlers", "rate_limit_prometheus_handler") async def rate_limit_prometheus_handler( self, - message: types.Message, - state: FSMContext, + message: types.Message, + state: FSMContext, bot_db: MagicData("bot_db"), - settings: MagicData("settings") + settings: MagicData("settings"), ): """Показывает Prometheus метрики rate limiting""" try: @@ -223,13 +241,13 @@ class RateLimitHandlers: if not await bot_db.is_admin(message.from_user.id): await message.answer("У вас нет прав для выполнения этой команды.") return - + # Обновляем gauge метрики update_rate_limit_gauges() - + # Получаем сводку метрик metrics_summary = get_rate_limit_metrics_summary() - + # Формируем сообщение с метриками metrics_text = ( f"📊 Prometheus метрики Rate Limiting\n\n" @@ -241,30 +259,40 @@ class RateLimitHandlers: f"• rate_limit_avg_wait_time: {metrics_summary['average_wait_time']:.3f}s\n" f"• rate_limit_active_chats: {metrics_summary['active_chats']}\n\n" ) - + # Добавляем детальные метрики metrics_text += f"🔍 Детальные метрики:\n" - metrics_text += f"• Успешных запросов: {metrics_summary['successful_requests']}\n" - metrics_text += f"• Неудачных запросов: {metrics_summary['failed_requests']}\n" - metrics_text += f"• RetryAfter ошибок: {metrics_summary['retry_after_errors']}\n" + metrics_text += ( + f"• Успешных запросов: {metrics_summary['successful_requests']}\n" + ) + metrics_text += ( + f"• Неудачных запросов: {metrics_summary['failed_requests']}\n" + ) + metrics_text += ( + f"• RetryAfter ошибок: {metrics_summary['retry_after_errors']}\n" + ) metrics_text += f"• Других ошибок: {metrics_summary['other_errors']}\n" - metrics_text += f"• Общее время ожидания: {metrics_summary['total_wait_time']:.2f}s\n\n" - + metrics_text += ( + f"• Общее время ожидания: {metrics_summary['total_wait_time']:.2f}s\n\n" + ) + # Добавляем информацию о доступных метриках metrics_text += f"📈 Доступные Prometheus метрики:\n" metrics_text += f"• rate_limit_requests_total - общее количество запросов\n" metrics_text += f"• rate_limit_errors_total - количество ошибок по типам\n" metrics_text += f"• rate_limit_wait_duration_seconds - время ожидания\n" - metrics_text += f"• rate_limit_request_interval_seconds - интервалы между запросами\n" + metrics_text += ( + f"• rate_limit_request_interval_seconds - интервалы между запросами\n" + ) metrics_text += f"• rate_limit_active_chats - количество активных чатов\n" metrics_text += f"• rate_limit_success_rate - процент успеха по чатам\n" metrics_text += f"• rate_limit_requests_per_minute - запросов в минуту\n" metrics_text += f"• rate_limit_total_requests - общее количество запросов\n" metrics_text += f"• rate_limit_total_errors - количество ошибок\n" metrics_text += f"• rate_limit_avg_wait_time - среднее время ожидания\n" - - await message.answer(metrics_text, parse_mode='HTML') - + + await message.answer(metrics_text, parse_mode="HTML") + except Exception as e: logger.error(f"Ошибка при получении Prometheus метрик: {e}") await message.answer("Произошла ошибка при получении метрик.") diff --git a/helper_bot/handlers/admin/services.py b/helper_bot/handlers/admin/services.py index 7b92973..6eb625e 100644 --- a/helper_bot/handlers/admin/services.py +++ b/helper_bot/handlers/admin/services.py @@ -1,11 +1,16 @@ from datetime import datetime from typing import List, Optional -from helper_bot.handlers.admin.exceptions import (InvalidInputError, - UserAlreadyBannedError) -from helper_bot.utils.helper_func import (add_days_to_date, - get_banned_users_buttons, - get_banned_users_list) +from helper_bot.handlers.admin.exceptions import ( + InvalidInputError, + UserAlreadyBannedError, +) +from helper_bot.utils.helper_func import ( + add_days_to_date, + get_banned_users_buttons, + get_banned_users_list, +) + # Local imports - metrics from helper_bot.utils.metrics import track_errors, track_time from logs.custom_logger import logger @@ -13,6 +18,7 @@ from logs.custom_logger import logger class User: """Модель пользователя""" + def __init__(self, user_id: int, username: str, full_name: str): self.user_id = user_id self.username = username @@ -21,7 +27,10 @@ class User: class BannedUser: """Модель заблокированного пользователя""" - def __init__(self, user_id: int, username: str, reason: str, unban_date: Optional[datetime]): + + def __init__( + self, user_id: int, username: str, reason: str, unban_date: Optional[datetime] + ): self.user_id = user_id self.username = username self.reason = reason @@ -30,10 +39,10 @@ class BannedUser: class AdminService: """Сервис для административных операций""" - + def __init__(self, bot_db): self.bot_db = bot_db - + @track_time("get_last_users", "admin_service") @track_errors("admin_service", "get_last_users") async def get_last_users(self) -> List[User]: @@ -41,17 +50,13 @@ class AdminService: try: users_data = await self.bot_db.get_last_users(30) return [ - User( - user_id=user[1], - username='Неизвестно', - full_name=user[0] - ) + User(user_id=user[1], username="Неизвестно", full_name=user[0]) for user in users_data ] except Exception as e: logger.error(f"Ошибка при получении списка последних пользователей: {e}") raise - + @track_time("get_banned_users", "admin_service") @track_errors("admin_service", "get_banned_users") async def get_banned_users(self) -> List[BannedUser]: @@ -65,18 +70,22 @@ class AdminService: username = await self.bot_db.get_username(user_id) full_name = await self.bot_db.get_full_name_by_id(user_id) user_name = username or full_name or f"User_{user_id}" - - banned_users.append(BannedUser( - user_id=user_id, - username=user_name, - reason=reason, - unban_date=unban_date - )) + + banned_users.append( + BannedUser( + user_id=user_id, + username=user_name, + reason=reason, + unban_date=unban_date, + ) + ) return banned_users except Exception as e: - logger.error(f"Ошибка при получении списка заблокированных пользователей: {e}") + logger.error( + f"Ошибка при получении списка заблокированных пользователей: {e}" + ) raise - + @track_time("get_user_by_username", "admin_service") @track_errors("admin_service", "get_user_by_username") async def get_user_by_username(self, username: str) -> Optional[User]: @@ -85,17 +94,15 @@ class AdminService: user_id = await self.bot_db.get_user_id_by_username(username) if not user_id: return None - + full_name = await self.bot_db.get_full_name_by_id(user_id) return User( - user_id=user_id, - username=username, - full_name=full_name or 'Неизвестно' + user_id=user_id, username=username, full_name=full_name or "Неизвестно" ) except Exception as e: logger.error(f"Ошибка при поиске пользователя по username {username}: {e}") raise - + @track_time("get_user_by_id", "admin_service") @track_errors("admin_service", "get_user_by_id") async def get_user_by_id(self, user_id: int) -> Optional[User]: @@ -104,39 +111,50 @@ class AdminService: user_info = await self.bot_db.get_user_by_id(user_id) if not user_info: return None - + return User( user_id=user_id, - username=user_info.username or 'Неизвестно', - full_name=user_info.full_name or 'Неизвестно' + username=user_info.username or "Неизвестно", + full_name=user_info.full_name or "Неизвестно", ) except Exception as e: logger.error(f"Ошибка при поиске пользователя по ID {user_id}: {e}") raise - + @track_time("ban_user", "admin_service") @track_errors("admin_service", "ban_user") - async def ban_user(self, user_id: int, username: str, reason: str, ban_days: Optional[int], ban_author_id: int) -> None: + async def ban_user( + self, + user_id: int, + username: str, + reason: str, + ban_days: Optional[int], + ban_author_id: int, + ) -> None: """Заблокировать пользователя""" try: # Проверяем, не заблокирован ли уже пользователь if await self.bot_db.check_user_in_blacklist(user_id): raise UserAlreadyBannedError(f"Пользователь {user_id} уже заблокирован") - + # Рассчитываем дату разблокировки date_to_unban = None if ban_days is not None: date_to_unban = add_days_to_date(ban_days) - + # Сохраняем в БД (username больше не передается, так как не используется в новой схеме) - await self.bot_db.set_user_blacklist(user_id, None, reason, date_to_unban, ban_author=ban_author_id) - - logger.info(f"Пользователь {user_id} ({username}) заблокирован. Причина: {reason}, срок: {ban_days} дней") - + await self.bot_db.set_user_blacklist( + user_id, None, reason, date_to_unban, ban_author=ban_author_id + ) + + logger.info( + f"Пользователь {user_id} ({username}) заблокирован. Причина: {reason}, срок: {ban_days} дней" + ) + except Exception as e: logger.error(f"Ошибка при блокировке пользователя {user_id}: {e}") raise - + @track_time("unban_user", "admin_service") @track_errors("admin_service", "unban_user") async def unban_user(self, user_id: int) -> None: @@ -147,7 +165,7 @@ class AdminService: except Exception as e: logger.error(f"Ошибка при разблокировке пользователя {user_id}: {e}") raise - + @track_time("validate_user_input", "admin_service") @track_errors("admin_service", "validate_user_input") async def validate_user_input(self, input_text: str) -> int: @@ -155,11 +173,13 @@ class AdminService: try: user_id = int(input_text.strip()) if user_id <= 0: - raise InvalidInputError("ID пользователя должен быть положительным числом") + raise InvalidInputError( + "ID пользователя должен быть положительным числом" + ) return user_id except ValueError: raise InvalidInputError("ID пользователя должен быть числом") - + @track_time("get_banned_users_for_display", "admin_service") @track_errors("admin_service", "get_banned_users_for_display") async def get_banned_users_for_display(self, page: int = 0) -> tuple[str, list]: @@ -170,5 +190,7 @@ class AdminService: buttons_list = await get_banned_users_buttons(self.bot_db) return message_text, buttons_list except Exception as e: - logger.error(f"Ошибка при получении данных заблокированных пользователей: {e}") + logger.error( + f"Ошибка при получении данных заблокированных пользователей: {e}" + ) raise diff --git a/helper_bot/handlers/admin/utils.py b/helper_bot/handlers/admin/utils.py index 292dd2b..74fea5c 100644 --- a/helper_bot/handlers/admin/utils.py +++ b/helper_bot/handlers/admin/utils.py @@ -3,6 +3,7 @@ from typing import Optional from aiogram import types from aiogram.fsm.context import FSMContext + from helper_bot.handlers.admin.exceptions import AdminError from helper_bot.keyboards.keyboards import get_reply_keyboard_admin from logs.custom_logger import logger @@ -13,33 +14,41 @@ def escape_html(text: str) -> str: return html.escape(str(text)) if text else "" -async def return_to_admin_menu(message: types.Message, state: FSMContext, - additional_message: Optional[str] = None) -> None: +async def return_to_admin_menu( + message: types.Message, state: FSMContext, additional_message: Optional[str] = None +) -> None: """Универсальная функция для возврата в админ-меню""" - logger.info(f"return_to_admin_menu: Возврат в админ-меню для пользователя {message.from_user.id}") - + logger.info( + f"return_to_admin_menu: Возврат в админ-меню для пользователя {message.from_user.id}" + ) + await state.set_data({}) await state.set_state("ADMIN") markup = get_reply_keyboard_admin() - + if additional_message: - logger.info(f"return_to_admin_menu: Отправка дополнительного сообщения: {additional_message}") + logger.info( + f"return_to_admin_menu: Отправка дополнительного сообщения: {additional_message}" + ) await message.answer(additional_message) - - await message.answer('Вернулись в меню', reply_markup=markup) - logger.info(f"return_to_admin_menu: Пользователь {message.from_user.id} успешно возвращен в админ-меню") + + await message.answer("Вернулись в меню", reply_markup=markup) + logger.info( + f"return_to_admin_menu: Пользователь {message.from_user.id} успешно возвращен в админ-меню" + ) -async def handle_admin_error(message: types.Message, error: Exception, - state: FSMContext, error_context: str = "") -> None: +async def handle_admin_error( + message: types.Message, error: Exception, state: FSMContext, error_context: str = "" +) -> None: """Централизованная обработка ошибок административных операций""" logger.error(f"Ошибка в {error_context}: {error}") - + if isinstance(error, AdminError): await message.answer(f"Ошибка: {str(error)}") else: await message.answer("Произошла внутренняя ошибка. Попробуйте позже.") - + await return_to_admin_menu(message, state) @@ -47,19 +56,23 @@ def format_user_info(user_id: int, username: str, full_name: str) -> str: """Форматирование информации о пользователе для отображения""" safe_username = escape_html(username) safe_full_name = escape_html(full_name) - - return (f"Выбран пользователь:\n" - f"ID: {user_id}\n" - f"Username: {safe_username}\n" - f"Имя: {safe_full_name}") + + return ( + f"Выбран пользователь:\n" + f"ID: {user_id}\n" + f"Username: {safe_username}\n" + f"Имя: {safe_full_name}" + ) def format_ban_confirmation(user_id: int, reason: str, ban_days: Optional[int]) -> str: """Форматирование подтверждения бана""" safe_reason = escape_html(reason) ban_text = "Навсегда" if ban_days is None else f"{ban_days} дней" - - return (f"Необходимо подтверждение:\n" - f"Пользователь: {user_id}\n" - f"Причина бана: {safe_reason}\n" - f"Срок бана: {ban_text}") + + return ( + f"Необходимо подтверждение:\n" + f"Пользователь: {user_id}\n" + f"Причина бана: {safe_reason}\n" + f"Срок бана: {ban_text}" + ) diff --git a/helper_bot/handlers/callback/__init__.py b/helper_bot/handlers/callback/__init__.py index b06d41a..ccb7ddd 100644 --- a/helper_bot/handlers/callback/__init__.py +++ b/helper_bot/handlers/callback/__init__.py @@ -1,23 +1,34 @@ from .callback_handlers import callback_router -from .constants import (CALLBACK_BAN, CALLBACK_DECLINE, CALLBACK_PAGE, - CALLBACK_PUBLISH, CALLBACK_RETURN, CALLBACK_UNLOCK) -from .exceptions import (BanError, PostNotFoundError, PublishError, - UserBlockedBotError, UserNotFoundError) +from .constants import ( + CALLBACK_BAN, + CALLBACK_DECLINE, + CALLBACK_PAGE, + CALLBACK_PUBLISH, + CALLBACK_RETURN, + CALLBACK_UNLOCK, +) +from .exceptions import ( + BanError, + PostNotFoundError, + PublishError, + UserBlockedBotError, + UserNotFoundError, +) from .services import BanService, PostPublishService __all__ = [ - 'callback_router', - 'PostPublishService', - 'BanService', - 'UserBlockedBotError', - 'PostNotFoundError', - 'UserNotFoundError', - 'PublishError', - 'BanError', - 'CALLBACK_PUBLISH', - 'CALLBACK_DECLINE', - 'CALLBACK_BAN', - 'CALLBACK_UNLOCK', - 'CALLBACK_RETURN', - 'CALLBACK_PAGE' + "callback_router", + "PostPublishService", + "BanService", + "UserBlockedBotError", + "PostNotFoundError", + "UserNotFoundError", + "PublishError", + "BanError", + "CALLBACK_PUBLISH", + "CALLBACK_DECLINE", + "CALLBACK_BAN", + "CALLBACK_UNLOCK", + "CALLBACK_RETURN", + "CALLBACK_PAGE", ] diff --git a/helper_bot/handlers/callback/callback_handlers.py b/helper_bot/handlers/callback/callback_handlers.py index 3d06760..d156184 100644 --- a/helper_bot/handlers/callback/callback_handlers.py +++ b/helper_bot/handlers/callback/callback_handlers.py @@ -7,28 +7,49 @@ from aiogram import F, Router from aiogram.filters import MagicData from aiogram.fsm.context import FSMContext from aiogram.types import CallbackQuery + from helper_bot.handlers.admin.utils import format_user_info from helper_bot.handlers.voice.constants import CALLBACK_DELETE, CALLBACK_SAVE from helper_bot.handlers.voice.services import AudioFileService -from helper_bot.keyboards.keyboards import (create_keyboard_for_ban_reason, - create_keyboard_with_pagination, - get_reply_keyboard_admin) +from helper_bot.keyboards.keyboards import ( + create_keyboard_for_ban_reason, + create_keyboard_with_pagination, + get_reply_keyboard_admin, +) from helper_bot.utils.base_dependency_factory import get_global_instance -from helper_bot.utils.helper_func import (get_banned_users_buttons, - get_banned_users_list) +from helper_bot.utils.helper_func import get_banned_users_buttons, get_banned_users_list + # Local imports - metrics -from helper_bot.utils.metrics import (db_query_time, track_errors, - track_file_operations, track_time) +from helper_bot.utils.metrics import ( + db_query_time, + track_errors, + track_file_operations, + track_time, +) from logs.custom_logger import logger -from .constants import (CALLBACK_BAN, CALLBACK_DECLINE, CALLBACK_PAGE, - CALLBACK_PUBLISH, CALLBACK_RETURN, CALLBACK_UNLOCK, - ERROR_BOT_BLOCKED, MESSAGE_DECLINED, MESSAGE_ERROR, - MESSAGE_PUBLISHED, MESSAGE_USER_BANNED, - MESSAGE_USER_UNLOCKED) +from .constants import ( + CALLBACK_BAN, + CALLBACK_DECLINE, + CALLBACK_PAGE, + CALLBACK_PUBLISH, + CALLBACK_RETURN, + CALLBACK_UNLOCK, + ERROR_BOT_BLOCKED, + MESSAGE_DECLINED, + MESSAGE_ERROR, + MESSAGE_PUBLISHED, + MESSAGE_USER_BANNED, + MESSAGE_USER_UNLOCKED, +) from .dependency_factory import get_ban_service, get_post_publish_service -from .exceptions import (BanError, PostNotFoundError, PublishError, - UserBlockedBotError, UserNotFoundError) +from .exceptions import ( + BanError, + PostNotFoundError, + PublishError, + UserBlockedBotError, + UserNotFoundError, +) callback_router = Router() @@ -36,65 +57,61 @@ callback_router = Router() @callback_router.callback_query(F.data == CALLBACK_PUBLISH) @track_time("post_for_group", "callback_handlers") @track_errors("callback_handlers", "post_for_group") -async def post_for_group( - call: CallbackQuery, - settings: MagicData("settings") - ): +async def post_for_group(call: CallbackQuery, settings: MagicData("settings")): publish_service = get_post_publish_service() # TODO: переделать на MagicData logger.info( - f'Получен callback-запрос с действием: {call.data} от пользователя {call.from_user.full_name} (ID сообщения: {call.message.message_id})') - + f"Получен callback-запрос с действием: {call.data} от пользователя {call.from_user.full_name} (ID сообщения: {call.message.message_id})" + ) + try: await publish_service.publish_post(call) await call.answer(text=MESSAGE_PUBLISHED, cache_time=3) except UserBlockedBotError: await call.answer(text=MESSAGE_ERROR, show_alert=True, cache_time=3) except (PostNotFoundError, PublishError) as e: - logger.error(f'Ошибка при публикации поста: {str(e)}') + logger.error(f"Ошибка при публикации поста: {str(e)}") await call.answer(text=MESSAGE_ERROR, show_alert=True, cache_time=3) except Exception as e: if str(e) == ERROR_BOT_BLOCKED: await call.answer(text=MESSAGE_ERROR, show_alert=True, cache_time=3) else: - important_logs = settings['Telegram']['important_logs'] + important_logs = settings["Telegram"]["important_logs"] await call.bot.send_message( chat_id=important_logs, - text=f"Произошла ошибка: {str(e)}\n\nTraceback:\n{traceback.format_exc()}" + text=f"Произошла ошибка: {str(e)}\n\nTraceback:\n{traceback.format_exc()}", ) - logger.error(f'Неожиданная ошибка при публикации поста: {str(e)}') + logger.error(f"Неожиданная ошибка при публикации поста: {str(e)}") await call.answer(text=MESSAGE_ERROR, show_alert=True, cache_time=3) @callback_router.callback_query(F.data == CALLBACK_DECLINE) @track_time("decline_post_for_group", "callback_handlers") @track_errors("callback_handlers", "decline_post_for_group") -async def decline_post_for_group( - call: CallbackQuery, - settings: MagicData("settings") - ): +async def decline_post_for_group(call: CallbackQuery, settings: MagicData("settings")): publish_service = get_post_publish_service() # TODO: переделать на MagicData logger.info( - f'Получен callback-запрос с данными: {call.data} от пользователя {call.from_user.full_name} (ID: {call.from_user.id})') + f"Получен callback-запрос с данными: {call.data} от пользователя {call.from_user.full_name} (ID: {call.from_user.id})" + ) try: await publish_service.decline_post(call) await call.answer(text=MESSAGE_DECLINED, cache_time=3) except UserBlockedBotError: await call.answer(text=MESSAGE_ERROR, show_alert=True, cache_time=3) except (PostNotFoundError, PublishError) as e: - logger.error(f'Ошибка при отклонении поста: {str(e)}') + logger.error(f"Ошибка при отклонении поста: {str(e)}") await call.answer(text=MESSAGE_ERROR, show_alert=True, cache_time=3) except Exception as e: if str(e) == ERROR_BOT_BLOCKED: await call.answer(text=MESSAGE_ERROR, show_alert=True, cache_time=3) else: - important_logs = settings['Telegram']['important_logs'] + important_logs = settings["Telegram"]["important_logs"] await call.bot.send_message( chat_id=important_logs, - text=f"Произошла ошибка: {str(e)}\n\nTraceback:\n{traceback.format_exc()}" + text=f"Произошла ошибка: {str(e)}\n\nTraceback:\n{traceback.format_exc()}", ) - logger.error(f'Неожиданная ошибка при отклонении поста: {str(e)}') + logger.error(f"Неожиданная ошибка при отклонении поста: {str(e)}") await call.answer(text=MESSAGE_ERROR, show_alert=True, cache_time=3) @@ -110,65 +127,75 @@ async def ban_user_from_post(call: CallbackQuery, **kwargs): except UserBlockedBotError: await call.answer(text=MESSAGE_ERROR, show_alert=True, cache_time=3) except (UserNotFoundError, BanError) as e: - logger.error(f'Ошибка при блокировке пользователя: {str(e)}') + logger.error(f"Ошибка при блокировке пользователя: {str(e)}") await call.answer(text=MESSAGE_ERROR, show_alert=True, cache_time=3) except Exception as e: if str(e) == ERROR_BOT_BLOCKED: await call.answer(text=MESSAGE_ERROR, show_alert=True, cache_time=3) else: - logger.error(f'Неожиданная ошибка при блокировке пользователя: {str(e)}') + logger.error(f"Неожиданная ошибка при блокировке пользователя: {str(e)}") await call.answer(text=MESSAGE_ERROR, show_alert=True, cache_time=3) @callback_router.callback_query(F.data.contains(CALLBACK_BAN)) @track_time("process_ban_user", "callback_handlers") @track_errors("callback_handlers", "process_ban_user") -async def process_ban_user(call: CallbackQuery, state: FSMContext, bot_db: MagicData("bot_db"), **kwargs): +async def process_ban_user( + call: CallbackQuery, state: FSMContext, bot_db: MagicData("bot_db"), **kwargs +): ban_service = get_ban_service() # TODO: переделать на MagicData user_id = call.data[4:] - logger.info(f"Вызов функции process_ban_user. Данные callback: {call.data} пользователь: {user_id}") - + logger.info( + f"Вызов функции process_ban_user. Данные callback: {call.data} пользователь: {user_id}" + ) + # Проверяем, что user_id является валидным числом try: user_id_int = int(user_id) except ValueError: logger.error(f"Некорректный user_id в callback: {user_id}") - await call.answer(text="Ошибка: некорректный ID пользователя", show_alert=True, cache_time=3) + await call.answer( + text="Ошибка: некорректный ID пользователя", show_alert=True, cache_time=3 + ) return - + try: # Получаем username пользователя username = await ban_service.ban_user(str(user_id_int), "") if not username: raise UserNotFoundError(f"Пользователь с ID {user_id_int} не найден в базе") - + # Получаем full_name пользователя из базы данных full_name = await bot_db.get_full_name_by_id(user_id_int) if not full_name: - full_name = 'Неизвестно' - + full_name = "Неизвестно" + # Сохраняем данные в формате, совместимом с admin_handlers await state.update_data( target_user_id=user_id_int, target_username=username, - target_full_name=full_name + target_full_name=full_name, ) - + # Используем единый формат отображения информации о пользователе user_info = format_user_info(user_id_int, username, full_name) markup = create_keyboard_for_ban_reason() - + await call.message.answer( text=f"{user_info}\n\nВыбери причину бана из списка или напиши ее в чат", - reply_markup=markup + reply_markup=markup, + ) + await state.set_state("AWAIT_BAN_DETAILS") + logger.info( + f"process_ban_user: Состояние изменено на AWAIT_BAN_DETAILS для пользователя {user_id_int}" ) - await state.set_state('AWAIT_BAN_DETAILS') - logger.info(f"process_ban_user: Состояние изменено на AWAIT_BAN_DETAILS для пользователя {user_id_int}") except UserNotFoundError: markup = get_reply_keyboard_admin() - await call.message.answer(text='Пользователь с таким ID не найден в базе', reply_markup=markup) - await state.set_state('ADMIN') + await call.message.answer( + text="Пользователь с таким ID не найден в базе", reply_markup=markup + ) + await state.set_state("ADMIN") @callback_router.callback_query(F.data.contains(CALLBACK_UNLOCK)) @@ -178,22 +205,26 @@ async def process_unlock_user(call: CallbackQuery, **kwargs): ban_service = get_ban_service() # TODO: переделать на MagicData user_id = call.data[7:] - + # Проверяем, что user_id является валидным числом try: user_id_int = int(user_id) except ValueError: logger.error(f"Некорректный user_id в callback: {user_id}") - await call.answer(text="Ошибка: некорректный ID пользователя", show_alert=True, cache_time=3) + await call.answer( + text="Ошибка: некорректный ID пользователя", show_alert=True, cache_time=3 + ) return - + try: username = await ban_service.unlock_user(str(user_id_int)) - await call.answer(f'{MESSAGE_USER_UNLOCKED} {username}', show_alert=True) + await call.answer(f"{MESSAGE_USER_UNLOCKED} {username}", show_alert=True) except UserNotFoundError: - await call.answer(text='Пользователь не найден в базе', show_alert=True, cache_time=3) + await call.answer( + text="Пользователь не найден в базе", show_alert=True, cache_time=3 + ) except Exception as e: - logger.error(f'Ошибка при разблокировке пользователя: {str(e)}') + logger.error(f"Ошибка при разблокировке пользователя: {str(e)}") await call.answer(text=MESSAGE_ERROR, show_alert=True, cache_time=3) @@ -204,48 +235,52 @@ async def return_to_main_menu(call: CallbackQuery, **kwargs): await call.message.delete() logger.info(f"Запуск админ панели для пользователя: {call.message.from_user.id}") markup = get_reply_keyboard_admin() - await call.message.answer("Добро пожаловать в админку. Выбери что хочешь:", reply_markup=markup) + await call.message.answer( + "Добро пожаловать в админку. Выбери что хочешь:", reply_markup=markup + ) @callback_router.callback_query(F.data.contains(CALLBACK_PAGE)) @track_time("change_page", "callback_handlers") @track_errors("callback_handlers", "change_page") -async def change_page( - call: CallbackQuery, - bot_db: MagicData("bot_db"), - **kwargs - ): +async def change_page(call: CallbackQuery, bot_db: MagicData("bot_db"), **kwargs): try: page_number = int(call.data[5:]) except ValueError: logger.error(f"Некорректный номер страницы в callback: {call.data}") - await call.answer(text="Ошибка: некорректный номер страницы", show_alert=True, cache_time=3) + await call.answer( + text="Ошибка: некорректный номер страницы", show_alert=True, cache_time=3 + ) return - + logger.info(f"Переход на страницу {page_number}") - - if call.message.text == 'Список пользователей которые последними обращались к боту': + + if call.message.text == "Список пользователей которые последними обращались к боту": list_users = await bot_db.get_last_users(30) - keyboard = create_keyboard_with_pagination(page_number, len(list_users), list_users, 'ban') + keyboard = create_keyboard_with_pagination( + page_number, len(list_users), list_users, "ban" + ) await call.bot.edit_message_reply_markup( - chat_id=call.message.chat.id, + chat_id=call.message.chat.id, message_id=call.message.message_id, - reply_markup=keyboard + reply_markup=keyboard, ) else: message_user = await get_banned_users_list(int(page_number) * 7 - 7, bot_db) await call.bot.edit_message_text( - chat_id=call.message.chat.id, + chat_id=call.message.chat.id, message_id=call.message.message_id, - text=message_user + text=message_user, ) - + buttons = await get_banned_users_buttons(bot_db) - keyboard = create_keyboard_with_pagination(page_number, len(buttons), buttons, 'unlock') + keyboard = create_keyboard_with_pagination( + page_number, len(buttons), buttons, "unlock" + ) await call.bot.edit_message_reply_markup( - chat_id=call.message.chat.id, + chat_id=call.message.chat.id, message_id=call.message.message_id, - reply_markup=keyboard + reply_markup=keyboard, ) @@ -255,73 +290,81 @@ async def change_page( @track_file_operations("voice") @db_query_time("save_voice_message", "audio_moderate", "mixed") async def save_voice_message( - call: CallbackQuery, + call: CallbackQuery, bot_db: MagicData("bot_db"), settings: MagicData("settings"), - **kwargs - ): + **kwargs, +): try: - logger.info(f"Начинаем сохранение голосового сообщения. Message ID: {call.message.message_id}") - + logger.info( + f"Начинаем сохранение голосового сообщения. Message ID: {call.message.message_id}" + ) + # Создаем сервис для работы с аудио файлами audio_service = AudioFileService(bot_db) - + # Получаем ID пользователя из базы - user_id = await bot_db.get_user_id_by_message_id_for_voice_bot(call.message.message_id) + user_id = await bot_db.get_user_id_by_message_id_for_voice_bot( + call.message.message_id + ) logger.info(f"Получен user_id: {user_id}") - + # Генерируем имя файла file_name = await audio_service.generate_file_name(user_id) logger.info(f"Сгенерировано имя файла: {file_name}") - + # Собираем инфо о сообщении time_UTC = int(time.time()) date_added = datetime.fromtimestamp(time_UTC) - + # Получаем file_id из voice сообщения file_id = call.message.voice.file_id if call.message.voice else "" logger.info(f"Получен file_id: {file_id}") - + # ВАЖНО: Сначала скачиваем и сохраняем файл на диск logger.info("Начинаем скачивание и сохранение файла на диск...") await audio_service.download_and_save_audio(call.bot, call.message, file_name) logger.info("Файл успешно скачан и сохранен на диск") - + # Только после успешного сохранения файла - сохраняем в базу данных logger.info("Начинаем сохранение информации в базу данных...") await audio_service.save_audio_file(file_name, user_id, date_added, file_id) logger.info("Информация успешно сохранена в базу данных") - + # Удаляем сообщение из предложки logger.info("Удаляем сообщение из предложки...") await call.bot.delete_message( - chat_id=settings['Telegram']['group_for_posts'], - message_id=call.message.message_id + chat_id=settings["Telegram"]["group_for_posts"], + message_id=call.message.message_id, ) logger.info("Сообщение удалено из предложки") - + # Удаляем запись из таблицы audio_moderate logger.info("Удаляем запись из таблицы audio_moderate...") await bot_db.delete_audio_moderate_record(call.message.message_id) logger.info("Запись удалена из таблицы audio_moderate") - - await call.answer(text='Сохранено!', cache_time=3) + + await call.answer(text="Сохранено!", cache_time=3) logger.info(f"Голосовое сообщение успешно сохранено: {file_name}") - + except Exception as e: logger.error(f"Ошибка при сохранении голосового сообщения: {e}") logger.error(f"Traceback: {traceback.format_exc()}") - + # Дополнительная информация для диагностики try: - if 'call' in locals() and call.message: + if "call" in locals() and call.message: logger.error(f"Message ID: {call.message.message_id}") - logger.error(f"User ID: {user_id if 'user_id' in locals() else 'не определен'}") - logger.error(f"File name: {file_name if 'file_name' in locals() else 'не определен'}") + logger.error( + f"User ID: {user_id if 'user_id' in locals() else 'не определен'}" + ) + logger.error( + f"File name: {file_name if 'file_name' in locals() else 'не определен'}" + ) except: pass - - await call.answer(text='Ошибка при сохранении!', cache_time=3) + + await call.answer(text="Ошибка при сохранении!", cache_time=3) @callback_router.callback_query(F.data == CALLBACK_DELETE) @@ -329,23 +372,23 @@ async def save_voice_message( @track_errors("callback_handlers", "delete_voice_message") @db_query_time("delete_voice_message", "audio_moderate", "delete") async def delete_voice_message( - call: CallbackQuery, + call: CallbackQuery, bot_db: MagicData("bot_db"), settings: MagicData("settings"), - **kwargs - ): + **kwargs, +): try: # Удаляем сообщение из предложки await call.bot.delete_message( - chat_id=settings['Telegram']['group_for_posts'], - message_id=call.message.message_id + chat_id=settings["Telegram"]["group_for_posts"], + message_id=call.message.message_id, ) - + # Удаляем запись из таблицы audio_moderate await bot_db.delete_audio_moderate_record(call.message.message_id) - - await call.answer(text='Удалено!', cache_time=3) - + + await call.answer(text="Удалено!", cache_time=3) + except Exception as e: logger.error(f"Ошибка при удалении голосового сообщения: {e}") - await call.answer(text='Ошибка при удалении!', cache_time=3) + await call.answer(text="Ошибка при удалении!", cache_time=3) diff --git a/helper_bot/handlers/callback/constants.py b/helper_bot/handlers/callback/constants.py index 1ce2b75..6aee85a 100644 --- a/helper_bot/handlers/callback/constants.py +++ b/helper_bot/handlers/callback/constants.py @@ -33,9 +33,9 @@ ERROR_BOT_BLOCKED = "Forbidden: bot was blocked by the user" # Callback to command mapping for metrics CALLBACK_COMMAND_MAPPING: Final[Dict[str, str]] = { "publish": "publish", - "decline": "decline", + "decline": "decline", "ban": "ban", "unlock": "unlock", "return": "return", - "page": "page" + "page": "page", } diff --git a/helper_bot/handlers/callback/dependency_factory.py b/helper_bot/handlers/callback/dependency_factory.py index ec3f563..a8b376f 100644 --- a/helper_bot/handlers/callback/dependency_factory.py +++ b/helper_bot/handlers/callback/dependency_factory.py @@ -3,6 +3,7 @@ from typing import Callable from aiogram import Bot from aiogram.client.default import DefaultBotProperties from aiogram.fsm.context import FSMContext + from helper_bot.utils.base_dependency_factory import get_global_instance from .services import BanService, PostPublishService @@ -22,7 +23,7 @@ def get_post_publish_service() -> PostPublishService: def get_ban_service() -> BanService: """Фабрика для BanService""" bdf = get_global_instance() - + db = bdf.get_db() settings = bdf.settings return BanService(None, db, settings) diff --git a/helper_bot/handlers/callback/exceptions.py b/helper_bot/handlers/callback/exceptions.py index 5b1dc73..d135cfc 100644 --- a/helper_bot/handlers/callback/exceptions.py +++ b/helper_bot/handlers/callback/exceptions.py @@ -1,23 +1,28 @@ class UserBlockedBotError(Exception): """Исключение, возникающее когда пользователь заблокировал бота""" + pass class PostNotFoundError(Exception): """Исключение, возникающее когда пост не найден в базе данных""" + pass class UserNotFoundError(Exception): """Исключение, возникающее когда пользователь не найден в базе данных""" + pass class PublishError(Exception): """Общее исключение для ошибок публикации""" + pass class BanError(Exception): """Исключение для ошибок бана/разбана пользователей""" + pass diff --git a/helper_bot/handlers/callback/services.py b/helper_bot/handlers/callback/services.py index 51d0337..f9aff0d 100644 --- a/helper_bot/handlers/callback/services.py +++ b/helper_bot/handlers/callback/services.py @@ -4,42 +4,70 @@ from typing import Any, Dict from aiogram import Bot, types from aiogram.types import CallbackQuery + from helper_bot.keyboards.keyboards import create_keyboard_for_ban_reason -from helper_bot.utils.helper_func import (delete_user_blacklist, - get_text_message, send_audio_message, - send_media_group_to_channel, - send_photo_message, - send_text_message, - send_video_message, - send_video_note_message, - send_voice_message) +from helper_bot.utils.helper_func import ( + delete_user_blacklist, + get_text_message, + send_audio_message, + send_media_group_to_channel, + send_photo_message, + send_text_message, + send_video_message, + send_video_note_message, + send_voice_message, +) + # Local imports - metrics -from helper_bot.utils.metrics import (db_query_time, track_errors, - track_media_processing, track_time) +from helper_bot.utils.metrics import ( + db_query_time, + track_errors, + track_media_processing, + track_time, +) from logs.custom_logger import logger -from .constants import (CONTENT_TYPE_AUDIO, CONTENT_TYPE_MEDIA_GROUP, - CONTENT_TYPE_PHOTO, CONTENT_TYPE_TEXT, - CONTENT_TYPE_VIDEO, CONTENT_TYPE_VIDEO_NOTE, - CONTENT_TYPE_VOICE, ERROR_BOT_BLOCKED, - MESSAGE_POST_DECLINED, MESSAGE_POST_PUBLISHED, - MESSAGE_USER_BANNED_SPAM) -from .exceptions import (BanError, PostNotFoundError, PublishError, - UserBlockedBotError, UserNotFoundError) +from .constants import ( + CONTENT_TYPE_AUDIO, + CONTENT_TYPE_MEDIA_GROUP, + CONTENT_TYPE_PHOTO, + CONTENT_TYPE_TEXT, + CONTENT_TYPE_VIDEO, + CONTENT_TYPE_VIDEO_NOTE, + CONTENT_TYPE_VOICE, + ERROR_BOT_BLOCKED, + MESSAGE_POST_DECLINED, + MESSAGE_POST_PUBLISHED, + MESSAGE_USER_BANNED_SPAM, +) +from .exceptions import ( + BanError, + PostNotFoundError, + PublishError, + UserBlockedBotError, + UserNotFoundError, +) class PostPublishService: - def __init__(self, bot: Bot, db, settings: Dict[str, Any], s3_storage=None, scoring_manager=None): + def __init__( + self, + bot: Bot, + db, + settings: Dict[str, Any], + s3_storage=None, + scoring_manager=None, + ): # bot может быть None - в этом случае используем бота из контекста сообщения self.bot = bot self.db = db self.settings = settings self.s3_storage = s3_storage self.scoring_manager = scoring_manager - self.group_for_posts = settings['Telegram']['group_for_posts'] - self.main_public = settings['Telegram']['main_public'] - self.important_logs = settings['Telegram']['important_logs'] - + self.group_for_posts = settings["Telegram"]["group_for_posts"] + self.main_public = settings["Telegram"]["main_public"] + self.important_logs = settings["Telegram"]["important_logs"] + def _get_bot(self, message) -> Bot: """Получает бота из контекста сообщения или использует переданного""" if self.bot: @@ -54,14 +82,14 @@ class PostPublishService: if call.message.text == CONTENT_TYPE_MEDIA_GROUP: await self._publish_media_group(call) return - + # Проверяем, является ли сообщение частью медиагруппы (для обратной совместимости) if call.message.media_group_id: await self._publish_media_group(call) return - + content_type = call.message.content_type - + if content_type == CONTENT_TYPE_TEXT: await self._publish_text_post(call) elif content_type == CONTENT_TYPE_PHOTO: @@ -83,34 +111,50 @@ class PostPublishService: """Публикация текстового поста""" author_id = await self._get_author_id(call.message.message_id) - updated_rows = await self.db.update_status_by_message_id(call.message.message_id, "approved") + updated_rows = await self.db.update_status_by_message_id( + call.message.message_id, "approved" + ) if updated_rows == 0: - logger.error(f"Не удалось обновить статус поста message_id={call.message.message_id} на 'approved'") - raise PostNotFoundError(f"Пост с message_id={call.message.message_id} не найден в базе данных") - + logger.error( + f"Не удалось обновить статус поста message_id={call.message.message_id} на 'approved'" + ) + raise PostNotFoundError( + f"Пост с message_id={call.message.message_id} не найден в базе данных" + ) + # Получаем сырой текст и is_anonymous из базы - raw_text, is_anonymous = await self.db.get_post_text_and_anonymity_by_message_id(call.message.message_id) + raw_text, is_anonymous = ( + await self.db.get_post_text_and_anonymity_by_message_id( + call.message.message_id + ) + ) if raw_text is None: raw_text = "" - + # Получаем данные автора user = await self.db.get_user_by_id(author_id) if not user: raise PostNotFoundError(f"Пользователь {author_id} не найден в базе данных") - + # Формируем финальный текст с учетом is_anonymous - formatted_text = get_text_message(raw_text, user.first_name, user.username, is_anonymous) - - sent_message = await send_text_message(self.main_public, call.message, formatted_text) - + formatted_text = get_text_message( + raw_text, user.first_name, user.username, is_anonymous + ) + + sent_message = await send_text_message( + self.main_public, call.message, formatted_text + ) + # Сохраняем published_message_id await self.db.update_published_message_id( original_message_id=call.message.message_id, - published_message_id=sent_message.message_id + published_message_id=sent_message.message_id, ) - + await self._delete_post_and_notify_author(call, author_id) - logger.info(f'Текст сообщение опубликован в канале {self.main_public}, published_message_id={sent_message.message_id}.') + logger.info( + f"Текст сообщение опубликован в канале {self.main_public}, published_message_id={sent_message.message_id}." + ) @track_time("_publish_photo_post", "post_publish_service") @track_errors("post_publish_service", "_publish_photo_post") @@ -118,37 +162,58 @@ class PostPublishService: """Публикация поста с фото""" author_id = await self._get_author_id(call.message.message_id) - updated_rows = await self.db.update_status_by_message_id(call.message.message_id, "approved") + updated_rows = await self.db.update_status_by_message_id( + call.message.message_id, "approved" + ) if updated_rows == 0: - logger.error(f"Не удалось обновить статус поста message_id={call.message.message_id} на 'approved'") - raise PostNotFoundError(f"Пост с message_id={call.message.message_id} не найден в базе данных") - + logger.error( + f"Не удалось обновить статус поста message_id={call.message.message_id} на 'approved'" + ) + raise PostNotFoundError( + f"Пост с message_id={call.message.message_id} не найден в базе данных" + ) + # Получаем сырой текст и is_anonymous из базы - raw_text, is_anonymous = await self.db.get_post_text_and_anonymity_by_message_id(call.message.message_id) + raw_text, is_anonymous = ( + await self.db.get_post_text_and_anonymity_by_message_id( + call.message.message_id + ) + ) if raw_text is None: raw_text = "" - + # Получаем данные автора user = await self.db.get_user_by_id(author_id) if not user: raise PostNotFoundError(f"Пользователь {author_id} не найден в базе данных") - + # Формируем финальный текст с учетом is_anonymous - formatted_text = get_text_message(raw_text, user.first_name, user.username, is_anonymous) - - sent_message = await send_photo_message(self.main_public, call.message, call.message.photo[-1].file_id, formatted_text) - + formatted_text = get_text_message( + raw_text, user.first_name, user.username, is_anonymous + ) + + sent_message = await send_photo_message( + self.main_public, + call.message, + call.message.photo[-1].file_id, + formatted_text, + ) + # Сохраняем published_message_id await self.db.update_published_message_id( original_message_id=call.message.message_id, - published_message_id=sent_message.message_id + published_message_id=sent_message.message_id, ) - + # Сохраняем медиафайл из опубликованного поста (используем уже сохраненный файл) - await self._save_published_post_content(sent_message, sent_message.message_id, call.message.message_id) - + await self._save_published_post_content( + sent_message, sent_message.message_id, call.message.message_id + ) + await self._delete_post_and_notify_author(call, author_id) - logger.info(f'Пост с фото опубликован в канале {self.main_public}, published_message_id={sent_message.message_id}.') + logger.info( + f"Пост с фото опубликован в канале {self.main_public}, published_message_id={sent_message.message_id}." + ) @track_time("_publish_video_post", "post_publish_service") @track_errors("post_publish_service", "_publish_video_post") @@ -156,37 +221,55 @@ class PostPublishService: """Публикация поста с видео""" author_id = await self._get_author_id(call.message.message_id) - updated_rows = await self.db.update_status_by_message_id(call.message.message_id, "approved") + updated_rows = await self.db.update_status_by_message_id( + call.message.message_id, "approved" + ) if updated_rows == 0: - logger.error(f"Не удалось обновить статус поста message_id={call.message.message_id} на 'approved'") - raise PostNotFoundError(f"Пост с message_id={call.message.message_id} не найден в базе данных") - + logger.error( + f"Не удалось обновить статус поста message_id={call.message.message_id} на 'approved'" + ) + raise PostNotFoundError( + f"Пост с message_id={call.message.message_id} не найден в базе данных" + ) + # Получаем сырой текст и is_anonymous из базы - raw_text, is_anonymous = await self.db.get_post_text_and_anonymity_by_message_id(call.message.message_id) + raw_text, is_anonymous = ( + await self.db.get_post_text_and_anonymity_by_message_id( + call.message.message_id + ) + ) if raw_text is None: raw_text = "" - + # Получаем данные автора user = await self.db.get_user_by_id(author_id) if not user: raise PostNotFoundError(f"Пользователь {author_id} не найден в базе данных") - + # Формируем финальный текст с учетом is_anonymous - formatted_text = get_text_message(raw_text, user.first_name, user.username, is_anonymous) - - sent_message = await send_video_message(self.main_public, call.message, call.message.video.file_id, formatted_text) - + formatted_text = get_text_message( + raw_text, user.first_name, user.username, is_anonymous + ) + + sent_message = await send_video_message( + self.main_public, call.message, call.message.video.file_id, formatted_text + ) + # Сохраняем published_message_id await self.db.update_published_message_id( original_message_id=call.message.message_id, - published_message_id=sent_message.message_id + published_message_id=sent_message.message_id, ) - + # Сохраняем медиафайл из опубликованного поста (используем уже сохраненный файл) - await self._save_published_post_content(sent_message, sent_message.message_id, call.message.message_id) - + await self._save_published_post_content( + sent_message, sent_message.message_id, call.message.message_id + ) + await self._delete_post_and_notify_author(call, author_id) - logger.info(f'Пост с видео опубликован в канале {self.main_public}, published_message_id={sent_message.message_id}.') + logger.info( + f"Пост с видео опубликован в канале {self.main_public}, published_message_id={sent_message.message_id}." + ) @track_time("_publish_video_note_post", "post_publish_service") @track_errors("post_publish_service", "_publish_video_note_post") @@ -194,24 +277,36 @@ class PostPublishService: """Публикация поста с кружком""" author_id = await self._get_author_id(call.message.message_id) - updated_rows = await self.db.update_status_by_message_id(call.message.message_id, "approved") + updated_rows = await self.db.update_status_by_message_id( + call.message.message_id, "approved" + ) if updated_rows == 0: - logger.error(f"Не удалось обновить статус поста message_id={call.message.message_id} на 'approved'") - raise PostNotFoundError(f"Пост с message_id={call.message.message_id} не найден в базе данных") - - sent_message = await send_video_note_message(self.main_public, call.message, call.message.video_note.file_id) - + logger.error( + f"Не удалось обновить статус поста message_id={call.message.message_id} на 'approved'" + ) + raise PostNotFoundError( + f"Пост с message_id={call.message.message_id} не найден в базе данных" + ) + + sent_message = await send_video_note_message( + self.main_public, call.message, call.message.video_note.file_id + ) + # Сохраняем published_message_id await self.db.update_published_message_id( original_message_id=call.message.message_id, - published_message_id=sent_message.message_id + published_message_id=sent_message.message_id, ) - + # Сохраняем медиафайл из опубликованного поста (используем уже сохраненный файл) - await self._save_published_post_content(sent_message, sent_message.message_id, call.message.message_id) - + await self._save_published_post_content( + sent_message, sent_message.message_id, call.message.message_id + ) + await self._delete_post_and_notify_author(call, author_id) - logger.info(f'Пост с кружком опубликован в канале {self.main_public}, published_message_id={sent_message.message_id}.') + logger.info( + f"Пост с кружком опубликован в канале {self.main_public}, published_message_id={sent_message.message_id}." + ) @track_time("_publish_audio_post", "post_publish_service") @track_errors("post_publish_service", "_publish_audio_post") @@ -219,37 +314,55 @@ class PostPublishService: """Публикация поста с аудио""" author_id = await self._get_author_id(call.message.message_id) - updated_rows = await self.db.update_status_by_message_id(call.message.message_id, "approved") + updated_rows = await self.db.update_status_by_message_id( + call.message.message_id, "approved" + ) if updated_rows == 0: - logger.error(f"Не удалось обновить статус поста message_id={call.message.message_id} на 'approved'") - raise PostNotFoundError(f"Пост с message_id={call.message.message_id} не найден в базе данных") - + logger.error( + f"Не удалось обновить статус поста message_id={call.message.message_id} на 'approved'" + ) + raise PostNotFoundError( + f"Пост с message_id={call.message.message_id} не найден в базе данных" + ) + # Получаем сырой текст и is_anonymous из базы - raw_text, is_anonymous = await self.db.get_post_text_and_anonymity_by_message_id(call.message.message_id) + raw_text, is_anonymous = ( + await self.db.get_post_text_and_anonymity_by_message_id( + call.message.message_id + ) + ) if raw_text is None: raw_text = "" - + # Получаем данные автора user = await self.db.get_user_by_id(author_id) if not user: raise PostNotFoundError(f"Пользователь {author_id} не найден в базе данных") - + # Формируем финальный текст с учетом is_anonymous - formatted_text = get_text_message(raw_text, user.first_name, user.username, is_anonymous) - - sent_message = await send_audio_message(self.main_public, call.message, call.message.audio.file_id, formatted_text) - + formatted_text = get_text_message( + raw_text, user.first_name, user.username, is_anonymous + ) + + sent_message = await send_audio_message( + self.main_public, call.message, call.message.audio.file_id, formatted_text + ) + # Сохраняем published_message_id await self.db.update_published_message_id( original_message_id=call.message.message_id, - published_message_id=sent_message.message_id + published_message_id=sent_message.message_id, ) - + # Сохраняем медиафайл из опубликованного поста (используем уже сохраненный файл) - await self._save_published_post_content(sent_message, sent_message.message_id, call.message.message_id) - + await self._save_published_post_content( + sent_message, sent_message.message_id, call.message.message_id + ) + await self._delete_post_and_notify_author(call, author_id) - logger.info(f'Пост с аудио опубликован в канале {self.main_public}, published_message_id={sent_message.message_id}.') + logger.info( + f"Пост с аудио опубликован в канале {self.main_public}, published_message_id={sent_message.message_id}." + ) @track_time("_publish_voice_post", "post_publish_service") @track_errors("post_publish_service", "_publish_voice_post") @@ -257,24 +370,36 @@ class PostPublishService: """Публикация поста с войсом""" author_id = await self._get_author_id(call.message.message_id) - updated_rows = await self.db.update_status_by_message_id(call.message.message_id, "approved") + updated_rows = await self.db.update_status_by_message_id( + call.message.message_id, "approved" + ) if updated_rows == 0: - logger.error(f"Не удалось обновить статус поста message_id={call.message.message_id} на 'approved'") - raise PostNotFoundError(f"Пост с message_id={call.message.message_id} не найден в базе данных") - - sent_message = await send_voice_message(self.main_public, call.message, call.message.voice.file_id) - + logger.error( + f"Не удалось обновить статус поста message_id={call.message.message_id} на 'approved'" + ) + raise PostNotFoundError( + f"Пост с message_id={call.message.message_id} не найден в базе данных" + ) + + sent_message = await send_voice_message( + self.main_public, call.message, call.message.voice.file_id + ) + # Сохраняем published_message_id await self.db.update_published_message_id( original_message_id=call.message.message_id, - published_message_id=sent_message.message_id + published_message_id=sent_message.message_id, ) - + # Сохраняем медиафайл из опубликованного поста (используем уже сохраненный файл) - await self._save_published_post_content(sent_message, sent_message.message_id, call.message.message_id) - + await self._save_published_post_content( + sent_message, sent_message.message_id, call.message.message_id + ) + await self._delete_post_and_notify_author(call, author_id) - logger.info(f'Пост с войсом опубликован в канале {self.main_public}, published_message_id={sent_message.message_id}.') + logger.info( + f"Пост с войсом опубликован в канале {self.main_public}, published_message_id={sent_message.message_id}." + ) @track_time("_publish_media_group", "post_publish_service") @track_errors("post_publish_service", "_publish_media_group") @@ -283,91 +408,130 @@ class PostPublishService: """Публикация медиагруппы""" try: helper_message_id = call.message.message_id - - media_group_message_ids = await self.db.get_post_ids_by_helper_id(helper_message_id) + + media_group_message_ids = await self.db.get_post_ids_by_helper_id( + helper_message_id + ) if not media_group_message_ids: - logger.error(f"_publish_media_group: Не найдены message_id медиагруппы для helper_message_id={helper_message_id}") + logger.error( + f"_publish_media_group: Не найдены message_id медиагруппы для helper_message_id={helper_message_id}" + ) raise PublishError("Не найдены message_id медиагруппы в базе данных") - - post_content = await self.db.get_post_content_by_helper_id(helper_message_id) + + post_content = await self.db.get_post_content_by_helper_id( + helper_message_id + ) if not post_content: - logger.error(f"_publish_media_group: Контент медиагруппы не найден в базе данных для helper_message_id={helper_message_id}") + logger.error( + f"_publish_media_group: Контент медиагруппы не найден в базе данных для helper_message_id={helper_message_id}" + ) raise PublishError("Контент медиагруппы не найден в базе данных") - - raw_text, is_anonymous = await self.db.get_post_text_and_anonymity_by_helper_id(helper_message_id) + + raw_text, is_anonymous = ( + await self.db.get_post_text_and_anonymity_by_helper_id( + helper_message_id + ) + ) if raw_text is None: raw_text = "" - - author_id = await self.db.get_author_id_by_helper_message_id(helper_message_id) + + author_id = await self.db.get_author_id_by_helper_message_id( + helper_message_id + ) if not author_id: - logger.error(f"_publish_media_group: Автор не найден для медиагруппы helper_message_id={helper_message_id}") - raise PostNotFoundError(f"Автор не найден для медиагруппы {helper_message_id}") - + logger.error( + f"_publish_media_group: Автор не найден для медиагруппы helper_message_id={helper_message_id}" + ) + raise PostNotFoundError( + f"Автор не найден для медиагруппы {helper_message_id}" + ) + user = await self.db.get_user_by_id(author_id) if not user: - raise PostNotFoundError(f"Пользователь {author_id} не найден в базе данных") - - formatted_text = get_text_message(raw_text, user.first_name, user.username, is_anonymous) - + raise PostNotFoundError( + f"Пользователь {author_id} не найден в базе данных" + ) + + formatted_text = get_text_message( + raw_text, user.first_name, user.username, is_anonymous + ) + try: await self._get_bot(call.message).delete_messages( - chat_id=self.group_for_posts, - message_ids=media_group_message_ids + chat_id=self.group_for_posts, message_ids=media_group_message_ids ) except Exception as e: - logger.warning(f"_publish_media_group: Ошибка при удалении медиагруппы из чата модерации: {e}") - + logger.warning( + f"_publish_media_group: Ошибка при удалении медиагруппы из чата модерации: {e}" + ) + sent_messages = await send_media_group_to_channel( - bot=self._get_bot(call.message), - chat_id=self.main_public, - post_content=post_content, + bot=self._get_bot(call.message), + chat_id=self.main_public, + post_content=post_content, post_text=formatted_text, - s3_storage=self.s3_storage + s3_storage=self.s3_storage, ) - + if len(sent_messages) == len(media_group_message_ids): for i, original_message_id in enumerate(media_group_message_ids): published_message_id = sent_messages[i].message_id try: await self.db.update_published_message_id( original_message_id=original_message_id, - published_message_id=published_message_id + published_message_id=published_message_id, + ) + await self._save_published_post_content( + sent_messages[i], published_message_id, original_message_id ) - await self._save_published_post_content(sent_messages[i], published_message_id, original_message_id) except Exception as e: - logger.warning(f"_publish_media_group: Ошибка при сохранении published_message_id для {original_message_id}: {e}") + logger.warning( + f"_publish_media_group: Ошибка при сохранении published_message_id для {original_message_id}: {e}" + ) else: - logger.warning(f"_publish_media_group: Количество опубликованных сообщений ({len(sent_messages)}) не совпадает с количеством оригинальных ({len(media_group_message_ids)})") + logger.warning( + f"_publish_media_group: Количество опубликованных сообщений ({len(sent_messages)}) не совпадает с количеством оригинальных ({len(media_group_message_ids)})" + ) + + await self.db.update_status_for_media_group_by_helper_id( + helper_message_id, "approved" + ) - await self.db.update_status_for_media_group_by_helper_id(helper_message_id, "approved") - # Удаляем helper сообщение - это критично, делаем это всегда try: await self._get_bot(call.message).delete_message( - chat_id=self.group_for_posts, - message_id=helper_message_id + chat_id=self.group_for_posts, message_id=helper_message_id ) except Exception as e: - logger.warning(f"_publish_media_group: Ошибка при удалении helper сообщения: {e}") - + logger.warning( + f"_publish_media_group: Ошибка при удалении helper сообщения: {e}" + ) + try: await send_text_message(author_id, call.message, MESSAGE_POST_PUBLISHED) except Exception as e: if str(e) == ERROR_BOT_BLOCKED: - logger.warning(f"_publish_media_group: Пользователь {author_id} заблокировал бота") + logger.warning( + f"_publish_media_group: Пользователь {author_id} заблокировал бота" + ) raise UserBlockedBotError("Пользователь заблокировал бота") - logger.error(f"_publish_media_group: Ошибка при отправке уведомления автору: {e}") - + logger.error( + f"_publish_media_group: Ошибка при отправке уведомления автору: {e}" + ) + except Exception as e: - logger.error(f"_publish_media_group: Ошибка при публикации медиагруппы: {e}") + logger.error( + f"_publish_media_group: Ошибка при публикации медиагруппы: {e}" + ) # Пытаемся удалить helper сообщение даже при ошибке try: await self._get_bot(call.message).delete_message( - chat_id=self.group_for_posts, - message_id=call.message.message_id + chat_id=self.group_for_posts, message_id=call.message.message_id ) except Exception as delete_error: - logger.warning(f"_publish_media_group: Не удалось удалить helper сообщение при ошибке: {delete_error}") + logger.warning( + f"_publish_media_group: Не удалось удалить helper сообщение при ошибке: {delete_error}" + ) raise PublishError(f"Не удалось опубликовать медиагруппу: {str(e)}") @track_time("decline_post", "post_publish_service") @@ -378,32 +542,50 @@ class PostPublishService: if call.message.text == CONTENT_TYPE_MEDIA_GROUP: await self._decline_media_group(call) return - + content_type = call.message.content_type - - if content_type in [CONTENT_TYPE_TEXT, CONTENT_TYPE_PHOTO, CONTENT_TYPE_AUDIO, - CONTENT_TYPE_VOICE, CONTENT_TYPE_VIDEO, CONTENT_TYPE_VIDEO_NOTE]: + + if content_type in [ + CONTENT_TYPE_TEXT, + CONTENT_TYPE_PHOTO, + CONTENT_TYPE_AUDIO, + CONTENT_TYPE_VOICE, + CONTENT_TYPE_VIDEO, + CONTENT_TYPE_VIDEO_NOTE, + ]: await self._decline_single_post(call) else: - logger.error(f"Неподдерживаемый тип контента для отклонения: {content_type}") - raise PublishError(f"Неподдерживаемый тип контента для отклонения: {content_type}") + logger.error( + f"Неподдерживаемый тип контента для отклонения: {content_type}" + ) + raise PublishError( + f"Неподдерживаемый тип контента для отклонения: {content_type}" + ) @track_time("_decline_single_post", "post_publish_service") @track_errors("post_publish_service", "_decline_single_post") async def _decline_single_post(self, call: CallbackQuery) -> None: """Отклонение одиночного поста""" author_id = await self._get_author_id(call.message.message_id) - + # Обучаем RAG на отклоненном посте перед удалением await self._train_on_declined(call.message.message_id) - updated_rows = await self.db.update_status_by_message_id(call.message.message_id, "declined") + updated_rows = await self.db.update_status_by_message_id( + call.message.message_id, "declined" + ) if updated_rows == 0: - logger.error(f"Не удалось обновить статус поста message_id={call.message.message_id} на 'declined'") - raise PostNotFoundError(f"Пост с message_id={call.message.message_id} не найден в базе данных") - - await self._get_bot(call.message).delete_message(chat_id=self.group_for_posts, message_id=call.message.message_id) - + logger.error( + f"Не удалось обновить статус поста message_id={call.message.message_id} на 'declined'" + ) + raise PostNotFoundError( + f"Пост с message_id={call.message.message_id} не найден в базе данных" + ) + + await self._get_bot(call.message).delete_message( + chat_id=self.group_for_posts, message_id=call.message.message_id + ) + try: await send_text_message(author_id, call.message, MESSAGE_POST_DECLINED) except Exception as e: @@ -412,7 +594,9 @@ class PostPublishService: raise UserBlockedBotError("Пользователь заблокировал бота") logger.error(f"Ошибка при отправке уведомления автору {author_id}: {e}") raise - logger.info(f'Сообщение отклонено админом {call.from_user.full_name} (ID: {call.from_user.id}).') + logger.info( + f"Сообщение отклонено админом {call.from_user.full_name} (ID: {call.from_user.id})." + ) @track_time("_decline_media_group", "post_publish_service") @track_errors("post_publish_service", "_decline_media_group") @@ -420,10 +604,14 @@ class PostPublishService: async def _decline_media_group(self, call: CallbackQuery) -> None: """Отклонение медиагруппы""" helper_message_id = call.message.message_id - - await self.db.update_status_for_media_group_by_helper_id(helper_message_id, "declined") - media_group_message_ids = await self.db.get_post_ids_by_helper_id(helper_message_id) + await self.db.update_status_for_media_group_by_helper_id( + helper_message_id, "declined" + ) + + media_group_message_ids = await self.db.get_post_ids_by_helper_id( + helper_message_id + ) message_ids_to_delete = media_group_message_ids.copy() message_ids_to_delete.append(helper_message_id) @@ -432,19 +620,22 @@ class PostPublishService: try: await self._get_bot(call.message).delete_messages( - chat_id=self.group_for_posts, - message_ids=message_ids_to_delete + chat_id=self.group_for_posts, message_ids=message_ids_to_delete ) except Exception as e: logger.warning(f"_decline_media_group: Ошибка при удалении сообщений: {e}") - + try: await send_text_message(author_id, call.message, MESSAGE_POST_DECLINED) except Exception as e: if str(e) == ERROR_BOT_BLOCKED: - logger.warning(f"_decline_media_group: Пользователь {author_id} заблокировал бота") + logger.warning( + f"_decline_media_group: Пользователь {author_id} заблокировал бота" + ) raise UserBlockedBotError("Пользователь заблокировал бота") - logger.error(f"_decline_media_group: Ошибка при отправке уведомления автору {author_id}: {e}") + logger.error( + f"_decline_media_group: Ошибка при отправке уведомления автору {author_id}: {e}" + ) raise @track_time("_get_author_id", "post_publish_service") @@ -464,7 +655,7 @@ class PostPublishService: author_id = await self.db.get_author_id_by_helper_message_id(message_id) if author_id: return author_id - + # Если не найден, ищем по основному message_id медиагруппы # Для этого нужно найти связанные сообщения медиагруппы try: @@ -478,7 +669,7 @@ class PostPublishService: return author_id except Exception as e: logger.warning(f"Не удалось найти автора через связанные сообщения: {e}") - + # Если все способы не сработали, ищем напрямую author_id = await self.db.get_author_id_by_message_id(message_id) if not author_id: @@ -487,38 +678,44 @@ class PostPublishService: @track_time("_delete_post_and_notify_author", "post_publish_service") @track_errors("post_publish_service", "_delete_post_and_notify_author") - async def _delete_post_and_notify_author(self, call: CallbackQuery, author_id: int) -> None: + async def _delete_post_and_notify_author( + self, call: CallbackQuery, author_id: int + ) -> None: """Удаление поста и уведомление автора""" # Получаем текст поста для обучения RAG перед удалением await self._train_on_published(call.message.message_id) - - await self._get_bot(call.message).delete_message(chat_id=self.group_for_posts, message_id=call.message.message_id) - + + await self._get_bot(call.message).delete_message( + chat_id=self.group_for_posts, message_id=call.message.message_id + ) + try: await send_text_message(author_id, call.message, MESSAGE_POST_PUBLISHED) except Exception as e: if str(e) == ERROR_BOT_BLOCKED: raise UserBlockedBotError("Пользователь заблокировал бота") raise - + async def _train_on_published(self, message_id: int) -> None: """Обучает RAG на опубликованном посте.""" if not self.scoring_manager: return - + try: text = await self.db.get_post_text_by_message_id(message_id) if text and text.strip() and text != "^": await self.scoring_manager.on_post_published(text) logger.debug(f"RAG обучен на опубликованном посте: {message_id}") except Exception as e: - logger.error(f"Ошибка обучения RAG на опубликованном посте {message_id}: {e}") - + logger.error( + f"Ошибка обучения RAG на опубликованном посте {message_id}: {e}" + ) + async def _train_on_declined(self, message_id: int) -> None: """Обучает RAG на отклоненном посте.""" if not self.scoring_manager: return - + try: text = await self.db.get_post_text_by_message_id(message_id) if text and text.strip() and text != "^": @@ -530,16 +727,22 @@ class PostPublishService: @track_time("_delete_media_group_and_notify_author", "post_publish_service") @track_errors("post_publish_service", "_delete_media_group_and_notify_author") @track_media_processing("media_group") - async def _delete_media_group_and_notify_author(self, call: CallbackQuery, author_id: int) -> None: + async def _delete_media_group_and_notify_author( + self, call: CallbackQuery, author_id: int + ) -> None: """Удаление медиагруппы и уведомление автора (legacy метод, используется для обратной совместимости)""" helper_message_id = call.message.message_id - - media_group_message_ids = await self.db.get_post_ids_by_helper_id(helper_message_id) + + media_group_message_ids = await self.db.get_post_ids_by_helper_id( + helper_message_id + ) message_ids_to_delete = media_group_message_ids.copy() message_ids_to_delete.append(helper_message_id) - - await self._get_bot(call.message).delete_messages(chat_id=self.group_for_posts, message_ids=message_ids_to_delete) + + await self._get_bot(call.message).delete_messages( + chat_id=self.group_for_posts, message_ids=message_ids_to_delete + ) try: await send_text_message(author_id, call.message, MESSAGE_POST_PUBLISHED) except Exception as e: @@ -549,30 +752,47 @@ class PostPublishService: @track_time("_save_published_post_content", "post_publish_service") @track_errors("post_publish_service", "_save_published_post_content") - async def _save_published_post_content(self, published_message: types.Message, published_message_id: int, original_message_id: int) -> None: + async def _save_published_post_content( + self, + published_message: types.Message, + published_message_id: int, + original_message_id: int, + ) -> None: """Сохраняет ссылку на медиафайл из опубликованного поста (файл уже в S3 или на диске).""" try: # Получаем уже сохраненный путь/S3 ключ из оригинального поста - saved_content = await self.db.get_post_content_by_message_id(original_message_id) - + saved_content = await self.db.get_post_content_by_message_id( + original_message_id + ) + if saved_content and len(saved_content) > 0: # Копируем тот же путь/S3 ключ file_path, content_type = saved_content[0] - logger.debug(f"Копируем путь/S3 ключ для опубликованного поста: {file_path}") - + logger.debug( + f"Копируем путь/S3 ключ для опубликованного поста: {file_path}" + ) + success = await self.db.add_published_post_content( published_message_id=published_message_id, content_path=file_path, # Тот же путь/S3 ключ - content_type=content_type + content_type=content_type, ) if success: - logger.info(f"Ссылка на файл сохранена для опубликованного поста: published_message_id={published_message_id}, path={file_path}") + logger.info( + f"Ссылка на файл сохранена для опубликованного поста: published_message_id={published_message_id}, path={file_path}" + ) else: - logger.warning(f"Не удалось сохранить ссылку на файл: published_message_id={published_message_id}") + logger.warning( + f"Не удалось сохранить ссылку на файл: published_message_id={published_message_id}" + ) else: - logger.warning(f"Контент не найден для оригинального поста message_id={original_message_id}") + logger.warning( + f"Контент не найден для оригинального поста message_id={original_message_id}" + ) except Exception as e: - logger.error(f"Ошибка при сохранении ссылки на контент опубликованного поста {published_message_id}: {e}") + logger.error( + f"Ошибка при сохранении ссылки на контент опубликованного поста {published_message_id}: {e}" + ) # Не прерываем публикацию, если сохранение контента не удалось @@ -581,8 +801,8 @@ class BanService: self.bot = bot self.db = db self.settings = settings - self.group_for_posts = settings['Telegram']['group_for_posts'] - self.important_logs = settings['Telegram']['important_logs'] + self.group_for_posts = settings["Telegram"]["group_for_posts"] + self.important_logs = settings["Telegram"]["important_logs"] def _get_bot(self, message) -> Bot: """Получает бота из контекста сообщения или использует переданного""" @@ -597,16 +817,22 @@ class BanService: """Бан пользователя за спам""" # Если это helper-сообщение медиагруппы, используем специальный метод if call.message.text == CONTENT_TYPE_MEDIA_GROUP: - author_id = await self.db.get_author_id_by_helper_message_id(call.message.message_id) + author_id = await self.db.get_author_id_by_helper_message_id( + call.message.message_id + ) else: - author_id = await self.db.get_author_id_by_message_id(call.message.message_id) - + author_id = await self.db.get_author_id_by_message_id( + call.message.message_id + ) + if not author_id: - raise UserNotFoundError(f"Автор не найден для сообщения {call.message.message_id}") - + raise UserNotFoundError( + f"Автор не найден для сообщения {call.message.message_id}" + ) + current_date = datetime.now() date_to_unban = int((current_date + timedelta(days=7)).timestamp()) - + ban_author_id = call.from_user.id await self.db.set_user_blacklist( @@ -616,7 +842,7 @@ class BanService: date_to_unban=date_to_unban, ban_author=ban_author_id, ) - + # Обновляем статус поста на declined if call.message.text == CONTENT_TYPE_MEDIA_GROUP: # Для медиагруппы обновляем статус по helper_message_id @@ -624,23 +850,33 @@ class BanService: call.message.message_id, "declined" ) if updated_rows == 0: - logger.warning(f"Не удалось обновить статус медиагруппы helper_message_id={call.message.message_id} на 'declined'") + logger.warning( + f"Не удалось обновить статус медиагруппы helper_message_id={call.message.message_id} на 'declined'" + ) else: # Для одиночного поста обновляем статус по message_id - updated_rows = await self.db.update_status_by_message_id(call.message.message_id, "declined") + updated_rows = await self.db.update_status_by_message_id( + call.message.message_id, "declined" + ) if updated_rows == 0: - logger.warning(f"Не удалось обновить статус поста message_id={call.message.message_id} на 'declined'") - - await self._get_bot(call.message).delete_message(chat_id=self.group_for_posts, message_id=call.message.message_id) - + logger.warning( + f"Не удалось обновить статус поста message_id={call.message.message_id} на 'declined'" + ) + + await self._get_bot(call.message).delete_message( + chat_id=self.group_for_posts, message_id=call.message.message_id + ) + date_str = (current_date + timedelta(days=7)).strftime("%d.%m.%Y %H:%M") try: - await send_text_message(author_id, call.message, MESSAGE_USER_BANNED_SPAM.format(date=date_str)) + await send_text_message( + author_id, call.message, MESSAGE_USER_BANNED_SPAM.format(date=date_str) + ) except Exception as e: if str(e) == ERROR_BOT_BLOCKED: raise UserBlockedBotError("Пользователь заблокировал бота") raise - + logger.info(f"Пользователь {author_id} заблокирован за спам до {date_str}") @track_time("ban_user", "ban_service") @@ -650,7 +886,7 @@ class BanService: user_name = await self.db.get_username(int(user_id)) if not user_name: raise UserNotFoundError(f"Пользователь с ID {user_id} не найден в базе") - + return user_name @track_time("unlock_user", "ban_service") @@ -661,7 +897,7 @@ class BanService: user_name = await self.db.get_username(int(user_id)) if not user_name: raise UserNotFoundError(f"Пользователь с ID {user_id} не найден в базе") - + await delete_user_blacklist(int(user_id), self.db) logger.info(f"Разблокирован пользователь с ID: {user_id} username:{user_name}") return user_name diff --git a/helper_bot/handlers/group/__init__.py b/helper_bot/handlers/group/__init__.py index c78eba7..a060f32 100644 --- a/helper_bot/handlers/group/__init__.py +++ b/helper_bot/handlers/group/__init__.py @@ -6,27 +6,24 @@ from .constants import ERROR_MESSAGES, FSM_STATES from .decorators import error_handler from .exceptions import NoReplyToMessageError, UserNotFoundError from .group_handlers import GroupHandlers, create_group_handlers, group_router + # Local imports - services from .services import AdminReplyService, DatabaseProtocol __all__ = [ # Main components - 'group_router', - 'create_group_handlers', - 'GroupHandlers', - + "group_router", + "create_group_handlers", + "GroupHandlers", # Services - 'AdminReplyService', - 'DatabaseProtocol', - + "AdminReplyService", + "DatabaseProtocol", # Constants - 'FSM_STATES', - 'ERROR_MESSAGES', - + "FSM_STATES", + "ERROR_MESSAGES", # Exceptions - 'NoReplyToMessageError', - 'UserNotFoundError', - + "NoReplyToMessageError", + "UserNotFoundError", # Utilities - 'error_handler' + "error_handler", ] diff --git a/helper_bot/handlers/group/constants.py b/helper_bot/handlers/group/constants.py index 96446f2..e893148 100644 --- a/helper_bot/handlers/group/constants.py +++ b/helper_bot/handlers/group/constants.py @@ -3,12 +3,10 @@ from typing import Dict, Final # FSM States -FSM_STATES: Final[Dict[str, str]] = { - "CHAT": "CHAT" -} +FSM_STATES: Final[Dict[str, str]] = {"CHAT": "CHAT"} # Error messages ERROR_MESSAGES: Final[Dict[str, str]] = { "NO_REPLY_TO_MESSAGE": "Блять, выдели сообщение!", - "USER_NOT_FOUND": "Не могу найти кому ответить в базе, проебали сообщение." + "USER_NOT_FOUND": "Не могу найти кому ответить в базе, проебали сообщение.", } diff --git a/helper_bot/handlers/group/decorators.py b/helper_bot/handlers/group/decorators.py index 8cb0d3a..b1511a0 100644 --- a/helper_bot/handlers/group/decorators.py +++ b/helper_bot/handlers/group/decorators.py @@ -6,12 +6,14 @@ from typing import Any, Callable # Third-party imports from aiogram import types + # Local imports from logs.custom_logger import logger def error_handler(func: Callable[..., Any]) -> Callable[..., Any]: """Decorator for centralized error handling""" + async def wrapper(*args: Any, **kwargs: Any) -> Any: try: return await func(*args, **kwargs) @@ -19,18 +21,23 @@ def error_handler(func: Callable[..., Any]) -> Callable[..., Any]: logger.error(f"Error in {func.__name__}: {str(e)}") # Try to send error to logs if possible try: - message = next((arg for arg in args if isinstance(arg, types.Message)), None) - if message and hasattr(message, 'bot'): - from helper_bot.utils.base_dependency_factory import \ - get_global_instance + message = next( + (arg for arg in args if isinstance(arg, types.Message)), None + ) + if message and hasattr(message, "bot"): + from helper_bot.utils.base_dependency_factory import ( + get_global_instance, + ) + bdf = get_global_instance() - important_logs = bdf.settings['Telegram']['important_logs'] + important_logs = bdf.settings["Telegram"]["important_logs"] await message.bot.send_message( chat_id=important_logs, - text=f"Произошла ошибка в {func.__name__}: {str(e)}\n\nTraceback:\n{traceback.format_exc()}" + text=f"Произошла ошибка в {func.__name__}: {str(e)}\n\nTraceback:\n{traceback.format_exc()}", ) except Exception: # If we can't log the error, at least it was logged to logger pass raise + return wrapper diff --git a/helper_bot/handlers/group/exceptions.py b/helper_bot/handlers/group/exceptions.py index e10a41c..bccef54 100644 --- a/helper_bot/handlers/group/exceptions.py +++ b/helper_bot/handlers/group/exceptions.py @@ -3,9 +3,11 @@ class NoReplyToMessageError(Exception): """Raised when admin tries to reply without selecting a message""" + pass class UserNotFoundError(Exception): """Raised when user is not found in database for the given message_id""" + pass diff --git a/helper_bot/handlers/group/group_handlers.py b/helper_bot/handlers/group/group_handlers.py index 8d14db8..c21fddc 100644 --- a/helper_bot/handlers/group/group_handlers.py +++ b/helper_bot/handlers/group/group_handlers.py @@ -3,11 +3,14 @@ # Third-party imports from aiogram import Router, types from aiogram.fsm.context import FSMContext + # Local imports - filters from database.async_db import AsyncBotDB from helper_bot.filters.main import ChatTypeFilter + # Local imports - metrics from helper_bot.utils.metrics import metrics, track_errors, track_time + # Local imports - utilities from logs.custom_logger import logger @@ -20,25 +23,24 @@ from .services import AdminReplyService class GroupHandlers: """Main handler class for group messages""" - + def __init__(self, db: AsyncBotDB, keyboard_markup: types.ReplyKeyboardMarkup): self.db = db self.keyboard_markup = keyboard_markup self.admin_reply_service = AdminReplyService(db) - + # Create router self.router = Router() - + # Register handlers self._register_handlers() - + def _register_handlers(self): """Register all message handlers""" self.router.message.register( - self.handle_message, - ChatTypeFilter(chat_type=["group", "supergroup"]) + self.handle_message, ChatTypeFilter(chat_type=["group", "supergroup"]) ) - + @error_handler @track_errors("group_handlers", "handle_message") @track_time("handle_message", "group_handlers") @@ -46,44 +48,46 @@ class GroupHandlers: """Handle admin reply to user through group chat""" logger.info( - f'Получено сообщение в группе {message.chat.title} (ID: {message.chat.id}) ' + f"Получено сообщение в группе {message.chat.title} (ID: {message.chat.id}) " f'от пользователя {message.from_user.full_name} (ID: {message.from_user.id}): "{message.text}"' ) - + # Check if message is a reply if not message.reply_to_message: await message.answer(ERROR_MESSAGES["NO_REPLY_TO_MESSAGE"]) logger.warning( - f'В группе {message.chat.title} (ID: {message.chat.id}) ' - f'админ не выделил сообщение для ответа.' + f"В группе {message.chat.title} (ID: {message.chat.id}) " + f"админ не выделил сообщение для ответа." ) return - + message_id = message.reply_to_message.message_id reply_text = message.text - + try: # Get user ID for reply chat_id = await self.admin_reply_service.get_user_id_for_reply(message_id) - + # Send reply to user await self.admin_reply_service.send_reply_to_user( chat_id, message, reply_text, self.keyboard_markup ) - + # Set state await state.set_state(FSM_STATES["CHAT"]) - + except UserNotFoundError: await message.answer(ERROR_MESSAGES["USER_NOT_FOUND"]) logger.error( - f'Ошибка при поиске пользователя в базе для ответа на сообщение: {reply_text} ' - f'в группе {message.chat.title} (ID сообщения: {message.message_id})' + f"Ошибка при поиске пользователя в базе для ответа на сообщение: {reply_text} " + f"в группе {message.chat.title} (ID сообщения: {message.message_id})" ) # Factory function to create handlers with dependencies -def create_group_handlers(db: AsyncBotDB, keyboard_markup: types.ReplyKeyboardMarkup) -> GroupHandlers: +def create_group_handlers( + db: AsyncBotDB, keyboard_markup: types.ReplyKeyboardMarkup +) -> GroupHandlers: """Create group handlers instance with dependencies""" return GroupHandlers(db, keyboard_markup) @@ -91,21 +95,23 @@ def create_group_handlers(db: AsyncBotDB, keyboard_markup: types.ReplyKeyboardMa # Legacy router for backward compatibility group_router = Router() + # Initialize with global dependencies (for backward compatibility) def init_legacy_router(): """Initialize legacy router with global dependencies""" global group_router - + from helper_bot.keyboards.keyboards import get_reply_keyboard_leave_chat from helper_bot.utils.base_dependency_factory import get_global_instance - + bdf = get_global_instance() - #TODO: поменять архитектуру и подключить правильный BotDB + # TODO: поменять архитектуру и подключить правильный BotDB db = bdf.get_db() keyboard_markup = get_reply_keyboard_leave_chat() - + handlers = create_group_handlers(db, keyboard_markup) group_router = handlers.router + # Initialize legacy router init_legacy_router() diff --git a/helper_bot/handlers/group/services.py b/helper_bot/handlers/group/services.py index 320f466..934887b 100644 --- a/helper_bot/handlers/group/services.py +++ b/helper_bot/handlers/group/services.py @@ -5,8 +5,10 @@ from typing import Optional, Protocol # Third-party imports from aiogram import types + # Local imports from helper_bot.utils.helper_func import send_text_message + # Local imports - metrics from helper_bot.utils.metrics import db_query_time, track_errors, track_time from logs.custom_logger import logger @@ -16,29 +18,32 @@ from .exceptions import NoReplyToMessageError, UserNotFoundError class DatabaseProtocol(Protocol): """Protocol for database operations""" + async def get_user_by_message_id(self, message_id: int) -> Optional[int]: ... - async def add_message(self, message_text: str, user_id: int, message_id: int, date: int = None): ... + async def add_message( + self, message_text: str, user_id: int, message_id: int, date: int = None + ): ... class AdminReplyService: """Service for admin reply operations""" - + def __init__(self, db: DatabaseProtocol) -> None: self.db = db - + @track_time("get_user_id_for_reply", "admin_reply_service") @track_errors("admin_reply_service", "get_user_id_for_reply") @db_query_time("get_user_id_for_reply", "users", "select") async def get_user_id_for_reply(self, message_id: int) -> int: """ Get user ID for reply by message ID. - + Args: message_id: ID of the message to reply to - + Returns: User ID for the reply - + Raises: UserNotFoundError: If user is not found in database """ @@ -46,19 +51,19 @@ class AdminReplyService: if user_id is None: raise UserNotFoundError(f"User not found for message_id: {message_id}") return user_id - + @track_time("send_reply_to_user", "admin_reply_service") @track_errors("admin_reply_service", "send_reply_to_user") async def send_reply_to_user( - self, - chat_id: int, - message: types.Message, - reply_text: str, - markup: types.ReplyKeyboardMarkup + self, + chat_id: int, + message: types.Message, + reply_text: str, + markup: types.ReplyKeyboardMarkup, ) -> None: """ Send reply to user. - + Args: chat_id: User's chat ID message: Original message from admin diff --git a/helper_bot/handlers/private/__init__.py b/helper_bot/handlers/private/__init__.py index a8e44f7..0a8be50 100644 --- a/helper_bot/handlers/private/__init__.py +++ b/helper_bot/handlers/private/__init__.py @@ -4,28 +4,25 @@ # Local imports - constants and utilities from .constants import BUTTON_TEXTS, ERROR_MESSAGES, FSM_STATES from .decorators import error_handler -from .private_handlers import (PrivateHandlers, create_private_handlers, - private_router) +from .private_handlers import PrivateHandlers, create_private_handlers, private_router + # Local imports - services from .services import BotSettings, PostService, StickerService, UserService __all__ = [ # Main components - 'private_router', - 'create_private_handlers', - 'PrivateHandlers', - + "private_router", + "create_private_handlers", + "PrivateHandlers", # Services - 'BotSettings', - 'UserService', - 'PostService', - 'StickerService', - + "BotSettings", + "UserService", + "PostService", + "StickerService", # Constants - 'FSM_STATES', - 'BUTTON_TEXTS', - 'ERROR_MESSAGES', - + "FSM_STATES", + "BUTTON_TEXTS", + "ERROR_MESSAGES", # Utilities - 'error_handler' + "error_handler", ] diff --git a/helper_bot/handlers/private/constants.py b/helper_bot/handlers/private/constants.py index 09ee96f..8f3ce85 100644 --- a/helper_bot/handlers/private/constants.py +++ b/helper_bot/handlers/private/constants.py @@ -7,7 +7,7 @@ FSM_STATES: Final[Dict[str, str]] = { "START": "START", "SUGGEST": "SUGGEST", "PRE_CHAT": "PRE_CHAT", - "CHAT": "CHAT" + "CHAT": "CHAT", } # Button texts @@ -18,7 +18,7 @@ BUTTON_TEXTS: Final[Dict[str, str]] = { "RETURN_TO_BOT": "Вернуться в бота", "WANT_STICKERS": "🤪Хочу стикеры", "CONNECT_ADMIN": "📩Связаться с админами", - "VOICE_BOT": "🎤Голосовой бот" + "VOICE_BOT": "🎤Голосовой бот", } # Button to command mapping for metrics @@ -29,15 +29,15 @@ BUTTON_COMMAND_MAPPING: Final[Dict[str, str]] = { "Вернуться в бота": "return_to_bot", "🤪Хочу стикеры": "want_stickers", "📩Связаться с админами": "connect_admin", - "🎤Голосовой бот": "voice_bot" + "🎤Голосовой бот": "voice_bot", } # Error messages ERROR_MESSAGES: Final[Dict[str, str]] = { "UNSUPPORTED_CONTENT": ( - 'Я пока не умею работать с таким сообщением. ' - 'Пришли текст и фото/фоты(ы). А лучше перешли это сообщение админу @kerrad1\n' - 'Мы добавим его к обработке если необходимо' + "Я пока не умею работать с таким сообщением. " + "Пришли текст и фото/фоты(ы). А лучше перешли это сообщение админу @kerrad1\n" + "Мы добавим его к обработке если необходимо" ), - "STICKERS_LINK": "Хорошо, лови, добавить можно отсюда: https://t.me/addstickers/love_biysk" + "STICKERS_LINK": "Хорошо, лови, добавить можно отсюда: https://t.me/addstickers/love_biysk", } diff --git a/helper_bot/handlers/private/decorators.py b/helper_bot/handlers/private/decorators.py index 2905664..1adabc0 100644 --- a/helper_bot/handlers/private/decorators.py +++ b/helper_bot/handlers/private/decorators.py @@ -6,12 +6,14 @@ from typing import Any, Callable # Third-party imports from aiogram import types + # Local imports from logs.custom_logger import logger def error_handler(func: Callable[..., Any]) -> Callable[..., Any]: """Decorator for centralized error handling""" + async def wrapper(*args: Any, **kwargs: Any) -> Any: try: return await func(*args, **kwargs) @@ -19,18 +21,23 @@ def error_handler(func: Callable[..., Any]) -> Callable[..., Any]: logger.error(f"Error in {func.__name__}: {str(e)}") # Try to send error to logs if possible try: - message = next((arg for arg in args if isinstance(arg, types.Message)), None) - if message and hasattr(message, 'bot'): - from helper_bot.utils.base_dependency_factory import \ - get_global_instance + message = next( + (arg for arg in args if isinstance(arg, types.Message)), None + ) + if message and hasattr(message, "bot"): + from helper_bot.utils.base_dependency_factory import ( + get_global_instance, + ) + bdf = get_global_instance() - important_logs = bdf.settings['Telegram']['important_logs'] + important_logs = bdf.settings["Telegram"]["important_logs"] await message.bot.send_message( chat_id=important_logs, - text=f"Произошла ошибка в {func.__name__}: {str(e)}\n\nTraceback:\n{traceback.format_exc()}" + text=f"Произошла ошибка в {func.__name__}: {str(e)}\n\nTraceback:\n{traceback.format_exc()}", ) except Exception: # If we can't log the error, at least it was logged to logger pass raise + return wrapper diff --git a/helper_bot/handlers/private/private_handlers.py b/helper_bot/handlers/private/private_handlers.py index f34af93..c8d0f8b 100644 --- a/helper_bot/handlers/private/private_handlers.py +++ b/helper_bot/handlers/private/private_handlers.py @@ -8,18 +8,23 @@ from datetime import datetime from aiogram import F, Router, types from aiogram.filters import Command, StateFilter from aiogram.fsm.context import FSMContext + # Local imports - filters and middlewares from database.async_db import AsyncBotDB from helper_bot.filters.main import ChatTypeFilter + # Local imports - utilities -from helper_bot.keyboards import (get_reply_keyboard, - get_reply_keyboard_for_post) +from helper_bot.keyboards import get_reply_keyboard, get_reply_keyboard_for_post from helper_bot.keyboards.keyboards import get_reply_keyboard_leave_chat from helper_bot.middlewares.album_middleware import AlbumMiddleware from helper_bot.middlewares.blacklist_middleware import BlacklistMiddleware from helper_bot.utils import messages -from helper_bot.utils.helper_func import (check_user_emoji, get_first_name, - update_user_info) +from helper_bot.utils.helper_func import ( + check_user_emoji, + get_first_name, + update_user_info, +) + # Local imports - metrics from helper_bot.utils.metrics import db_query_time, track_errors, track_time @@ -34,83 +39,144 @@ sleep = asyncio.sleep class PrivateHandlers: """Main handler class for private messages""" - - def __init__(self, db: AsyncBotDB, settings: BotSettings, s3_storage=None, scoring_manager=None): + + def __init__( + self, + db: AsyncBotDB, + settings: BotSettings, + s3_storage=None, + scoring_manager=None, + ): self.db = db self.settings = settings self.user_service = UserService(db, settings) self.post_service = PostService(db, settings, s3_storage, scoring_manager) self.sticker_service = StickerService(settings) - + self.router = Router() self.router.message.middleware(AlbumMiddleware(latency=5.0)) self.router.message.middleware(BlacklistMiddleware()) - + # Register handlers self._register_handlers() - + def _register_handlers(self): """Register all message handlers""" # Command handlers - self.router.message.register(self.handle_emoji_message, ChatTypeFilter(chat_type=["private"]), Command("emoji")) - self.router.message.register(self.handle_restart_message, ChatTypeFilter(chat_type=["private"]), Command("restart")) - self.router.message.register(self.handle_start_message, ChatTypeFilter(chat_type=["private"]), Command("start")) - self.router.message.register(self.handle_start_message, ChatTypeFilter(chat_type=["private"]), F.text == BUTTON_TEXTS["RETURN_TO_BOT"]) - - # Button handlers - self.router.message.register(self.suggest_post, StateFilter(FSM_STATES["START"]), ChatTypeFilter(chat_type=["private"]), F.text == BUTTON_TEXTS["SUGGEST_POST"]) - self.router.message.register(self.end_message, ChatTypeFilter(chat_type=["private"]), F.text == BUTTON_TEXTS["SAY_GOODBYE"]) - self.router.message.register(self.end_message, ChatTypeFilter(chat_type=["private"]), F.text == BUTTON_TEXTS["LEAVE_CHAT"]) - self.router.message.register(self.stickers, ChatTypeFilter(chat_type=["private"]), F.text == BUTTON_TEXTS["WANT_STICKERS"]) - self.router.message.register(self.connect_with_admin, StateFilter(FSM_STATES["START"]), ChatTypeFilter(chat_type=["private"]), F.text == BUTTON_TEXTS["CONNECT_ADMIN"]) + self.router.message.register( + self.handle_emoji_message, + ChatTypeFilter(chat_type=["private"]), + Command("emoji"), + ) + self.router.message.register( + self.handle_restart_message, + ChatTypeFilter(chat_type=["private"]), + Command("restart"), + ) + self.router.message.register( + self.handle_start_message, + ChatTypeFilter(chat_type=["private"]), + Command("start"), + ) + self.router.message.register( + self.handle_start_message, + ChatTypeFilter(chat_type=["private"]), + F.text == BUTTON_TEXTS["RETURN_TO_BOT"], + ) + + # Button handlers + self.router.message.register( + self.suggest_post, + StateFilter(FSM_STATES["START"]), + ChatTypeFilter(chat_type=["private"]), + F.text == BUTTON_TEXTS["SUGGEST_POST"], + ) + self.router.message.register( + self.end_message, + ChatTypeFilter(chat_type=["private"]), + F.text == BUTTON_TEXTS["SAY_GOODBYE"], + ) + self.router.message.register( + self.end_message, + ChatTypeFilter(chat_type=["private"]), + F.text == BUTTON_TEXTS["LEAVE_CHAT"], + ) + self.router.message.register( + self.stickers, + ChatTypeFilter(chat_type=["private"]), + F.text == BUTTON_TEXTS["WANT_STICKERS"], + ) + self.router.message.register( + self.connect_with_admin, + StateFilter(FSM_STATES["START"]), + ChatTypeFilter(chat_type=["private"]), + F.text == BUTTON_TEXTS["CONNECT_ADMIN"], + ) - # State handlers - self.router.message.register(self.suggest_router, StateFilter(FSM_STATES["SUGGEST"]), ChatTypeFilter(chat_type=["private"])) - self.router.message.register(self.resend_message_in_group_for_message, StateFilter(FSM_STATES["PRE_CHAT"]), ChatTypeFilter(chat_type=["private"])) - self.router.message.register(self.resend_message_in_group_for_message, StateFilter(FSM_STATES["CHAT"]), ChatTypeFilter(chat_type=["private"])) + self.router.message.register( + self.suggest_router, + StateFilter(FSM_STATES["SUGGEST"]), + ChatTypeFilter(chat_type=["private"]), + ) + self.router.message.register( + self.resend_message_in_group_for_message, + StateFilter(FSM_STATES["PRE_CHAT"]), + ChatTypeFilter(chat_type=["private"]), + ) + self.router.message.register( + self.resend_message_in_group_for_message, + StateFilter(FSM_STATES["CHAT"]), + ChatTypeFilter(chat_type=["private"]), + ) @error_handler @track_errors("private_handlers", "handle_emoji_message") @track_time("handle_emoji_message", "private_handlers") - async def handle_emoji_message(self, message: types.Message, state: FSMContext, **kwargs): + async def handle_emoji_message( + self, message: types.Message, state: FSMContext, **kwargs + ): """Handle emoji command""" await self.user_service.log_user_message(message) user_emoji = await check_user_emoji(message) await state.set_state(FSM_STATES["START"]) if user_emoji is not None: - await message.answer(f'Твоя эмодзя - {user_emoji}', parse_mode='HTML') - + await message.answer(f"Твоя эмодзя - {user_emoji}", parse_mode="HTML") + @error_handler @track_errors("private_handlers", "handle_restart_message") @track_time("handle_restart_message", "private_handlers") - async def handle_restart_message(self, message: types.Message, state: FSMContext, **kwargs): + async def handle_restart_message( + self, message: types.Message, state: FSMContext, **kwargs + ): """Handle restart command""" markup = await get_reply_keyboard(self.db, message.from_user.id) await self.user_service.log_user_message(message) await state.set_state(FSM_STATES["START"]) - await update_user_info('love', message) + await update_user_info("love", message) await check_user_emoji(message) - await message.answer('Я перезапущен!', reply_markup=markup, parse_mode='HTML') - + await message.answer("Я перезапущен!", reply_markup=markup, parse_mode="HTML") + @error_handler @track_errors("private_handlers", "handle_start_message") @track_time("handle_start_message", "private_handlers") - async def handle_start_message(self, message: types.Message, state: FSMContext, **kwargs): + async def handle_start_message( + self, message: types.Message, state: FSMContext, **kwargs + ): """Handle start command and return to bot button with metrics tracking""" # User service operations with metrics await self.user_service.log_user_message(message) await self.user_service.ensure_user_exists(message) await state.set_state(FSM_STATES["START"]) - + # Send sticker with metrics await self.sticker_service.send_random_hello_sticker(message) - + # Send welcome message with metrics markup = await get_reply_keyboard(self.db, message.from_user.id) - hello_message = messages.get_message(get_first_name(message), 'HELLO_MESSAGE') - await message.answer(hello_message, reply_markup=markup, parse_mode='HTML') - + hello_message = messages.get_message(get_first_name(message), "HELLO_MESSAGE") + await message.answer(hello_message, reply_markup=markup, parse_mode="HTML") + @error_handler @track_errors("private_handlers", "suggest_post") @track_time("suggest_post", "private_handlers") @@ -120,11 +186,11 @@ class PrivateHandlers: await self.user_service.update_user_activity(message.from_user.id) await self.user_service.log_user_message(message) await state.set_state(FSM_STATES["SUGGEST"]) - + markup = types.ReplyKeyboardRemove() - suggest_news = messages.get_message(get_first_name(message), 'SUGGEST_NEWS') + suggest_news = messages.get_message(get_first_name(message), "SUGGEST_NEWS") await message.answer(suggest_news, reply_markup=markup) - + @error_handler @track_errors("private_handlers", "end_message") @track_time("end_message", "private_handlers") @@ -133,40 +199,44 @@ class PrivateHandlers: # User service operations with metrics await self.user_service.update_user_activity(message.from_user.id) await self.user_service.log_user_message(message) - + # Send sticker await self.sticker_service.send_random_goodbye_sticker(message) - + # Send goodbye message markup = types.ReplyKeyboardRemove() - bye_message = messages.get_message(get_first_name(message), 'BYE_MESSAGE') + bye_message = messages.get_message(get_first_name(message), "BYE_MESSAGE") await message.answer(bye_message, reply_markup=markup) await state.set_state(FSM_STATES["START"]) - + @error_handler @track_errors("private_handlers", "suggest_router") @track_time("suggest_router", "private_handlers") - async def suggest_router(self, message: types.Message, state: FSMContext, album: list = None, **kwargs): + async def suggest_router( + self, message: types.Message, state: FSMContext, album: list = None, **kwargs + ): """Handle post submission in suggest state - сразу отвечает пользователю, обработка в фоне""" # Сразу отвечаем пользователю markup_for_user = await get_reply_keyboard(self.db, message.from_user.id) - success_send_message = messages.get_message(get_first_name(message), 'SUCCESS_SEND_MESSAGE') + success_send_message = messages.get_message( + get_first_name(message), "SUCCESS_SEND_MESSAGE" + ) await message.answer(success_send_message, reply_markup=markup_for_user) await state.set_state(FSM_STATES["START"]) - + # Проверяем, есть ли механизм для получения полной медиагруппы (для медиагрупп) album_getter = kwargs.get("album_getter") - + # В фоне обрабатываем пост async def process_post_background(): try: # Обновляем активность пользователя await self.user_service.update_user_activity(message.from_user.id) - + # Логируем сообщение (только для одиночных сообщений, не медиагрупп) if message.media_group_id is None: await self.user_service.log_user_message(message) - + # Для медиагрупп ждем полную медиагруппу if album_getter and message.media_group_id: full_album = await album_getter.get_album(timeout=10.0) @@ -177,10 +247,11 @@ class PrivateHandlers: await self.post_service.process_post(message, album) except Exception as e: from logs.custom_logger import logger + logger.error(f"Ошибка при фоновой обработке поста: {e}") - + asyncio.create_task(process_post_background()) - + @error_handler @track_errors("private_handlers", "stickers") @track_time("stickers", "private_handlers") @@ -191,41 +262,46 @@ class PrivateHandlers: markup = await get_reply_keyboard(self.db, message.from_user.id) await self.db.update_stickers_info(message.from_user.id) await self.user_service.log_user_message(message) - await message.answer( - text=ERROR_MESSAGES["STICKERS_LINK"], - reply_markup=markup - ) + await message.answer(text=ERROR_MESSAGES["STICKERS_LINK"], reply_markup=markup) await state.set_state(FSM_STATES["START"]) - + @error_handler @track_errors("private_handlers", "connect_with_admin") @track_time("connect_with_admin", "private_handlers") - async def connect_with_admin(self, message: types.Message, state: FSMContext, **kwargs): + async def connect_with_admin( + self, message: types.Message, state: FSMContext, **kwargs + ): """Handle connect with admin button""" # User service operations with metrics await self.user_service.update_user_activity(message.from_user.id) - admin_message = messages.get_message(get_first_name(message), 'CONNECT_WITH_ADMIN') + admin_message = messages.get_message( + get_first_name(message), "CONNECT_WITH_ADMIN" + ) await message.answer(admin_message, parse_mode="html") await self.user_service.log_user_message(message) await state.set_state(FSM_STATES["PRE_CHAT"]) - + @error_handler @track_errors("private_handlers", "resend_message_in_group_for_message") @track_time("resend_message_in_group_for_message", "private_handlers") @db_query_time("resend_message_in_group_for_message", "messages", "insert") - async def resend_message_in_group_for_message(self, message: types.Message, state: FSMContext, **kwargs): + async def resend_message_in_group_for_message( + self, message: types.Message, state: FSMContext, **kwargs + ): """Handle messages in admin chat states""" # User service operations with metrics await self.user_service.update_user_activity(message.from_user.id) await message.forward(chat_id=self.settings.group_for_message) - + current_date = datetime.now() date = int(current_date.timestamp()) - await self.db.add_message(message.text, message.from_user.id, message.message_id + 1, date) - - question = messages.get_message(get_first_name(message), 'QUESTION') + await self.db.add_message( + message.text, message.from_user.id, message.message_id + 1, date + ) + + question = messages.get_message(get_first_name(message), "QUESTION") user_state = await state.get_state() - + if user_state == FSM_STATES["PRE_CHAT"]: markup = await get_reply_keyboard(self.db, message.from_user.id) await message.answer(question, reply_markup=markup) @@ -236,7 +312,9 @@ class PrivateHandlers: # Factory function to create handlers with dependencies -def create_private_handlers(db: AsyncBotDB, settings: BotSettings, s3_storage=None, scoring_manager=None) -> PrivateHandlers: +def create_private_handlers( + db: AsyncBotDB, settings: BotSettings, s3_storage=None, scoring_manager=None +) -> PrivateHandlers: """Create private handlers instance with dependencies""" return PrivateHandlers(db, settings, s3_storage, scoring_manager) @@ -247,37 +325,39 @@ private_router = Router() # Флаг инициализации для защиты от повторного вызова _legacy_router_initialized = False + # Initialize with global dependencies (for backward compatibility) def init_legacy_router(): """Initialize legacy router with global dependencies""" global private_router, _legacy_router_initialized - + if _legacy_router_initialized: return - + from helper_bot.utils.base_dependency_factory import get_global_instance - + bdf = get_global_instance() settings = BotSettings( - group_for_posts=bdf.settings['Telegram']['group_for_posts'], - group_for_message=bdf.settings['Telegram']['group_for_message'], - main_public=bdf.settings['Telegram']['main_public'], - group_for_logs=bdf.settings['Telegram']['group_for_logs'], - important_logs=bdf.settings['Telegram']['important_logs'], - preview_link=bdf.settings['Telegram']['preview_link'], - logs=bdf.settings['Settings']['logs'], - test=bdf.settings['Settings']['test'] + group_for_posts=bdf.settings["Telegram"]["group_for_posts"], + group_for_message=bdf.settings["Telegram"]["group_for_message"], + main_public=bdf.settings["Telegram"]["main_public"], + group_for_logs=bdf.settings["Telegram"]["group_for_logs"], + important_logs=bdf.settings["Telegram"]["important_logs"], + preview_link=bdf.settings["Telegram"]["preview_link"], + logs=bdf.settings["Settings"]["logs"], + test=bdf.settings["Settings"]["test"], ) - + db = bdf.get_db() s3_storage = bdf.get_s3_storage() scoring_manager = bdf.get_scoring_manager() handlers = create_private_handlers(db, settings, s3_storage, scoring_manager) - + # Instead of trying to copy handlers, we'll use the new router directly # This maintains backward compatibility while using the new architecture private_router = handlers.router _legacy_router_initialized = True + # Initialize legacy router init_legacy_router() diff --git a/helper_bot/handlers/private/services.py b/helper_bot/handlers/private/services.py index 5d96348..5fbaef4 100644 --- a/helper_bot/handlers/private/services.py +++ b/helper_bot/handlers/private/services.py @@ -12,37 +12,61 @@ from typing import Any, Callable, Dict, Protocol, Union # Third-party imports from aiogram import types from aiogram.types import FSInputFile + from database.models import TelegramPost, User from helper_bot.keyboards import get_reply_keyboard_for_post + # Local imports - utilities from helper_bot.utils.helper_func import ( - add_in_db_media, check_username_and_full_name, determine_anonymity, - get_first_name, get_text_message, prepare_media_group_from_middlewares, - send_audio_message, send_media_group_message_to_private_chat, - send_photo_message, send_text_message, send_video_message, - send_video_note_message, send_voice_message) + add_in_db_media, + check_username_and_full_name, + determine_anonymity, + get_first_name, + get_text_message, + prepare_media_group_from_middlewares, + send_audio_message, + send_media_group_message_to_private_chat, + send_photo_message, + send_text_message, + send_video_message, + send_video_note_message, + send_voice_message, +) + # Local imports - metrics -from helper_bot.utils.metrics import (db_query_time, track_errors, - track_file_operations, - track_media_processing, track_time) +from helper_bot.utils.metrics import ( + db_query_time, + track_errors, + track_file_operations, + track_media_processing, + track_time, +) from logs.custom_logger import logger class DatabaseProtocol(Protocol): """Protocol for database operations""" + async def user_exists(self, user_id: int) -> bool: ... async def add_user(self, user: User) -> None: ... - async def update_user_info(self, user_id: int, username: str = None, full_name: str = None) -> None: ... + async def update_user_info( + self, user_id: int, username: str = None, full_name: str = None + ) -> None: ... async def update_user_date(self, user_id: int) -> None: ... async def add_post(self, post: TelegramPost) -> None: ... async def update_stickers_info(self, user_id: int) -> None: ... - async def add_message(self, message_text: str, user_id: int, message_id: int, date: int = None) -> None: ... - async def update_helper_message(self, message_id: int, helper_message_id: int) -> None: ... + async def add_message( + self, message_text: str, user_id: int, message_id: int, date: int = None + ) -> None: ... + async def update_helper_message( + self, message_id: int, helper_message_id: int + ) -> None: ... @dataclass class BotSettings: """Bot configuration settings""" + group_for_posts: str group_for_message: str main_public: str @@ -55,18 +79,18 @@ class BotSettings: class UserService: """Service for user-related operations""" - + def __init__(self, db: DatabaseProtocol, settings: BotSettings) -> None: self.db = db self.settings = settings - + @track_time("update_user_activity", "user_service") @track_errors("user_service", "update_user_activity") @db_query_time("update_user_activity", "users", "update") async def update_user_activity(self, user_id: int) -> None: """Update user's last activity timestamp with metrics tracking""" await self.db.update_user_date(user_id) - + @track_time("ensure_user_exists", "user_service") @track_errors("user_service", "ensure_user_exists") @db_query_time("ensure_user_exists", "users", "insert") @@ -79,7 +103,7 @@ class UserService: first_name = get_first_name(message) is_bot = message.from_user.is_bot language_code = message.from_user.language_code - + # Create User object with current timestamp current_timestamp = int(datetime.now().timestamp()) user = User( @@ -93,33 +117,39 @@ class UserService: has_stickers=False, date_added=current_timestamp, date_changed=current_timestamp, - voice_bot_welcome_received=False + voice_bot_welcome_received=False, ) - + # Пытаемся создать пользователя (если уже существует - игнорируем) # Это устраняет race condition и упрощает логику await self.db.add_user(user) - + # Проверяем, нужно ли обновить информацию о существующем пользователе - is_need_update = await check_username_and_full_name(user_id, username, full_name, self.db) + is_need_update = await check_username_and_full_name( + user_id, username, full_name, self.db + ) if is_need_update: await self.db.update_user_info(user_id, username, full_name) - safe_full_name = html.escape(full_name) if full_name else "Неизвестный пользователь" + safe_full_name = ( + html.escape(full_name) if full_name else "Неизвестный пользователь" + ) # Для отображения используем подстановочное значение, но в БД сохраняем только реальный username safe_username = html.escape(username) if username else "Без никнейма" - + await message.answer( - f"Давно не виделись! Вижу что ты изменился;) Теперь буду звать тебя: {safe_full_name} и ник @{safe_username}") + f"Давно не виделись! Вижу что ты изменился;) Теперь буду звать тебя: {safe_full_name} и ник @{safe_username}" + ) await message.bot.send_message( chat_id=self.settings.group_for_logs, - text=f'Для пользователя: {user_id} обновлены данные в БД.\nНовое имя: {safe_full_name}\nНовый ник:{safe_username}') - + text=f"Для пользователя: {user_id} обновлены данные в БД.\nНовое имя: {safe_full_name}\nНовый ник:{safe_username}", + ) + await self.db.update_user_date(user_id) async def log_user_message(self, message: types.Message) -> None: """Forward user message to logs group with metrics tracking""" await message.forward(chat_id=self.settings.group_for_logs) - + def get_safe_user_info(self, message: types.Message) -> tuple[str, str]: """Get safely escaped user information for logging""" full_name = message.from_user.full_name or "Неизвестный пользователь" @@ -129,60 +159,87 @@ class UserService: class PostService: """Service for post-related operations""" - - def __init__(self, db: DatabaseProtocol, settings: BotSettings, s3_storage=None, scoring_manager=None) -> None: + + def __init__( + self, + db: DatabaseProtocol, + settings: BotSettings, + s3_storage=None, + scoring_manager=None, + ) -> None: self.db = db self.settings = settings self.s3_storage = s3_storage self.scoring_manager = scoring_manager - - async def _save_media_background(self, sent_message: types.Message, bot_db: Any, s3_storage) -> None: + + async def _save_media_background( + self, sent_message: types.Message, bot_db: Any, s3_storage + ) -> None: """Сохраняет медиа в фоне, чтобы не блокировать ответ пользователю""" try: success = await add_in_db_media(sent_message, bot_db, s3_storage) if not success: - logger.warning(f"_save_media_background: Не удалось сохранить медиа для поста {sent_message.message_id}") + logger.warning( + f"_save_media_background: Не удалось сохранить медиа для поста {sent_message.message_id}" + ) except Exception as e: - logger.error(f"_save_media_background: Ошибка при сохранении медиа для поста {sent_message.message_id}: {e}") - + logger.error( + f"_save_media_background: Ошибка при сохранении медиа для поста {sent_message.message_id}: {e}" + ) + async def _get_scores(self, text: str) -> tuple: """ Получает скоры для текста поста. - + Returns: Tuple (deepseek_score, rag_score, rag_confidence, rag_score_pos_only, ml_scores_json) """ if not self.scoring_manager or not text or not text.strip(): return None, None, None, None, None - + try: scores = await self.scoring_manager.score_post(text) - + # Формируем JSON для сохранения в БД import json - ml_scores_json = json.dumps(scores.to_json_dict()) if scores.has_any_score() else None - + + ml_scores_json = ( + json.dumps(scores.to_json_dict()) if scores.has_any_score() else None + ) + # Получаем данные от RAG rag_confidence = scores.rag.confidence if scores.rag else None - rag_score_pos_only = scores.rag.metadata.get("rag_score_pos_only") if scores.rag else None - - return scores.deepseek_score, scores.rag_score, rag_confidence, rag_score_pos_only, ml_scores_json + rag_score_pos_only = ( + scores.rag.metadata.get("rag_score_pos_only") if scores.rag else None + ) + + return ( + scores.deepseek_score, + scores.rag_score, + rag_confidence, + rag_score_pos_only, + ml_scores_json, + ) except Exception as e: logger.error(f"PostService: Ошибка получения скоров: {e}") return None, None, None, None, None - - async def _save_scores_background(self, message_id: int, ml_scores_json: str) -> None: + + async def _save_scores_background( + self, message_id: int, ml_scores_json: str + ) -> None: """Сохраняет скоры в БД в фоне.""" if ml_scores_json: try: await self.db.update_ml_scores(message_id, ml_scores_json) except Exception as e: - logger.error(f"PostService: Ошибка сохранения скоров для {message_id}: {e}") - + logger.error( + f"PostService: Ошибка сохранения скоров для {message_id}: {e}" + ) + async def _get_scores_with_error_handling(self, text: str) -> tuple: """ Получает скоры для текста поста с обработкой ошибок. - + Returns: Tuple (deepseek_score, rag_score, rag_confidence, rag_score_pos_only, ml_scores_json, error_message) error_message будет None если все ок, или строка с описанием ошибки @@ -190,28 +247,40 @@ class PostService: if not self.scoring_manager: # Скоры выключены в .env - это нормально return None, None, None, None, None, None - + if not text or not text.strip(): return None, None, None, None, None, None - + try: scores = await self.scoring_manager.score_post(text) - + # Формируем JSON для сохранения в БД import json - ml_scores_json = json.dumps(scores.to_json_dict()) if scores.has_any_score() else None - + + ml_scores_json = ( + json.dumps(scores.to_json_dict()) if scores.has_any_score() else None + ) + # Получаем данные от RAG rag_confidence = scores.rag.confidence if scores.rag else None - rag_score_pos_only = scores.rag.metadata.get("rag_score_pos_only") if scores.rag else None - - return scores.deepseek_score, scores.rag_score, rag_confidence, rag_score_pos_only, ml_scores_json, None + rag_score_pos_only = ( + scores.rag.metadata.get("rag_score_pos_only") if scores.rag else None + ) + + return ( + scores.deepseek_score, + scores.rag_score, + rag_confidence, + rag_score_pos_only, + ml_scores_json, + None, + ) except Exception as e: logger.error(f"PostService: Ошибка получения скоров: {e}") # Возвращаем частичные скоры если есть, или сообщение об ошибке error_message = "Не удалось рассчитать скоры" return None, None, None, None, None, error_message - + @track_time("_process_post_background", "post_service") @track_errors("post_service", "_process_post_background") async def _process_post_background( @@ -219,11 +288,11 @@ class PostService: message: types.Message, first_name: str, content_type: str, - album: Union[list, None] = None + album: Union[list, None] = None, ) -> None: """ Обрабатывает пост в фоне: получает скоры, отправляет в группу модерации, сохраняет в БД. - + Args: message: Сообщение от пользователя first_name: Имя пользователя @@ -236,14 +305,22 @@ class PostService: if content_type == "text": original_raw_text = message.text or "" elif content_type == "media_group": - original_raw_text = album[0].caption or "" if album and album[0].caption else "" + original_raw_text = ( + album[0].caption or "" if album and album[0].caption else "" + ) else: original_raw_text = message.caption or "" - + # Получаем скоры с обработкой ошибок - deepseek_score, rag_score, rag_confidence, rag_score_pos_only, ml_scores_json, error_message = \ - await self._get_scores_with_error_handling(original_raw_text) - + ( + deepseek_score, + rag_score, + rag_confidence, + rag_score_pos_only, + ml_scores_json, + error_message, + ) = await self._get_scores_with_error_handling(original_raw_text) + # Формируем текст для поста (с сообщением об ошибке если есть) text_for_post = original_raw_text if error_message: @@ -253,7 +330,7 @@ class PostService: # Для медиа добавляем в caption elif content_type in ("photo", "video", "audio") and original_raw_text: text_for_post = f"{original_raw_text}\n\n⚠️ {error_message}" - + # Формируем текст/caption с учетом скоров post_text = "" if text_for_post or content_type == "text": @@ -273,13 +350,13 @@ class PostService: rag_confidence=rag_confidence, rag_score_pos_only=rag_score_pos_only, ) - + # Определяем анонимность по исходному тексту (без сообщения об ошибке) is_anonymous = determine_anonymity(original_raw_text) - + markup = get_reply_keyboard_for_post() sent_message = None - + # Отправляем пост в группу модерации в зависимости от типа if content_type == "text": sent_message = await send_text_message( @@ -287,59 +364,95 @@ class PostService: ) elif content_type == "photo": sent_message = await send_photo_message( - self.settings.group_for_posts, message, message.photo[-1].file_id, post_text, markup + self.settings.group_for_posts, + message, + message.photo[-1].file_id, + post_text, + markup, ) elif content_type == "video": sent_message = await send_video_message( - self.settings.group_for_posts, message, message.video.file_id, post_text, markup + self.settings.group_for_posts, + message, + message.video.file_id, + post_text, + markup, ) elif content_type == "audio": sent_message = await send_audio_message( - self.settings.group_for_posts, message, message.audio.file_id, post_text, markup + self.settings.group_for_posts, + message, + message.audio.file_id, + post_text, + markup, ) elif content_type == "voice": sent_message = await send_voice_message( - self.settings.group_for_posts, message, message.voice.file_id, markup + self.settings.group_for_posts, + message, + message.voice.file_id, + markup, ) elif content_type == "video_note": sent_message = await send_video_note_message( - self.settings.group_for_posts, message, message.video_note.file_id, markup + self.settings.group_for_posts, + message, + message.video_note.file_id, + markup, ) elif content_type == "media_group": # Для медиагруппы используем специальную обработку # Передаем ml_scores_json для сохранения в БД await self._process_media_group_background( - message, album, first_name, post_text, is_anonymous, original_raw_text, ml_scores_json + message, + album, + first_name, + post_text, + is_anonymous, + original_raw_text, + ml_scores_json, ) return else: - logger.error(f"PostService: Неподдерживаемый тип контента: {content_type}") + logger.error( + f"PostService: Неподдерживаемый тип контента: {content_type}" + ) return - + if not sent_message: - logger.error(f"PostService: Не удалось отправить пост типа {content_type}") + logger.error( + f"PostService: Не удалось отправить пост типа {content_type}" + ) return - + # Сохраняем пост в БД (сохраняем исходный текст, без сообщения об ошибке) post = TelegramPost( message_id=sent_message.message_id, text=original_raw_text, author_id=message.from_user.id, created_at=int(datetime.now().timestamp()), - is_anonymous=is_anonymous + is_anonymous=is_anonymous, ) await self.db.add_post(post) - + # Сохраняем медиа и скоры в фоне if content_type in ("photo", "video", "audio", "voice", "video_note"): - asyncio.create_task(self._save_media_background(sent_message, self.db, self.s3_storage)) - + asyncio.create_task( + self._save_media_background(sent_message, self.db, self.s3_storage) + ) + if ml_scores_json: - asyncio.create_task(self._save_scores_background(sent_message.message_id, ml_scores_json)) - + asyncio.create_task( + self._save_scores_background( + sent_message.message_id, ml_scores_json + ) + ) + except Exception as e: - logger.error(f"PostService: Критическая ошибка в _process_post_background для {content_type}: {e}") - + logger.error( + f"PostService: Критическая ошибка в _process_post_background для {content_type}: {e}" + ) + async def _process_media_group_background( self, message: types.Message, @@ -348,70 +461,81 @@ class PostService: post_caption: str, is_anonymous: bool, original_raw_text: str, - ml_scores_json: str = None + ml_scores_json: str = None, ) -> None: """Обрабатывает медиагруппу в фоне""" try: - media_group = await prepare_media_group_from_middlewares(album, post_caption) - - media_group_message_ids = await send_media_group_message_to_private_chat( - self.settings.group_for_posts, message, media_group, self.db, None, self.s3_storage + media_group = await prepare_media_group_from_middlewares( + album, post_caption ) - + + media_group_message_ids = await send_media_group_message_to_private_chat( + self.settings.group_for_posts, + message, + media_group, + self.db, + None, + self.s3_storage, + ) + main_post_id = media_group_message_ids[-1] - + main_post = TelegramPost( message_id=main_post_id, text=original_raw_text, author_id=message.from_user.id, created_at=int(datetime.now().timestamp()), - is_anonymous=is_anonymous + is_anonymous=is_anonymous, ) await self.db.add_post(main_post) - + # Сохраняем скоры в фоне (если они были получены) if ml_scores_json: - asyncio.create_task(self._save_scores_background(main_post_id, ml_scores_json)) - + asyncio.create_task( + self._save_scores_background(main_post_id, ml_scores_json) + ) + for msg_id in media_group_message_ids: await self.db.add_message_link(main_post_id, msg_id) - + await asyncio.sleep(0.2) - + markup = get_reply_keyboard_for_post() helper_message = await send_text_message( - self.settings.group_for_posts, - message, - "^", - markup + self.settings.group_for_posts, message, "^", markup ) helper_message_id = helper_message.message_id - + helper_post = TelegramPost( message_id=helper_message_id, text="^", author_id=message.from_user.id, helper_text_message_id=main_post_id, - created_at=int(datetime.now().timestamp()) + created_at=int(datetime.now().timestamp()), ) await self.db.add_post(helper_post) - + await self.db.update_helper_message( - message_id=main_post_id, - helper_message_id=helper_message_id + message_id=main_post_id, helper_message_id=helper_message_id ) except Exception as e: logger.error(f"PostService: Ошибка в _process_media_group_background: {e}") - + @track_time("handle_text_post", "post_service") @track_errors("post_service", "handle_text_post") @db_query_time("handle_text_post", "posts", "insert") async def handle_text_post(self, message: types.Message, first_name: str) -> None: """Handle text post submission""" raw_text = message.text or "" - + # Получаем скоры для текста - deepseek_score, rag_score, rag_confidence, rag_score_pos_only, ml_scores_json = await self._get_scores(raw_text) + ( + deepseek_score, + rag_score, + rag_confidence, + rag_score_pos_only, + ml_scores_json, + ) = await self._get_scores(raw_text) logger.debug( f"PostService.handle_text_post: Передача скоров в get_text_message - " @@ -420,11 +544,11 @@ class PostService: f"rag_confidence={rag_confidence} (type: {type(rag_confidence).__name__ if rag_confidence is not None else 'None'}), " f"message_id={message.message_id}" ) - + # Формируем текст с учетом скоров post_text = get_text_message( - message.text.lower(), - first_name, + message.text.lower(), + first_name, message.from_user.username, deepseek_score=deepseek_score, rag_score=rag_score, @@ -432,34 +556,44 @@ class PostService: rag_score_pos_only=rag_score_pos_only, ) markup = get_reply_keyboard_for_post() - - sent_message = await send_text_message(self.settings.group_for_posts, message, post_text, markup) - + + sent_message = await send_text_message( + self.settings.group_for_posts, message, post_text, markup + ) + # Определяем анонимность is_anonymous = determine_anonymity(raw_text) - + post = TelegramPost( message_id=sent_message.message_id, text=raw_text, author_id=message.from_user.id, created_at=int(datetime.now().timestamp()), - is_anonymous=is_anonymous + is_anonymous=is_anonymous, ) await self.db.add_post(post) - + # Сохраняем скоры в фоне if ml_scores_json: - asyncio.create_task(self._save_scores_background(sent_message.message_id, ml_scores_json)) - + asyncio.create_task( + self._save_scores_background(sent_message.message_id, ml_scores_json) + ) + @track_time("handle_photo_post", "post_service") @track_errors("post_service", "handle_photo_post") @db_query_time("handle_photo_post", "posts", "insert") async def handle_photo_post(self, message: types.Message, first_name: str) -> None: """Handle photo post submission""" raw_caption = message.caption or "" - + # Получаем скоры для текста - deepseek_score, rag_score, rag_confidence, rag_score_pos_only, ml_scores_json = await self._get_scores(raw_caption) + ( + deepseek_score, + rag_score, + rag_confidence, + rag_score_pos_only, + ml_scores_json, + ) = await self._get_scores(raw_caption) logger.debug( f"PostService.handle_photo_post: Передача скоров в get_text_message - " @@ -468,50 +602,64 @@ class PostService: f"rag_confidence={rag_confidence} (type: {type(rag_confidence).__name__ if rag_confidence is not None else 'None'}), " f"message_id={message.message_id}" ) - + post_caption = "" if message.caption: post_caption = get_text_message( - message.caption.lower(), - first_name, + message.caption.lower(), + first_name, message.from_user.username, deepseek_score=deepseek_score, rag_score=rag_score, rag_confidence=rag_confidence, rag_score_pos_only=rag_score_pos_only, ) - + markup = get_reply_keyboard_for_post() sent_message = await send_photo_message( - self.settings.group_for_posts, message, message.photo[-1].file_id, post_caption, markup + self.settings.group_for_posts, + message, + message.photo[-1].file_id, + post_caption, + markup, ) - + # Определяем анонимность is_anonymous = determine_anonymity(raw_caption) - + post = TelegramPost( message_id=sent_message.message_id, text=raw_caption, author_id=message.from_user.id, created_at=int(datetime.now().timestamp()), - is_anonymous=is_anonymous + is_anonymous=is_anonymous, ) await self.db.add_post(post) - + # Сохраняем медиа и скоры в фоне - asyncio.create_task(self._save_media_background(sent_message, self.db, self.s3_storage)) + asyncio.create_task( + self._save_media_background(sent_message, self.db, self.s3_storage) + ) if ml_scores_json: - asyncio.create_task(self._save_scores_background(sent_message.message_id, ml_scores_json)) - + asyncio.create_task( + self._save_scores_background(sent_message.message_id, ml_scores_json) + ) + @track_time("handle_video_post", "post_service") @track_errors("post_service", "handle_video_post") @db_query_time("handle_video_post", "posts", "insert") async def handle_video_post(self, message: types.Message, first_name: str) -> None: """Handle video post submission""" raw_caption = message.caption or "" - + # Получаем скоры для текста - deepseek_score, rag_score, rag_confidence, rag_score_pos_only, ml_scores_json = await self._get_scores(raw_caption) + ( + deepseek_score, + rag_score, + rag_confidence, + rag_score_pos_only, + ml_scores_json, + ) = await self._get_scores(raw_caption) logger.debug( f"PostService.handle_video_post: Передача скоров в get_text_message - " @@ -520,41 +668,49 @@ class PostService: f"rag_confidence={rag_confidence} (type: {type(rag_confidence).__name__ if rag_confidence is not None else 'None'}), " f"message_id={message.message_id}" ) - + post_caption = "" if message.caption: post_caption = get_text_message( - message.caption.lower(), - first_name, + message.caption.lower(), + first_name, message.from_user.username, deepseek_score=deepseek_score, rag_score=rag_score, rag_confidence=rag_confidence, rag_score_pos_only=rag_score_pos_only, ) - + markup = get_reply_keyboard_for_post() sent_message = await send_video_message( - self.settings.group_for_posts, message, message.video.file_id, post_caption, markup + self.settings.group_for_posts, + message, + message.video.file_id, + post_caption, + markup, ) - + # Определяем анонимность is_anonymous = determine_anonymity(raw_caption) - + post = TelegramPost( message_id=sent_message.message_id, text=raw_caption, author_id=message.from_user.id, created_at=int(datetime.now().timestamp()), - is_anonymous=is_anonymous + is_anonymous=is_anonymous, ) await self.db.add_post(post) - + # Сохраняем медиа и скоры в фоне - asyncio.create_task(self._save_media_background(sent_message, self.db, self.s3_storage)) + asyncio.create_task( + self._save_media_background(sent_message, self.db, self.s3_storage) + ) if ml_scores_json: - asyncio.create_task(self._save_scores_background(sent_message.message_id, ml_scores_json)) - + asyncio.create_task( + self._save_scores_background(sent_message.message_id, ml_scores_json) + ) + @track_time("handle_video_note_post", "post_service") @track_errors("post_service", "handle_video_note_post") @db_query_time("handle_video_note_post", "posts", "insert") @@ -564,31 +720,39 @@ class PostService: sent_message = await send_video_note_message( self.settings.group_for_posts, message, message.video_note.file_id, markup ) - + # Сохраняем пустую строку, так как video_note не имеет caption raw_caption = "" is_anonymous = determine_anonymity(raw_caption) - + post = TelegramPost( message_id=sent_message.message_id, text=raw_caption, author_id=message.from_user.id, created_at=int(datetime.now().timestamp()), - is_anonymous=is_anonymous + is_anonymous=is_anonymous, ) await self.db.add_post(post) # Сохраняем медиа в фоне, чтобы не блокировать ответ пользователю - asyncio.create_task(self._save_media_background(sent_message, self.db, self.s3_storage)) - + asyncio.create_task( + self._save_media_background(sent_message, self.db, self.s3_storage) + ) + @track_time("handle_audio_post", "post_service") @track_errors("post_service", "handle_audio_post") @db_query_time("handle_audio_post", "posts", "insert") async def handle_audio_post(self, message: types.Message, first_name: str) -> None: """Handle audio post submission""" raw_caption = message.caption or "" - + # Получаем скоры для текста - deepseek_score, rag_score, rag_confidence, rag_score_pos_only, ml_scores_json = await self._get_scores(raw_caption) + ( + deepseek_score, + rag_score, + rag_confidence, + rag_score_pos_only, + ml_scores_json, + ) = await self._get_scores(raw_caption) logger.debug( f"PostService.handle_audio_post: Передача скоров в get_text_message - " @@ -597,41 +761,49 @@ class PostService: f"rag_confidence={rag_confidence} (type: {type(rag_confidence).__name__ if rag_confidence is not None else 'None'}), " f"message_id={message.message_id}" ) - + post_caption = "" if message.caption: post_caption = get_text_message( - message.caption.lower(), - first_name, + message.caption.lower(), + first_name, message.from_user.username, deepseek_score=deepseek_score, rag_score=rag_score, rag_confidence=rag_confidence, rag_score_pos_only=rag_score_pos_only, ) - + markup = get_reply_keyboard_for_post() sent_message = await send_audio_message( - self.settings.group_for_posts, message, message.audio.file_id, post_caption, markup + self.settings.group_for_posts, + message, + message.audio.file_id, + post_caption, + markup, ) - + # Определяем анонимность is_anonymous = determine_anonymity(raw_caption) - + post = TelegramPost( message_id=sent_message.message_id, text=raw_caption, author_id=message.from_user.id, created_at=int(datetime.now().timestamp()), - is_anonymous=is_anonymous + is_anonymous=is_anonymous, ) await self.db.add_post(post) - + # Сохраняем медиа и скоры в фоне - asyncio.create_task(self._save_media_background(sent_message, self.db, self.s3_storage)) + asyncio.create_task( + self._save_media_background(sent_message, self.db, self.s3_storage) + ) if ml_scores_json: - asyncio.create_task(self._save_scores_background(sent_message.message_id, ml_scores_json)) - + asyncio.create_task( + self._save_scores_background(sent_message.message_id, ml_scores_json) + ) + @track_time("handle_voice_post", "post_service") @track_errors("post_service", "handle_voice_post") @db_query_time("handle_voice_post", "posts", "insert") @@ -641,37 +813,47 @@ class PostService: sent_message = await send_voice_message( self.settings.group_for_posts, message, message.voice.file_id, markup ) - + # Сохраняем пустую строку, так как voice не имеет caption raw_caption = "" is_anonymous = determine_anonymity(raw_caption) - + post = TelegramPost( message_id=sent_message.message_id, text=raw_caption, author_id=message.from_user.id, created_at=int(datetime.now().timestamp()), - is_anonymous=is_anonymous + is_anonymous=is_anonymous, ) await self.db.add_post(post) # Сохраняем медиа в фоне, чтобы не блокировать ответ пользователю - asyncio.create_task(self._save_media_background(sent_message, self.db, self.s3_storage)) - + asyncio.create_task( + self._save_media_background(sent_message, self.db, self.s3_storage) + ) + @track_time("handle_media_group_post", "post_service") - @track_errors("post_service", "handle_media_group_post") + @track_errors("post_service", "handle_media_group_post") @db_query_time("handle_media_group_post", "posts", "insert") @track_media_processing("media_group") - async def handle_media_group_post(self, message: types.Message, album: list, first_name: str) -> None: + async def handle_media_group_post( + self, message: types.Message, album: list, first_name: str + ) -> None: """Handle media group post submission""" post_caption = " " raw_caption = "" ml_scores_json = None - + if album and album[0].caption: raw_caption = album[0].caption or "" - + # Получаем скоры для текста - deepseek_score, rag_score, rag_confidence, rag_score_pos_only, ml_scores_json = await self._get_scores(raw_caption) + ( + deepseek_score, + rag_score, + rag_confidence, + rag_score_pos_only, + ml_scores_json, + ) = await self._get_scores(raw_caption) logger.debug( f"PostService.handle_media_group_post: Передача скоров в get_text_message - " @@ -680,80 +862,89 @@ class PostService: f"rag_confidence={rag_confidence} (type: {type(rag_confidence).__name__ if rag_confidence is not None else 'None'}), " f"message_id={message.message_id}" ) - + post_caption = get_text_message( - album[0].caption.lower(), - first_name, + album[0].caption.lower(), + first_name, message.from_user.username, deepseek_score=deepseek_score, rag_score=rag_score, rag_confidence=rag_confidence, rag_score_pos_only=rag_score_pos_only, ) - + is_anonymous = determine_anonymity(raw_caption) media_group = await prepare_media_group_from_middlewares(album, post_caption) - + media_group_message_ids = await send_media_group_message_to_private_chat( - self.settings.group_for_posts, message, media_group, self.db, None, self.s3_storage + self.settings.group_for_posts, + message, + media_group, + self.db, + None, + self.s3_storage, ) - + main_post_id = media_group_message_ids[-1] - + main_post = TelegramPost( message_id=main_post_id, text=raw_caption, author_id=message.from_user.id, created_at=int(datetime.now().timestamp()), - is_anonymous=is_anonymous + is_anonymous=is_anonymous, ) await self.db.add_post(main_post) - + # Сохраняем скоры в фоне if ml_scores_json: - asyncio.create_task(self._save_scores_background(main_post_id, ml_scores_json)) - + asyncio.create_task( + self._save_scores_background(main_post_id, ml_scores_json) + ) + for msg_id in media_group_message_ids: await self.db.add_message_link(main_post_id, msg_id) - + await asyncio.sleep(0.2) - + markup = get_reply_keyboard_for_post() helper_message = await send_text_message( - self.settings.group_for_posts, - message, - "^", - markup + self.settings.group_for_posts, message, "^", markup ) helper_message_id = helper_message.message_id - + helper_post = TelegramPost( message_id=helper_message_id, text="^", author_id=message.from_user.id, helper_text_message_id=main_post_id, - created_at=int(datetime.now().timestamp()) + created_at=int(datetime.now().timestamp()), ) await self.db.add_post(helper_post) - + await self.db.update_helper_message( - message_id=main_post_id, - helper_message_id=helper_message_id + message_id=main_post_id, helper_message_id=helper_message_id ) - + @track_time("process_post", "post_service") @track_errors("post_service", "process_post") @track_media_processing("media_group") - async def process_post(self, message: types.Message, album: Union[list, None] = None) -> None: + async def process_post( + self, message: types.Message, album: Union[list, None] = None + ) -> None: """ Запускает обработку поста в фоне. Не блокирует выполнение - сразу возвращает управление. """ first_name = get_first_name(message) - + # Определяем тип контента - content_type = "media_group" if message.media_group_id is not None else message.content_type - + content_type = ( + "media_group" + if message.media_group_id is not None + else message.content_type + ) + # Запускаем фоновую обработку asyncio.create_task( self._process_post_background(message, first_name, content_type, album) @@ -762,29 +953,29 @@ class PostService: class StickerService: """Service for sticker-related operations""" - + def __init__(self, settings: BotSettings) -> None: self.settings = settings - + @track_time("send_random_hello_sticker", "sticker_service") @track_errors("sticker_service", "send_random_hello_sticker") @track_file_operations("sticker") async def send_random_hello_sticker(self, message: types.Message) -> None: """Send random hello sticker with metrics tracking""" - name_stick_hello = list(Path('Stick').rglob('Hello_*')) + name_stick_hello = list(Path("Stick").rglob("Hello_*")) if not name_stick_hello: return random_stick_hello = random.choice(name_stick_hello) random_stick_hello = FSInputFile(path=random_stick_hello) await message.answer_sticker(random_stick_hello) await asyncio.sleep(0.3) - + @track_time("send_random_goodbye_sticker", "sticker_service") @track_errors("sticker_service", "send_random_goodbye_sticker") @track_file_operations("sticker") async def send_random_goodbye_sticker(self, message: types.Message) -> None: """Send random goodbye sticker with metrics tracking""" - name_stick_bye = list(Path('Stick').rglob('Universal_*')) + name_stick_bye = list(Path("Stick").rglob("Universal_*")) if not name_stick_bye: return random_stick_bye = random.choice(name_stick_bye) diff --git a/helper_bot/handlers/voice/cleanup_utils.py b/helper_bot/handlers/voice/cleanup_utils.py index 1c4d004..b6e1af1 100644 --- a/helper_bot/handlers/voice/cleanup_utils.py +++ b/helper_bot/handlers/voice/cleanup_utils.py @@ -1,6 +1,7 @@ """ Утилиты для очистки и диагностики проблем с голосовыми файлами """ + import asyncio import os from pathlib import Path @@ -12,108 +13,122 @@ from logs.custom_logger import logger class VoiceFileCleanupUtils: """Утилиты для очистки и диагностики голосовых файлов""" - + def __init__(self, bot_db): self.bot_db = bot_db - + async def find_orphaned_db_records(self) -> List[Tuple[str, int]]: """Найти записи в БД, для которых нет соответствующих файлов""" try: # Получаем все записи из БД all_audio_records = await self.bot_db.get_all_audio_records() orphaned_records = [] - + for record in all_audio_records: - file_name = record.get('file_name', '') - user_id = record.get('author_id', 0) - - file_path = f'{VOICE_USERS_DIR}/{file_name}.ogg' + file_name = record.get("file_name", "") + user_id = record.get("author_id", 0) + + file_path = f"{VOICE_USERS_DIR}/{file_name}.ogg" if not os.path.exists(file_path): orphaned_records.append((file_name, user_id)) - logger.warning(f"Найдена запись в БД без файла: {file_name} (user_id: {user_id})") - - logger.info(f"Найдено {len(orphaned_records)} записей в БД без соответствующих файлов") + logger.warning( + f"Найдена запись в БД без файла: {file_name} (user_id: {user_id})" + ) + + logger.info( + f"Найдено {len(orphaned_records)} записей в БД без соответствующих файлов" + ) return orphaned_records - + except Exception as e: logger.error(f"Ошибка при поиске orphaned записей: {e}") return [] - + async def find_orphaned_files(self) -> List[str]: """Найти файлы на диске, для которых нет записей в БД""" try: if not os.path.exists(VOICE_USERS_DIR): logger.warning(f"Директория {VOICE_USERS_DIR} не существует") return [] - + # Получаем все файлы .ogg в директории ogg_files = list(Path(VOICE_USERS_DIR).glob("*.ogg")) orphaned_files = [] - + # Получаем все записи из БД all_audio_records = await self.bot_db.get_all_audio_records() - db_file_names = {record.get('file_name', '') for record in all_audio_records} - + db_file_names = { + record.get("file_name", "") for record in all_audio_records + } + for file_path in ogg_files: file_name = file_path.stem # Имя файла без расширения if file_name not in db_file_names: orphaned_files.append(str(file_path)) logger.warning(f"Найден файл без записи в БД: {file_path}") - + logger.info(f"Найдено {len(orphaned_files)} файлов без записей в БД") return orphaned_files - + except Exception as e: logger.error(f"Ошибка при поиске orphaned файлов: {e}") return [] - + async def cleanup_orphaned_db_records(self, dry_run: bool = True) -> int: """Удалить записи в БД, для которых нет файлов""" try: orphaned_records = await self.find_orphaned_db_records() - + if not orphaned_records: logger.info("Нет orphaned записей для удаления") return 0 - + if dry_run: - logger.info(f"DRY RUN: Найдено {len(orphaned_records)} записей для удаления") + logger.info( + f"DRY RUN: Найдено {len(orphaned_records)} записей для удаления" + ) for file_name, user_id in orphaned_records: - logger.info(f"DRY RUN: Будет удалена запись: {file_name} (user_id: {user_id})") + logger.info( + f"DRY RUN: Будет удалена запись: {file_name} (user_id: {user_id})" + ) return len(orphaned_records) - + # Удаляем записи deleted_count = 0 for file_name, user_id in orphaned_records: try: await self.bot_db.delete_audio_record_by_file_name(file_name) deleted_count += 1 - logger.info(f"Удалена запись в БД: {file_name} (user_id: {user_id})") + logger.info( + f"Удалена запись в БД: {file_name} (user_id: {user_id})" + ) except Exception as e: logger.error(f"Ошибка при удалении записи {file_name}: {e}") - + logger.info(f"Удалено {deleted_count} orphaned записей из БД") return deleted_count - + except Exception as e: logger.error(f"Ошибка при очистке orphaned записей: {e}") return 0 - + async def cleanup_orphaned_files(self, dry_run: bool = True) -> int: """Удалить файлы на диске, для которых нет записей в БД""" try: orphaned_files = await self.find_orphaned_files() - + if not orphaned_files: logger.info("Нет orphaned файлов для удаления") return 0 - + if dry_run: - logger.info(f"DRY RUN: Найдено {len(orphaned_files)} файлов для удаления") + logger.info( + f"DRY RUN: Найдено {len(orphaned_files)} файлов для удаления" + ) for file_path in orphaned_files: logger.info(f"DRY RUN: Будет удален файл: {file_path}") return len(orphaned_files) - + # Удаляем файлы deleted_count = 0 for file_path in orphaned_files: @@ -123,70 +138,76 @@ class VoiceFileCleanupUtils: logger.info(f"Удален файл: {file_path}") except Exception as e: logger.error(f"Ошибка при удалении файла {file_path}: {e}") - + logger.info(f"Удалено {deleted_count} orphaned файлов") return deleted_count - + except Exception as e: logger.error(f"Ошибка при очистке orphaned файлов: {e}") return 0 - + async def get_disk_usage_stats(self) -> dict: """Получить статистику использования диска""" try: if not os.path.exists(VOICE_USERS_DIR): return {"error": f"Директория {VOICE_USERS_DIR} не существует"} - + total_size = 0 file_count = 0 - + for file_path in Path(VOICE_USERS_DIR).glob("*.ogg"): if file_path.is_file(): total_size += file_path.stat().st_size file_count += 1 - + return { "total_files": file_count, "total_size_bytes": total_size, "total_size_mb": round(total_size / (1024 * 1024), 2), - "directory": VOICE_USERS_DIR + "directory": VOICE_USERS_DIR, } - + except Exception as e: logger.error(f"Ошибка при получении статистики диска: {e}") return {"error": str(e)} - + async def run_full_diagnostic(self) -> dict: """Запустить полную диагностику""" try: logger.info("Запуск полной диагностики голосовых файлов...") - + # Статистика диска disk_stats = await self.get_disk_usage_stats() - + # Orphaned записи в БД orphaned_db_records = await self.find_orphaned_db_records() - + # Orphaned файлы orphaned_files = await self.find_orphaned_files() - + # Количество записей в БД all_audio_records = await self.bot_db.get_all_audio_records() db_records_count = len(all_audio_records) - + diagnostic_result = { "disk_stats": disk_stats, "db_records_count": db_records_count, "orphaned_db_records_count": len(orphaned_db_records), "orphaned_files_count": len(orphaned_files), - "orphaned_db_records": orphaned_db_records[:10], # Первые 10 для примера + "orphaned_db_records": orphaned_db_records[ + :10 + ], # Первые 10 для примера "orphaned_files": orphaned_files[:10], # Первые 10 для примера - "status": "healthy" if len(orphaned_db_records) == 0 and len(orphaned_files) == 0 else "issues_found" + "status": ( + "healthy" + if len(orphaned_db_records) == 0 and len(orphaned_files) == 0 + else "issues_found" + ), } - + logger.info(f"Диагностика завершена. Статус: {diagnostic_result['status']}") return diagnostic_result - + except Exception as e: logger.error(f"Ошибка при диагностике: {e}") return {"error": str(e)} diff --git a/helper_bot/handlers/voice/constants.py b/helper_bot/handlers/voice/constants.py index 18e5060..8774595 100644 --- a/helper_bot/handlers/voice/constants.py +++ b/helper_bot/handlers/voice/constants.py @@ -17,10 +17,10 @@ CMD_REFRESH = "refresh" # Command to command mapping for metrics COMMAND_MAPPING: Final[Dict[str, str]] = { "start": "voice_start", - "help": "voice_help", + "help": "voice_help", "restart": "voice_restart", "emoji": "voice_emoji", - "refresh": "voice_refresh" + "refresh": "voice_refresh", } # Button texts @@ -33,7 +33,7 @@ BUTTON_COMMAND_MAPPING: Final[Dict[str, str]] = { "🎧Послушать": "voice_listen", "Отменить": "voice_cancel", "🔄Сбросить прослушивания": "voice_refresh_listen", - "😊Узнать эмодзи": "voice_emoji" + "😊Узнать эмодзи": "voice_emoji", } # Callback data @@ -43,7 +43,7 @@ CALLBACK_DELETE = "delete" # Callback to command mapping for metrics CALLBACK_COMMAND_MAPPING: Final[Dict[str, str]] = { "save": "voice_save", - "delete": "voice_delete" + "delete": "voice_delete", } # File paths diff --git a/helper_bot/handlers/voice/exceptions.py b/helper_bot/handlers/voice/exceptions.py index ac43ef8..f991cdc 100644 --- a/helper_bot/handlers/voice/exceptions.py +++ b/helper_bot/handlers/voice/exceptions.py @@ -1,23 +1,28 @@ class VoiceBotError(Exception): """Базовое исключение для voice_bot""" + pass class VoiceMessageError(VoiceBotError): """Ошибка при работе с голосовыми сообщениями""" + pass class AudioProcessingError(VoiceBotError): """Ошибка при обработке аудио""" + pass class DatabaseError(VoiceBotError): """Ошибка базы данных""" + pass class FileOperationError(VoiceBotError): """Ошибка при работе с файлами""" + pass diff --git a/helper_bot/handlers/voice/services.py b/helper_bot/handlers/voice/services.py index d08ff59..36808ee 100644 --- a/helper_bot/handlers/voice/services.py +++ b/helper_bot/handlers/voice/services.py @@ -7,16 +7,24 @@ from pathlib import Path from typing import List, Optional, Tuple from aiogram.types import FSInputFile -from helper_bot.handlers.voice.constants import (MESSAGE_DELAY_1, - MESSAGE_DELAY_2, - MESSAGE_DELAY_3, - MESSAGE_DELAY_4, STICK_DIR, - STICK_PATTERN, STICKER_DELAY, - VOICE_USERS_DIR) -from helper_bot.handlers.voice.exceptions import (AudioProcessingError, - DatabaseError, - FileOperationError, - VoiceMessageError) + +from helper_bot.handlers.voice.constants import ( + MESSAGE_DELAY_1, + MESSAGE_DELAY_2, + MESSAGE_DELAY_3, + MESSAGE_DELAY_4, + STICK_DIR, + STICK_PATTERN, + STICKER_DELAY, + VOICE_USERS_DIR, +) +from helper_bot.handlers.voice.exceptions import ( + AudioProcessingError, + DatabaseError, + FileOperationError, + VoiceMessageError, +) + # Local imports - metrics from helper_bot.utils.metrics import db_query_time, track_errors, track_time from logs.custom_logger import logger @@ -24,19 +32,23 @@ from logs.custom_logger import logger class VoiceMessage: """Модель голосового сообщения""" - def __init__(self, file_name: str, user_id: int, date_added: datetime, file_id: int): + + def __init__( + self, file_name: str, user_id: int, date_added: datetime, file_id: int + ): self.file_name = file_name self.user_id = user_id self.date_added = date_added self.file_id = file_id + class VoiceBotService: """Сервис для работы с голосовыми сообщениями""" - + def __init__(self, bot_db, settings): self.bot_db = bot_db self.settings = settings - + @track_time("get_welcome_sticker", "voice_bot_service") @track_errors("voice_bot_service", "get_welcome_sticker") async def get_welcome_sticker(self) -> Optional[FSInputFile]: @@ -45,17 +57,21 @@ class VoiceBotService: name_stick_hello = list(Path(STICK_DIR).rglob(STICK_PATTERN)) if not name_stick_hello: return None - + random_stick_hello = random.choice(name_stick_hello) random_stick_hello = FSInputFile(path=random_stick_hello) - logger.info(f"Стикер успешно получен. Наименование стикера: {random_stick_hello}") + logger.info( + f"Стикер успешно получен. Наименование стикера: {random_stick_hello}" + ) return random_stick_hello except Exception as e: logger.error(f"Ошибка при получении стикера: {e}") - if self.settings['Settings']['logs']: - await self._send_error_to_logs(f'Отправка приветственных стикеров лажает. Ошибка: {e}') + if self.settings["Settings"]["logs"]: + await self._send_error_to_logs( + f"Отправка приветственных стикеров лажает. Ошибка: {e}" + ) return None - + @track_time("send_welcome_messages", "voice_bot_service") @track_errors("voice_bot_service", "send_welcome_messages") async def send_welcome_messages(self, message, user_emoji: str): @@ -66,92 +82,94 @@ class VoiceBotService: if sticker: await message.answer_sticker(sticker) await asyncio.sleep(STICKER_DELAY) - + # Отправляем приветственное сообщение markup = self._get_main_keyboard() await message.answer( - text="Привет.", - parse_mode='html', + text="Привет.", + parse_mode="html", reply_markup=markup, - disable_web_page_preview=not self.settings['Telegram']['preview_link'] + disable_web_page_preview=not self.settings["Telegram"]["preview_link"], ) await asyncio.sleep(STICKER_DELAY) - + # Отправляем описание await message.answer( text="Здесь можно послушать голосовые сообщения от совершенно незнакомых людей из Бийска", - parse_mode='html', + parse_mode="html", reply_markup=markup, - disable_web_page_preview=not self.settings['Telegram']['preview_link'] + disable_web_page_preview=not self.settings["Telegram"]["preview_link"], ) await asyncio.sleep(MESSAGE_DELAY_1) - + # Отправляем аналогию await message.answer( text="Это почти как написать письмо, положить его в бутылку и швырнуть в океан. Никогда не узнаешь, послушал его кто-то или нет и ответить тоже не получится..", - parse_mode='html', + parse_mode="html", reply_markup=markup, - disable_web_page_preview=not self.settings['Telegram']['preview_link'] + disable_web_page_preview=not self.settings["Telegram"]["preview_link"], ) await asyncio.sleep(MESSAGE_DELAY_2) - + # Отправляем правила await message.answer( text="Записывать можно всё что угодно — никаких правил нет. Главное — твой голос, хотя бы на 5-10 секунд", - parse_mode='html', + parse_mode="html", reply_markup=markup, - disable_web_page_preview=not self.settings['Telegram']['preview_link'] + disable_web_page_preview=not self.settings["Telegram"]["preview_link"], ) await asyncio.sleep(MESSAGE_DELAY_3) - + # Отправляем информацию об анонимности await message.answer( text="Здесь всё анонимно: тот, кому я отправлю твое сообщение, не узнает ни твое имя, ни твой аккаунт (так что можно не стесняться говорить то, что не стал(а) бы выкладывать в собственные соцсети)", - parse_mode='html', + parse_mode="html", reply_markup=markup, - disable_web_page_preview=not self.settings['Telegram']['preview_link'] + disable_web_page_preview=not self.settings["Telegram"]["preview_link"], ) await asyncio.sleep(MESSAGE_DELAY_4) - + # Отправляем предложения await message.answer( text="Если не знаешь, что сказать, можешь просто прочитать любое текстовое сообщение из недавно полученных или отправленных (или спеть, рассказать стихотворенье)", - parse_mode='html', + parse_mode="html", reply_markup=markup, - disable_web_page_preview=not self.settings['Telegram']['preview_link'] + disable_web_page_preview=not self.settings["Telegram"]["preview_link"], ) await asyncio.sleep(MESSAGE_DELAY_4) - + # Отправляем информацию об эмодзи await message.answer( text=f"Любые войсы будут помечены эмоджи. Твой эмоджи - {user_emoji}Таким эмоджи будут помечены твои сообщения для других Но другие люди не узнают кто за каким эмоджи скрывается:)", - parse_mode='html', + parse_mode="html", reply_markup=markup, - disable_web_page_preview=not self.settings['Telegram']['preview_link'] + disable_web_page_preview=not self.settings["Telegram"]["preview_link"], ) await asyncio.sleep(MESSAGE_DELAY_4) - + # Отправляем информацию о помощи await message.answer( text="Так же можешь ознакомиться с инструкцией к боту по команде /help", - parse_mode='html', + parse_mode="html", reply_markup=markup, - disable_web_page_preview=not self.settings['Telegram']['preview_link'] + disable_web_page_preview=not self.settings["Telegram"]["preview_link"], ) await asyncio.sleep(MESSAGE_DELAY_4) - + # Отправляем финальное сообщение await message.answer( text="Ну всё, достаточно инструкций. записывайся! Микрофон твой - 🎤", - parse_mode='html', + parse_mode="html", reply_markup=markup, - disable_web_page_preview=not self.settings['Telegram']['preview_link'] + disable_web_page_preview=not self.settings["Telegram"]["preview_link"], ) - + except Exception as e: logger.error(f"Ошибка при отправке приветственных сообщений: {e}") - raise VoiceMessageError(f"Не удалось отправить приветственные сообщения: {e}") - + raise VoiceMessageError( + f"Не удалось отправить приветственные сообщения: {e}" + ) + @track_time("get_random_audio", "voice_bot_service") @track_errors("voice_bot_service", "get_random_audio") async def get_random_audio(self, user_id: int) -> Optional[Tuple[str, str, str]]: @@ -159,25 +177,25 @@ class VoiceBotService: try: check_audio = await self.bot_db.check_listen_audio(user_id=user_id) list_audio = list(check_audio) - + if not list_audio: return None - + # Получаем случайное аудио number_element = random.randint(0, len(list_audio) - 1) audio_for_user = check_audio[number_element] - + # Получаем информацию об авторе user_id_author = await self.bot_db.get_user_id_by_file_name(audio_for_user) date_added = await self.bot_db.get_date_by_file_name(audio_for_user) user_emoji = await self.bot_db.get_user_emoji(user_id_author) - + return audio_for_user, date_added, user_emoji - + except Exception as e: logger.error(f"Ошибка при получении случайного аудио: {e}") raise AudioProcessingError(f"Не удалось получить случайное аудио: {e}") - + @track_time("mark_audio_as_listened", "voice_bot_service") @track_errors("voice_bot_service", "mark_audio_as_listened") async def mark_audio_as_listened(self, file_name: str, user_id: int) -> None: @@ -187,7 +205,7 @@ class VoiceBotService: except Exception as e: logger.error(f"Ошибка при пометке аудио как прослушанного: {e}") raise DatabaseError(f"Не удалось пометить аудио как прослушанное: {e}") - + @track_time("clear_user_listenings", "voice_bot_service") @track_errors("voice_bot_service", "clear_user_listenings") @db_query_time("clear_user_listenings", "audio_moderate", "delete") @@ -198,7 +216,7 @@ class VoiceBotService: except Exception as e: logger.error(f"Ошибка при очистке прослушиваний: {e}") raise DatabaseError(f"Не удалось очистить прослушивания: {e}") - + @track_time("get_remaining_audio_count", "voice_bot_service") @track_errors("voice_bot_service", "get_remaining_audio_count") async def get_remaining_audio_count(self, user_id: int) -> int: @@ -209,25 +227,24 @@ class VoiceBotService: except Exception as e: logger.error(f"Ошибка при получении количества аудио: {e}") raise DatabaseError(f"Не удалось получить количество аудио: {e}") - + @track_time("get_main_keyboard", "voice_bot_service") @track_errors("voice_bot_service", "get_main_keyboard") def _get_main_keyboard(self): """Получить основную клавиатуру""" from helper_bot.keyboards.keyboards import get_main_keyboard + return get_main_keyboard() - + @track_time("send_error_to_logs", "voice_bot_service") @track_errors("voice_bot_service", "send_error_to_logs") async def _send_error_to_logs(self, message: str) -> None: """Отправить ошибку в логи""" try: from helper_bot.utils.helper_func import send_voice_message + await send_voice_message( - self.settings['Telegram']['important_logs'], - None, - None, - None + self.settings["Telegram"]["important_logs"], None, None, None ) except Exception as e: logger.error(f"Не удалось отправить ошибку в логи: {e}") @@ -235,45 +252,49 @@ class VoiceBotService: class AudioFileService: """Сервис для работы с аудио файлами""" - + def __init__(self, bot_db): self.bot_db = bot_db - + @track_time("generate_file_name", "audio_file_service") @track_errors("audio_file_service", "generate_file_name") async def generate_file_name(self, user_id: int) -> str: """Сгенерировать имя файла для аудио""" try: # Проверяем есть ли запись о файле в базе данных - user_audio_count = await self.bot_db.get_user_audio_records_count(user_id=user_id) - + user_audio_count = await self.bot_db.get_user_audio_records_count( + user_id=user_id + ) + if user_audio_count == 0: # Если нет, то генерируем имя файла - file_name = f'message_from_{user_id}_number_1' + file_name = f"message_from_{user_id}_number_1" else: # Иначе берем последнюю запись из БД, добавляем к ней 1 file_name = await self.bot_db.get_path_for_audio_record(user_id=user_id) if file_name: # Извлекаем номер из имени файла и увеличиваем на 1 try: - current_number = int(file_name.split('_')[-1]) + current_number = int(file_name.split("_")[-1]) new_number = current_number + 1 except (ValueError, IndexError): new_number = user_audio_count + 1 else: new_number = user_audio_count + 1 - - file_name = f'message_from_{user_id}_number_{new_number}' - + + file_name = f"message_from_{user_id}_number_{new_number}" + return file_name - + except Exception as e: logger.error(f"Ошибка при генерации имени файла: {e}") raise FileOperationError(f"Не удалось сгенерировать имя файла: {e}") - + @track_time("save_audio_file", "audio_file_service") @track_errors("audio_file_service", "save_audio_file") - async def save_audio_file(self, file_name: str, user_id: int, date_added: datetime, file_id: str) -> None: + async def save_audio_file( + self, file_name: str, user_id: int, date_added: datetime, file_id: str + ) -> None: """Сохранить информацию об аудио файле в базу данных""" try: # Проверяем существование файла перед сохранением в БД @@ -281,16 +302,20 @@ class AudioFileService: error_msg = f"Файл {file_name} не существует или поврежден, отменяем сохранение в БД" logger.error(error_msg) raise FileOperationError(error_msg) - + await self.bot_db.add_audio_record_simple(file_name, user_id, date_added) - logger.info(f"Информация об аудио файле успешно сохранена в БД: {file_name}") + logger.info( + f"Информация об аудио файле успешно сохранена в БД: {file_name}" + ) except Exception as e: logger.error(f"Ошибка при сохранении аудио файла в БД: {e}") raise DatabaseError(f"Не удалось сохранить аудио файл в БД: {e}") - + @track_time("save_audio_file_with_transaction", "audio_file_service") @track_errors("audio_file_service", "save_audio_file_with_transaction") - async def save_audio_file_with_transaction(self, file_name: str, user_id: int, date_added: datetime, file_id: str) -> None: + async def save_audio_file_with_transaction( + self, file_name: str, user_id: int, date_added: datetime, file_id: str + ) -> None: """Сохранить информацию об аудио файле в базу данных с транзакцией""" try: # Проверяем существование файла перед сохранением в БД @@ -298,68 +323,80 @@ class AudioFileService: error_msg = f"Файл {file_name} не существует или поврежден, отменяем сохранение в БД" logger.error(error_msg) raise FileOperationError(error_msg) - + # Используем транзакцию для атомарности операции await self.bot_db.add_audio_record_simple(file_name, user_id, date_added) - logger.info(f"Информация об аудио файле успешно сохранена в БД с транзакцией: {file_name}") + logger.info( + f"Информация об аудио файле успешно сохранена в БД с транзакцией: {file_name}" + ) except Exception as e: logger.error(f"Ошибка при сохранении аудио файла в БД с транзакцией: {e}") - raise DatabaseError(f"Не удалось сохранить аудио файл в БД с транзакцией: {e}") - + raise DatabaseError( + f"Не удалось сохранить аудио файл в БД с транзакцией: {e}" + ) + @track_time("download_and_save_audio", "audio_file_service") @track_errors("audio_file_service", "download_and_save_audio") - async def download_and_save_audio(self, bot, message, file_name: str, max_retries: int = 3) -> None: + async def download_and_save_audio( + self, bot, message, file_name: str, max_retries: int = 3 + ) -> None: """Скачать и сохранить аудио файл с retry механизмом""" last_exception = None - + for attempt in range(max_retries): try: - logger.info(f"Попытка {attempt + 1}/{max_retries} скачивания и сохранения аудио: {file_name}") - + logger.info( + f"Попытка {attempt + 1}/{max_retries} скачивания и сохранения аудио: {file_name}" + ) + # Проверяем наличие голосового сообщения if not message or not message.voice: error_msg = "Сообщение или голосовое сообщение не найдено" logger.error(error_msg) raise FileOperationError(error_msg) - + file_id = message.voice.file_id logger.info(f"Получен file_id: {file_id}") - + # Получаем информацию о файле try: file_info = await bot.get_file(file_id=file_id) logger.info(f"Получена информация о файле: {file_info.file_path}") except Exception as e: logger.error(f"Ошибка при получении информации о файле: {e}") - raise FileOperationError(f"Не удалось получить информацию о файле: {e}") - + raise FileOperationError( + f"Не удалось получить информацию о файле: {e}" + ) + # Скачиваем файл try: - downloaded_file = await bot.download_file(file_path=file_info.file_path) + downloaded_file = await bot.download_file( + file_path=file_info.file_path + ) except Exception as e: logger.error(f"Ошибка при скачивании файла: {e}") raise FileOperationError(f"Не удалось скачать файл: {e}") - + # Проверяем что файл успешно скачан if not downloaded_file: error_msg = "Не удалось скачать файл - получен пустой объект" logger.error(error_msg) raise FileOperationError(error_msg) - + # Получаем размер файла без изменения позиции current_pos = downloaded_file.tell() downloaded_file.seek(0, 2) # Переходим в конец файла file_size = downloaded_file.tell() downloaded_file.seek(current_pos) # Возвращаемся в исходную позицию - + logger.info(f"Файл скачан, размер: {file_size} bytes") - + # Проверяем минимальный размер файла if file_size < 100: # Минимальный размер для аудио файла error_msg = f"Файл слишком маленький: {file_size} bytes" logger.error(error_msg) raise FileOperationError(error_msg) - + # Создаем директорию если она не существует try: os.makedirs(VOICE_USERS_DIR, exist_ok=True) @@ -367,27 +404,27 @@ class AudioFileService: except Exception as e: logger.error(f"Ошибка при создании директории: {e}") raise FileOperationError(f"Не удалось создать директорию: {e}") - - file_path = f'{VOICE_USERS_DIR}/{file_name}.ogg' + + file_path = f"{VOICE_USERS_DIR}/{file_name}.ogg" logger.info(f"Сохраняем файл по пути: {file_path}") - + # Сбрасываем позицию в файле перед сохранением downloaded_file.seek(0) - + # Сохраняем файл try: - with open(file_path, 'wb') as new_file: + with open(file_path, "wb") as new_file: new_file.write(downloaded_file.read()) except Exception as e: logger.error(f"Ошибка при записи файла на диск: {e}") raise FileOperationError(f"Не удалось записать файл на диск: {e}") - + # Проверяем что файл действительно создался и имеет правильный размер if not os.path.exists(file_path): error_msg = f"Файл не был создан: {file_path}" logger.error(error_msg) raise FileOperationError(error_msg) - + saved_file_size = os.path.getsize(file_path) if saved_file_size != file_size: error_msg = f"Размер сохраненного файла не совпадает: ожидалось {file_size}, получено {saved_file_size}" @@ -398,48 +435,62 @@ class AudioFileService: except: pass raise FileOperationError(error_msg) - - logger.info(f"Файл успешно сохранен: {file_path}, размер: {saved_file_size} bytes") + + logger.info( + f"Файл успешно сохранен: {file_path}, размер: {saved_file_size} bytes" + ) return # Успешное завершение - + except Exception as e: last_exception = e logger.error(f"Попытка {attempt + 1}/{max_retries} неудачна: {e}") - + if attempt < max_retries - 1: - wait_time = (attempt + 1) * 2 # Экспоненциальная задержка: 2, 4, 6 секунд - logger.info(f"Ожидание {wait_time} секунд перед следующей попыткой...") + wait_time = ( + attempt + 1 + ) * 2 # Экспоненциальная задержка: 2, 4, 6 секунд + logger.info( + f"Ожидание {wait_time} секунд перед следующей попыткой..." + ) await asyncio.sleep(wait_time) else: logger.error(f"Все {max_retries} попыток скачивания неудачны") - logger.error(f"Traceback последней ошибки: {traceback.format_exc()}") - + logger.error( + f"Traceback последней ошибки: {traceback.format_exc()}" + ) + # Если все попытки неудачны - raise FileOperationError(f"Не удалось скачать и сохранить аудио после {max_retries} попыток. Последняя ошибка: {last_exception}") - + raise FileOperationError( + f"Не удалось скачать и сохранить аудио после {max_retries} попыток. Последняя ошибка: {last_exception}" + ) + @track_time("verify_file_exists", "audio_file_service") @track_errors("audio_file_service", "verify_file_exists") async def verify_file_exists(self, file_name: str) -> bool: """Проверить существование и валидность файла""" try: - file_path = f'{VOICE_USERS_DIR}/{file_name}.ogg' - + file_path = f"{VOICE_USERS_DIR}/{file_name}.ogg" + if not os.path.exists(file_path): logger.warning(f"Файл не существует: {file_path}") return False - + file_size = os.path.getsize(file_path) if file_size == 0: logger.warning(f"Файл пустой: {file_path}") return False - + if file_size < 100: # Минимальный размер для аудио файла - logger.warning(f"Файл слишком маленький: {file_path}, размер: {file_size} bytes") + logger.warning( + f"Файл слишком маленький: {file_path}, размер: {file_size} bytes" + ) return False - - logger.info(f"Файл проверен и валиден: {file_path}, размер: {file_size} bytes") + + logger.info( + f"Файл проверен и валиден: {file_path}, размер: {file_size} bytes" + ) return True - + except Exception as e: logger.error(f"Ошибка при проверке файла {file_name}: {e}") return False diff --git a/helper_bot/handlers/voice/utils.py b/helper_bot/handlers/voice/utils.py index ea64bfe..2ec5d8d 100644 --- a/helper_bot/handlers/voice/utils.py +++ b/helper_bot/handlers/voice/utils.py @@ -18,31 +18,37 @@ def format_time_ago(date_from_db: str) -> Optional[str]: last_voice_time_timestamp = time.mktime(parse_date.timetuple()) time_now_timestamp = time.time() date_difference = time_now_timestamp - last_voice_time_timestamp - + # Считаем минуты, часы, дни much_minutes_ago = round(date_difference / 60, 0) much_hour_ago = round(date_difference / 3600, 0) much_days_ago = int(round(much_hour_ago / 24, 0)) - - message_with_date = '' + + message_with_date = "" if much_minutes_ago <= 60: word_minute = plural_time(1, much_minutes_ago) # Экранируем потенциально проблемные символы word_minute_escaped = html.escape(word_minute) - message_with_date = f'Последнее сообщение было записано {word_minute_escaped} назад' + message_with_date = ( + f"Последнее сообщение было записано {word_minute_escaped} назад" + ) elif much_minutes_ago > 60 and much_hour_ago <= 24: word_hour = plural_time(2, much_hour_ago) # Экранируем потенциально проблемные символы word_hour_escaped = html.escape(word_hour) - message_with_date = f'Последнее сообщение было записано {word_hour_escaped} назад' + message_with_date = ( + f"Последнее сообщение было записано {word_hour_escaped} назад" + ) elif much_hour_ago > 24: word_day = plural_time(3, much_days_ago) # Экранируем потенциально проблемные символы word_day_escaped = html.escape(word_day) - message_with_date = f'Последнее сообщение было записано {word_day_escaped} назад' - + message_with_date = ( + f"Последнее сообщение было записано {word_day_escaped} назад" + ) + return message_with_date - + except Exception as e: logger.error(f"Ошибка при форматировании времени: {e}") return None @@ -52,11 +58,11 @@ def plural_time(type: int, n: float) -> str: """Форматировать множественное число для времени""" word = [] if type == 1: - word = ['минуту', 'минуты', 'минут'] + word = ["минуту", "минуты", "минут"] elif type == 2: - word = ['час', 'часа', 'часов'] + word = ["час", "часа", "часов"] elif type == 3: - word = ['день', 'дня', 'дней'] + word = ["день", "дня", "дней"] else: return str(int(n)) @@ -66,9 +72,10 @@ def plural_time(type: int, n: float) -> str: p = 1 else: p = 2 - + new_number = int(n) - return str(new_number) + ' ' + word[p] + return str(new_number) + " " + word[p] + @track_time("get_last_message_text", "voice_utils") @track_errors("voice_utils", "get_last_message_text") @@ -89,7 +96,8 @@ async def get_last_message_text(bot_db) -> Optional[str]: async def validate_voice_message(message) -> bool: """Проверить валидность голосового сообщения""" - return message.content_type == 'voice' + return message.content_type == "voice" + @track_time("get_user_emoji_safe", "voice_utils") @track_errors("voice_utils", "get_user_emoji_safe") @@ -98,7 +106,11 @@ async def get_user_emoji_safe(bot_db, user_id: int) -> str: """Безопасно получить эмодзи пользователя""" try: user_emoji = await bot_db.get_user_emoji(user_id) - return user_emoji if user_emoji and user_emoji != "Смайл еще не определен" else "😊" + return ( + user_emoji + if user_emoji and user_emoji != "Смайл еще не определен" + else "😊" + ) except Exception as e: logger.error(f"Ошибка при получении эмодзи пользователя {user_id}: {e}") return "😊" diff --git a/helper_bot/handlers/voice/voice_handler.py b/helper_bot/handlers/voice/voice_handler.py index f3ed377..3ea223e 100644 --- a/helper_bot/handlers/voice/voice_handler.py +++ b/helper_bot/handlers/voice/voice_handler.py @@ -6,31 +6,44 @@ from aiogram import F, Router, types from aiogram.filters import Command, MagicData, StateFilter from aiogram.fsm.context import FSMContext from aiogram.types import FSInputFile + from helper_bot.filters.main import ChatTypeFilter from helper_bot.handlers.private.constants import BUTTON_TEXTS, FSM_STATES from helper_bot.handlers.voice.constants import * from helper_bot.handlers.voice.services import VoiceBotService -from helper_bot.handlers.voice.utils import (get_last_message_text, - get_user_emoji_safe, - validate_voice_message) +from helper_bot.handlers.voice.utils import ( + get_last_message_text, + get_user_emoji_safe, + validate_voice_message, +) from helper_bot.keyboards import get_reply_keyboard -from helper_bot.keyboards.keyboards import (get_main_keyboard, - get_reply_keyboard_for_voice) +from helper_bot.keyboards.keyboards import ( + get_main_keyboard, + get_reply_keyboard_for_voice, +) from helper_bot.middlewares.blacklist_middleware import BlacklistMiddleware -from helper_bot.middlewares.dependencies_middleware import \ - DependenciesMiddleware +from helper_bot.middlewares.dependencies_middleware import DependenciesMiddleware from helper_bot.utils import messages -from helper_bot.utils.helper_func import (check_user_emoji, get_first_name, - send_voice_message, update_user_info) +from helper_bot.utils.helper_func import ( + check_user_emoji, + get_first_name, + send_voice_message, + update_user_info, +) + # Local imports - metrics -from helper_bot.utils.metrics import (db_query_time, track_errors, - track_file_operations, track_time) +from helper_bot.utils.metrics import ( + db_query_time, + track_errors, + track_file_operations, + track_time, +) from logs.custom_logger import logger class VoiceHandlers: def __init__(self, db, settings): - self.db = db.get_db() if hasattr(db, 'get_db') else db + self.db = db.get_db() if hasattr(db, "get_db") else db self.settings = settings self.router = Router() self._setup_handlers() @@ -44,102 +57,114 @@ class VoiceHandlers: self.router.message.register( self.cancel_handler, ChatTypeFilter(chat_type=["private"]), - F.text == "Отменить" + F.text == "Отменить", ) - + # Обработчик кнопки "Голосовой бот" self.router.message.register( self.voice_bot_button_handler, ChatTypeFilter(chat_type=["private"]), - F.text == BUTTON_TEXTS["VOICE_BOT"] + F.text == BUTTON_TEXTS["VOICE_BOT"], ) - + # Команды self.router.message.register( self.restart_function, ChatTypeFilter(chat_type=["private"]), - Command(CMD_RESTART) + Command(CMD_RESTART), ) - + self.router.message.register( self.handle_emoji_message, ChatTypeFilter(chat_type=["private"]), - Command(CMD_EMOJI) + Command(CMD_EMOJI), ) - + self.router.message.register( - self.help_function, - ChatTypeFilter(chat_type=["private"]), - Command(CMD_HELP) + self.help_function, ChatTypeFilter(chat_type=["private"]), Command(CMD_HELP) ) - + self.router.message.register( - self.start, - ChatTypeFilter(chat_type=["private"]), - Command(CMD_START) + self.start, ChatTypeFilter(chat_type=["private"]), Command(CMD_START) ) - + # Дополнительные команды self.router.message.register( self.refresh_listen_function, ChatTypeFilter(chat_type=["private"]), - Command(CMD_REFRESH) + Command(CMD_REFRESH), ) - + # Обработчики состояний и кнопок self.router.message.register( self.standup_write, StateFilter(STATE_START), ChatTypeFilter(chat_type=["private"]), - F.text == BTN_SPEAK + F.text == BTN_SPEAK, ) - + self.router.message.register( self.suggest_voice, StateFilter(STATE_STANDUP_WRITE), ChatTypeFilter(chat_type=["private"]), ) - + self.router.message.register( self.standup_listen_audio, StateFilter(STATE_START), ChatTypeFilter(chat_type=["private"]), - F.text == BTN_LISTEN + F.text == BTN_LISTEN, ) - + # Новые обработчики кнопок self.router.message.register( self.refresh_listen_function, ChatTypeFilter(chat_type=["private"]), - F.text == "🔄Сбросить прослушивания" + F.text == "🔄Сбросить прослушивания", ) - + self.router.message.register( self.handle_emoji_message, ChatTypeFilter(chat_type=["private"]), - F.text == "😊Узнать эмодзи" + F.text == "😊Узнать эмодзи", ) @track_time("voice_bot_button_handler", "voice_handlers") @track_errors("voice_handlers", "voice_bot_button_handler") - async def voice_bot_button_handler(self, message: types.Message, state: FSMContext, bot_db: MagicData("bot_db"), settings: MagicData("settings")): + async def voice_bot_button_handler( + self, + message: types.Message, + state: FSMContext, + bot_db: MagicData("bot_db"), + settings: MagicData("settings"), + ): """Обработчик кнопки 'Голосовой бот' из основной клавиатуры""" - logger.info(f"Пользователь {message.from_user.id} ({message.from_user.full_name}) нажал кнопку 'Голосовой бот'") + logger.info( + f"Пользователь {message.from_user.id} ({message.from_user.full_name}) нажал кнопку 'Голосовой бот'" + ) try: # Проверяем, получал ли пользователь приветственное сообщение - welcome_received = await bot_db.check_voice_bot_welcome_received(message.from_user.id) - logger.info(f"Пользователь {message.from_user.id}: welcome_received = {welcome_received}") - + welcome_received = await bot_db.check_voice_bot_welcome_received( + message.from_user.id + ) + logger.info( + f"Пользователь {message.from_user.id}: welcome_received = {welcome_received}" + ) + if welcome_received: # Если уже получал приветствие, вызываем restart_function - logger.info(f"Пользователь {message.from_user.id}: вызываем restart_function") + logger.info( + f"Пользователь {message.from_user.id}: вызываем restart_function" + ) await self.restart_function(message, state, bot_db, settings) else: # Если не получал, вызываем start logger.info(f"Пользователь {message.from_user.id}: вызываем start") await self.start(message, state, bot_db, settings) except Exception as e: - logger.error(f"Ошибка при проверке приветственного сообщения для пользователя {message.from_user.id}: {e}") + logger.error( + f"Ошибка при проверке приветственного сообщения для пользователя {message.from_user.id}: {e}" + ) # В случае ошибки вызываем start await self.start(message, state, bot_db, settings) @@ -147,49 +172,49 @@ class VoiceHandlers: @track_errors("voice_handlers", "restart_function") async def restart_function( self, - message: types.Message, - state: FSMContext, + message: types.Message, + state: FSMContext, bot_db: MagicData("bot_db"), - settings: MagicData("settings") + settings: MagicData("settings"), ): - logger.info(f"Пользователь {message.from_user.id}: вызывается функция restart_function") - await message.forward(chat_id=settings['Telegram']['group_for_logs']) + logger.info( + f"Пользователь {message.from_user.id}: вызывается функция restart_function" + ) + await message.forward(chat_id=settings["Telegram"]["group_for_logs"]) await update_user_info(VOICE_BOT_NAME, message) await check_user_emoji(message) markup = get_main_keyboard() - await message.answer(text='🎤 Записывайся или слушай!', reply_markup=markup) + await message.answer(text="🎤 Записывайся или слушай!", reply_markup=markup) await state.set_state(STATE_START) @track_time("handle_emoji_message", "voice_handlers") @track_errors("voice_handlers", "handle_emoji_message") async def handle_emoji_message( - self, - message: types.Message, - state: FSMContext, - settings: MagicData("settings") + self, message: types.Message, state: FSMContext, settings: MagicData("settings") ): - logger.info(f"Пользователь {message.from_user.id} ({message.from_user.full_name}) запросил информацию об эмодзи") - await message.forward(chat_id=settings['Telegram']['group_for_logs']) + logger.info( + f"Пользователь {message.from_user.id} ({message.from_user.full_name}) запросил информацию об эмодзи" + ) + await message.forward(chat_id=settings["Telegram"]["group_for_logs"]) user_emoji = await check_user_emoji(message) await state.set_state(STATE_START) if user_emoji is not None: - await message.answer(f'Твоя эмодзя - {user_emoji}', parse_mode='HTML') + await message.answer(f"Твоя эмодзя - {user_emoji}", parse_mode="HTML") @track_time("help_function", "voice_handlers") @track_errors("voice_handlers", "help_function") async def help_function( - self, - message: types.Message, - state: FSMContext, - settings: MagicData("settings") + self, message: types.Message, state: FSMContext, settings: MagicData("settings") ): - logger.info(f"Пользователь {message.from_user.id} ({message.from_user.full_name}) вызвал функцию help_function") - await message.forward(chat_id=settings['Telegram']['group_for_logs']) + logger.info( + f"Пользователь {message.from_user.id} ({message.from_user.full_name}) вызвал функцию help_function" + ) + await message.forward(chat_id=settings["Telegram"]["group_for_logs"]) await update_user_info(VOICE_BOT_NAME, message) - help_message = messages.get_message(get_first_name(message), 'HELP_MESSAGE') + help_message = messages.get_message(get_first_name(message), "HELP_MESSAGE") await message.answer( text=help_message, - disable_web_page_preview=not settings['Telegram']['preview_link'] + disable_web_page_preview=not settings["Telegram"]["preview_link"], ) await state.set_state(STATE_START) @@ -198,43 +223,53 @@ class VoiceHandlers: @db_query_time("mark_voice_bot_welcome_received", "audio_moderate", "update") async def start( self, - message: types.Message, - state: FSMContext, - bot_db: MagicData("bot_db"), - settings: MagicData("settings") + message: types.Message, + state: FSMContext, + bot_db: MagicData("bot_db"), + settings: MagicData("settings"), ): - logger.info(f"Пользователь {message.from_user.id} ({message.from_user.full_name}): вызывается функция start") + logger.info( + f"Пользователь {message.from_user.id} ({message.from_user.full_name}): вызывается функция start" + ) await state.set_state(STATE_START) - await message.forward(chat_id=settings['Telegram']['group_for_logs']) + await message.forward(chat_id=settings["Telegram"]["group_for_logs"]) await update_user_info(VOICE_BOT_NAME, message) user_emoji = await get_user_emoji_safe(bot_db, message.from_user.id) - + # Создаем сервис и отправляем приветственные сообщения voice_service = VoiceBotService(bot_db, settings) await voice_service.send_welcome_messages(message, user_emoji) - logger.info(f"Приветственные сообщения отправлены пользователю {message.from_user.id}") - + logger.info( + f"Приветственные сообщения отправлены пользователю {message.from_user.id}" + ) + # Отмечаем, что пользователь получил приветственное сообщение try: await bot_db.mark_voice_bot_welcome_received(message.from_user.id) - logger.info(f"Пользователь {message.from_user.id}: отмечен как получивший приветствие") + logger.info( + f"Пользователь {message.from_user.id}: отмечен как получивший приветствие" + ) except Exception as e: - logger.error(f"Ошибка при отметке получения приветствия для пользователя {message.from_user.id}: {e}") + logger.error( + f"Ошибка при отметке получения приветствия для пользователя {message.from_user.id}: {e}" + ) @track_time("cancel_handler", "voice_handlers") @track_errors("voice_handlers", "cancel_handler") async def cancel_handler( self, - message: types.Message, - state: FSMContext, + message: types.Message, + state: FSMContext, bot_db: MagicData("bot_db"), - settings: MagicData("settings") + settings: MagicData("settings"), ): """Обработчик кнопки 'Отменить' - возвращает в начальное состояние""" - await message.forward(chat_id=settings['Telegram']['group_for_logs']) + await message.forward(chat_id=settings["Telegram"]["group_for_logs"]) await update_user_info(VOICE_BOT_NAME, message) markup = await get_reply_keyboard(self.db, message.from_user.id) - await message.answer(text='Добро пожаловать в меню!', reply_markup=markup, parse_mode='HTML') + await message.answer( + text="Добро пожаловать в меню!", reply_markup=markup, parse_mode="HTML" + ) await state.set_state(FSM_STATES["START"]) logger.info(f"Пользователь {message.from_user.id} возвращен в главное меню") @@ -242,208 +277,253 @@ class VoiceHandlers: @track_errors("voice_handlers", "refresh_listen_function") async def refresh_listen_function( self, - message: types.Message, - state: FSMContext, + message: types.Message, + state: FSMContext, bot_db: MagicData("bot_db"), - settings: MagicData("settings") - ): - logger.info(f"Пользователь {message.from_user.id} ({message.from_user.full_name}) вызвал функцию refresh_listen_function") - await message.forward(chat_id=settings['Telegram']['group_for_logs']) + settings: MagicData("settings"), + ): + logger.info( + f"Пользователь {message.from_user.id} ({message.from_user.full_name}) вызвал функцию refresh_listen_function" + ) + await message.forward(chat_id=settings["Telegram"]["group_for_logs"]) await update_user_info(VOICE_BOT_NAME, message) markup = get_main_keyboard() - + # Очищаем прослушивания через сервис voice_service = VoiceBotService(bot_db, settings) await voice_service.clear_user_listenings(message.from_user.id) - listenings_cleared_message = messages.get_message(get_first_name(message), 'LISTENINGS_CLEARED_MESSAGE') + listenings_cleared_message = messages.get_message( + get_first_name(message), "LISTENINGS_CLEARED_MESSAGE" + ) await message.answer( text=listenings_cleared_message, - disable_web_page_preview=not settings['Telegram']['preview_link'], - reply_markup=markup + disable_web_page_preview=not settings["Telegram"]["preview_link"], + reply_markup=markup, ) await state.set_state(STATE_START) - @track_time("standup_write", "voice_handlers") @track_errors("voice_handlers", "standup_write") async def standup_write( self, - message: types.Message, - state: FSMContext, + message: types.Message, + state: FSMContext, bot_db: MagicData("bot_db"), - settings: MagicData("settings") - ): - logger.info(f"Пользователь {message.from_user.id} ({message.from_user.full_name}) вызвал функцию standup_write") - await message.forward(chat_id=settings['Telegram']['group_for_logs']) + settings: MagicData("settings"), + ): + logger.info( + f"Пользователь {message.from_user.id} ({message.from_user.full_name}) вызвал функцию standup_write" + ) + await message.forward(chat_id=settings["Telegram"]["group_for_logs"]) markup = types.ReplyKeyboardRemove() - record_voice_message = messages.get_message(get_first_name(message), 'RECORD_VOICE_MESSAGE') + record_voice_message = messages.get_message( + get_first_name(message), "RECORD_VOICE_MESSAGE" + ) await message.answer(text=record_voice_message, reply_markup=markup) - + try: message_with_date = await get_last_message_text(bot_db) if message_with_date: await message.answer(text=message_with_date, parse_mode="html") except Exception as e: - logger.error(f'Не удалось получить дату последнего сообщения для пользователя {message.from_user.id}: {e}') - - await state.set_state(STATE_STANDUP_WRITE) + logger.error( + f"Не удалось получить дату последнего сообщения для пользователя {message.from_user.id}: {e}" + ) + await state.set_state(STATE_STANDUP_WRITE) @track_time("suggest_voice", "voice_handlers") @track_errors("voice_handlers", "suggest_voice") async def suggest_voice( self, - message: types.Message, - state: FSMContext, + message: types.Message, + state: FSMContext, bot_db: MagicData("bot_db"), - settings: MagicData("settings") - ): + settings: MagicData("settings"), + ): logger.info( f"Вызов функции suggest_voice. Пользователь: {message.from_user.id} Имя автора сообщения: {message.from_user.full_name}" ) - await message.forward(chat_id=settings['Telegram']['group_for_logs']) + await message.forward(chat_id=settings["Telegram"]["group_for_logs"]) markup = get_main_keyboard() - + if await validate_voice_message(message): markup_for_voice = get_reply_keyboard_for_voice() - + # Отправляем аудио в приватный канал sent_message = await send_voice_message( - settings['Telegram']['group_for_posts'], + settings["Telegram"]["group_for_posts"], message, - message.voice.file_id, - markup_for_voice + message.voice.file_id, + markup_for_voice, + ) + logger.info( + f"Голосовое сообщение пользователя {message.from_user.id} отправлено в группу постов (message_id: {sent_message.message_id})" ) - logger.info(f"Голосовое сообщение пользователя {message.from_user.id} отправлено в группу постов (message_id: {sent_message.message_id})") # Сохраняем в базу инфо о посте - await bot_db.set_user_id_and_message_id_for_voice_bot(sent_message.message_id, message.from_user.id) + await bot_db.set_user_id_and_message_id_for_voice_bot( + sent_message.message_id, message.from_user.id + ) # Отправляем юзеру ответ и возвращаем его в меню - voice_saved_message = messages.get_message(get_first_name(message), 'VOICE_SAVED_MESSAGE') + voice_saved_message = messages.get_message( + get_first_name(message), "VOICE_SAVED_MESSAGE" + ) await message.answer(text=voice_saved_message, reply_markup=markup) await state.set_state(STATE_START) else: - logger.warning(f"Голосовое сообщение пользователя {message.from_user.id} не прошло валидацию") - unknown_content_message = messages.get_message(get_first_name(message), 'UNKNOWN_CONTENT_MESSAGE') - await message.forward(chat_id=settings['Telegram']['group_for_logs']) + logger.warning( + f"Голосовое сообщение пользователя {message.from_user.id} не прошло валидацию" + ) + unknown_content_message = messages.get_message( + get_first_name(message), "UNKNOWN_CONTENT_MESSAGE" + ) + await message.forward(chat_id=settings["Telegram"]["group_for_logs"]) await message.answer(text=unknown_content_message, reply_markup=markup) await state.set_state(STATE_STANDUP_WRITE) - @track_time("standup_listen_audio", "voice_handlers") @track_errors("voice_handlers", "standup_listen_audio") @track_file_operations("voice") @db_query_time("standup_listen_audio", "audio_moderate", "mixed") async def standup_listen_audio( self, - message: types.Message, + message: types.Message, bot_db: MagicData("bot_db"), - settings: MagicData("settings") - ): - logger.info(f"Пользователь {message.from_user.id} ({message.from_user.full_name}) запросил прослушивание аудио") + settings: MagicData("settings"), + ): + logger.info( + f"Пользователь {message.from_user.id} ({message.from_user.full_name}) запросил прослушивание аудио" + ) markup = get_main_keyboard() - + # Создаем сервис для работы с аудио voice_service = VoiceBotService(bot_db, settings) - + try: - #TODO: удалить логику из хендлера + # TODO: удалить логику из хендлера # Получаем случайное аудио audio_data = await voice_service.get_random_audio(message.from_user.id) - + if not audio_data: - logger.warning(f"Для пользователя {message.from_user.id} не найдено доступных аудио для прослушивания") - no_audio_message = messages.get_message(get_first_name(message), 'NO_AUDIO_MESSAGE') + logger.warning( + f"Для пользователя {message.from_user.id} не найдено доступных аудио для прослушивания" + ) + no_audio_message = messages.get_message( + get_first_name(message), "NO_AUDIO_MESSAGE" + ) await message.answer(text=no_audio_message, reply_markup=markup) try: message_with_date = await get_last_message_text(bot_db) if message_with_date: await message.answer(text=message_with_date, parse_mode="html") except Exception as e: - logger.error(f'Не удалось получить последнюю дату для пользователя {message.from_user.id}: {e}') + logger.error( + f"Не удалось получить последнюю дату для пользователя {message.from_user.id}: {e}" + ) return - + audio_for_user, date_added, user_emoji = audio_data - + # Получаем путь к файлу - path = Path(f'{VOICE_USERS_DIR}/{audio_for_user}.ogg') - + path = Path(f"{VOICE_USERS_DIR}/{audio_for_user}.ogg") + # Проверяем существование файла if not path.exists(): - logger.error(f"Файл не найден: {path} для пользователя {message.from_user.id}") + logger.error( + f"Файл не найден: {path} для пользователя {message.from_user.id}" + ) # Дополнительная диагностика - logger.error(f"Директория {VOICE_USERS_DIR} существует: {Path(VOICE_USERS_DIR).exists()}") + logger.error( + f"Директория {VOICE_USERS_DIR} существует: {Path(VOICE_USERS_DIR).exists()}" + ) if Path(VOICE_USERS_DIR).exists(): files_in_dir = list(Path(VOICE_USERS_DIR).glob("*.ogg")) - logger.error(f"Файлы в директории: {[f.name for f in files_in_dir]}") - + logger.error( + f"Файлы в директории: {[f.name for f in files_in_dir]}" + ) + await message.answer( text="Файл аудио не найден. Обратитесь к администратору.", - reply_markup=markup + reply_markup=markup, ) return - + # Проверяем размер файла if path.stat().st_size == 0: - logger.error(f"Файл пустой: {path} для пользователя {message.from_user.id}") + logger.error( + f"Файл пустой: {path} для пользователя {message.from_user.id}" + ) await message.answer( text="Файл аудио поврежден. Обратитесь к администратору.", - reply_markup=markup + reply_markup=markup, ) return - + voice = FSInputFile(path) # Формируем подпись if user_emoji: - caption = f'{user_emoji}\nДата записи: {date_added}' + caption = f"{user_emoji}\nДата записи: {date_added}" else: - caption = f'Дата записи: {date_added}' - - logger.info(f"Подготовлено голосовое сообщение для пользователя {message.from_user.id}: {caption}") - + caption = f"Дата записи: {date_added}" + + logger.info( + f"Подготовлено голосовое сообщение для пользователя {message.from_user.id}: {caption}" + ) + try: from helper_bot.utils.rate_limiter import send_with_rate_limit - + async def _send_voice(): return await message.bot.send_voice( - chat_id=message.chat.id, - voice=voice, - caption=caption, - reply_markup=markup + chat_id=message.chat.id, + voice=voice, + caption=caption, + reply_markup=markup, ) - + await send_with_rate_limit(_send_voice, message.chat.id) - + # Маркируем сообщение как прослушанное только после успешной отправки - await voice_service.mark_audio_as_listened(audio_for_user, message.from_user.id) - - # Получаем количество оставшихся аудио только после успешной отправки - remaining_count = await voice_service.get_remaining_audio_count(message.from_user.id) - await message.answer( - text=f'Осталось непрослушанных: {remaining_count}', - reply_markup=markup + await voice_service.mark_audio_as_listened( + audio_for_user, message.from_user.id ) - + + # Получаем количество оставшихся аудио только после успешной отправки + remaining_count = await voice_service.get_remaining_audio_count( + message.from_user.id + ) + await message.answer( + text=f"Осталось непрослушанных: {remaining_count}", + reply_markup=markup, + ) + except Exception as voice_error: if "VOICE_MESSAGES_FORBIDDEN" in str(voice_error): # Если голосовые сообщения запрещены, отправляем информативное сообщение - logger.warning(f"Пользователь {message.from_user.id} запретил получение голосовых сообщений") - + logger.warning( + f"Пользователь {message.from_user.id} запретил получение голосовых сообщений" + ) + privacy_message = "🔇 К сожалению, у тебя закрыт доступ к получению голосовых сообщений.\n\nДля продолжения взаимодействия с ботом необходимо дать возможность мне присылать войсы в настройках приватности Telegram.\n\n💡 Как это сделать:\n1. Открой настройки Telegram\n2. Перейди в 'Конфиденциальность и безопасность'\n3. Выбери 'Голосовые сообщения'\n4. Разреши получение от 'Всех' или добавь меня в исключения" - + await message.answer(text=privacy_message, reply_markup=markup) return # Выходим без записи о прослушивании - + else: - logger.error(f"Ошибка при отправке голосового сообщения пользователю {message.from_user.id}: {voice_error}") + logger.error( + f"Ошибка при отправке голосового сообщения пользователю {message.from_user.id}: {voice_error}" + ) raise voice_error - + except Exception as e: - logger.error(f"Ошибка при прослушивании аудио для пользователя {message.from_user.id}: {e}") + logger.error( + f"Ошибка при прослушивании аудио для пользователя {message.from_user.id}: {e}" + ) await message.answer( text="Произошла ошибка при получении аудио. Попробуйте позже.", - reply_markup=markup + reply_markup=markup, ) diff --git a/helper_bot/keyboards/keyboards.py b/helper_bot/keyboards/keyboards.py index 3fd4f3c..ed605ad 100644 --- a/helper_bot/keyboards/keyboards.py +++ b/helper_bot/keyboards/keyboards.py @@ -1,24 +1,21 @@ from aiogram import types from aiogram.utils.keyboard import InlineKeyboardBuilder, ReplyKeyboardBuilder + # Local imports - metrics from helper_bot.utils.metrics import track_errors, track_time def get_reply_keyboard_for_post(): builder = InlineKeyboardBuilder() - builder.row(types.InlineKeyboardButton( - text="Опубликовать", callback_data="publish"), - types.InlineKeyboardButton( - text="Отклонить", callback_data="decline") - ) - builder.row(types.InlineKeyboardButton( - text="👮‍♂️ Забанить", callback_data="ban") + builder.row( + types.InlineKeyboardButton(text="Опубликовать", callback_data="publish"), + types.InlineKeyboardButton(text="Отклонить", callback_data="decline"), ) + builder.row(types.InlineKeyboardButton(text="👮‍♂️ Забанить", callback_data="ban")) markup = builder.as_markup(resize_keyboard=True, one_time_keyboard=True) return markup - async def get_reply_keyboard(db, user_id): builder = ReplyKeyboardBuilder() builder.row(types.KeyboardButton(text="📢Предложить свой пост")) @@ -43,21 +40,22 @@ def get_reply_keyboard_admin(): builder.row( types.KeyboardButton(text="Бан (Список)"), types.KeyboardButton(text="Бан по нику"), - types.KeyboardButton(text="Бан по ID") + types.KeyboardButton(text="Бан по ID"), ) builder.row( types.KeyboardButton(text="Разбан (список)"), - types.KeyboardButton(text="📊 ML Статистика") - ) - builder.row( - types.KeyboardButton(text="Вернуться в бота") + types.KeyboardButton(text="📊 ML Статистика"), ) + builder.row(types.KeyboardButton(text="Вернуться в бота")) markup = builder.as_markup(resize_keyboard=True, one_time_keyboard=True) return markup + @track_time("create_keyboard_with_pagination", "keyboard_service") @track_errors("keyboard_service", "create_keyboard_with_pagination") -def create_keyboard_with_pagination(page: int, total_items: int, array_items: list, callback: str): +def create_keyboard_with_pagination( + page: int, total_items: int, array_items: list, callback: str +): """ Создает клавиатуру с пагинацией для заданного набора элементов и устанавливает необходимый callback @@ -70,74 +68,79 @@ def create_keyboard_with_pagination(page: int, total_items: int, array_items: li Returns: InlineKeyboardMarkup: Клавиатура с кнопками пагинации. """ - + # Проверяем валидность входных данных if page < 1: page = 1 if not array_items: # Если нет элементов, возвращаем только кнопку "Назад" keyboard = InlineKeyboardBuilder() - home_button = types.InlineKeyboardButton(text="🏠 Назад", callback_data="return") + home_button = types.InlineKeyboardButton( + text="🏠 Назад", callback_data="return" + ) keyboard.row(home_button) return keyboard.as_markup() # Определяем общее количество страниц items_per_page = 9 total_pages = (total_items + items_per_page - 1) // items_per_page - + # Ограничиваем страницу максимальным значением if page > total_pages: page = total_pages # Создаем билдер для клавиатуры keyboard = InlineKeyboardBuilder() - + # Вычисляем стартовый номер для текущей страницы start_index = (page - 1) * items_per_page - + # Кнопки с элементами текущей страницы end_index = min(start_index + items_per_page, len(array_items)) current_row = [] - + for i in range(start_index, end_index): - current_row.append(types.InlineKeyboardButton( - text=f"{array_items[i][0]}", callback_data=f"{callback}_{array_items[i][1]}" - )) - + current_row.append( + types.InlineKeyboardButton( + text=f"{array_items[i][0]}", + callback_data=f"{callback}_{array_items[i][1]}", + ) + ) + # Когда набирается 3 кнопки, добавляем ряд if len(current_row) == 3: keyboard.row(*current_row) current_row = [] - + # Добавляем оставшиеся кнопки, если они есть if current_row: keyboard.row(*current_row) - + # Создаем кнопки навигации только если нужно navigation_buttons = [] - + # Кнопка "Предыдущая" - показываем только если не первая страница if page > 1: prev_button = types.InlineKeyboardButton( text="⬅️ Предыдущая", callback_data=f"page_{page - 1}" ) navigation_buttons.append(prev_button) - + # Кнопка "Следующая" - показываем только если не последняя страница if page < total_pages: next_button = types.InlineKeyboardButton( text="➡️ Следующая", callback_data=f"page_{page + 1}" ) navigation_buttons.append(next_button) - + # Добавляем кнопки навигации, если они есть if navigation_buttons: keyboard.row(*navigation_buttons) - + # Кнопка "Назад" home_button = types.InlineKeyboardButton(text="🏠 Назад", callback_data="return") keyboard.row(home_button) - + return keyboard.as_markup() @@ -146,7 +149,11 @@ def create_keyboard_for_ban_reason(): builder.add(types.KeyboardButton(text="Спам")) builder.add(types.KeyboardButton(text="Заебал стикерами")) builder.row(types.KeyboardButton(text="Реклама здесь: @kerrad1 ")) - builder.row(types.KeyboardButton(text="Тема с лагерями: https://vk.com/topic-75343895_50049913")) + builder.row( + types.KeyboardButton( + text="Тема с лагерями: https://vk.com/topic-75343895_50049913" + ) + ) builder.row(types.KeyboardButton(text="Отменить")) markup = builder.as_markup(resize_keyboard=True, one_time_keyboard=True) return markup @@ -176,12 +183,12 @@ def get_main_keyboard(): # Первая строка: Высказаться и послушать builder.row( types.KeyboardButton(text="🎤Высказаться"), - types.KeyboardButton(text="🎧Послушать") + types.KeyboardButton(text="🎧Послушать"), ) # Вторая строка: сбросить прослушивания и узнать эмодзи builder.row( types.KeyboardButton(text="🔄Сбросить прослушивания"), - types.KeyboardButton(text="😊Узнать эмодзи") + types.KeyboardButton(text="😊Узнать эмодзи"), ) # Третья строка: Вернуться в меню builder.row(types.KeyboardButton(text="Отменить")) @@ -191,11 +198,7 @@ def get_main_keyboard(): def get_reply_keyboard_for_voice(): builder = InlineKeyboardBuilder() - builder.row(types.InlineKeyboardButton( - text="Сохранить", callback_data="save") - ) - builder.row(types.InlineKeyboardButton( - text="Удалить", callback_data="delete") - ) + builder.row(types.InlineKeyboardButton(text="Сохранить", callback_data="save")) + builder.row(types.InlineKeyboardButton(text="Удалить", callback_data="delete")) markup = builder.as_markup(resize_keyboard=True, one_time_keyboard=True) return markup diff --git a/helper_bot/main.py b/helper_bot/main.py index a85ee0a..f055db6 100644 --- a/helper_bot/main.py +++ b/helper_bot/main.py @@ -6,22 +6,25 @@ from aiogram import Bot, Dispatcher from aiogram.client.default import DefaultBotProperties from aiogram.fsm.storage.memory import MemoryStorage from aiogram.fsm.strategy import FSMStrategy + from helper_bot.handlers.admin import admin_router from helper_bot.handlers.callback import callback_router from helper_bot.handlers.group import group_router from helper_bot.handlers.private import private_router from helper_bot.handlers.voice import VoiceHandlers from helper_bot.middlewares.blacklist_middleware import BlacklistMiddleware -from helper_bot.middlewares.dependencies_middleware import \ - DependenciesMiddleware -from helper_bot.middlewares.metrics_middleware import (ErrorMetricsMiddleware, - MetricsMiddleware) +from helper_bot.middlewares.dependencies_middleware import DependenciesMiddleware +from helper_bot.middlewares.metrics_middleware import ( + ErrorMetricsMiddleware, + MetricsMiddleware, +) from helper_bot.middlewares.rate_limit_middleware import RateLimitMiddleware -from helper_bot.server_prometheus import (start_metrics_server, - stop_metrics_server) +from helper_bot.server_prometheus import start_metrics_server, stop_metrics_server -async def start_bot_with_retry(bot: Bot, dp: Dispatcher, max_retries: int = 5, base_delay: float = 1.0): +async def start_bot_with_retry( + bot: Bot, dp: Dispatcher, max_retries: int = 5, base_delay: float = 1.0 +): """Запуск бота с автоматическим перезапуском при сетевых ошибках""" for attempt in range(max_retries): try: @@ -30,14 +33,21 @@ async def start_bot_with_retry(bot: Bot, dp: Dispatcher, max_retries: int = 5, b break except Exception as e: error_msg = str(e).lower() - if any(keyword in error_msg for keyword in ['network', 'disconnected', 'timeout', 'connection']): + if any( + keyword in error_msg + for keyword in ["network", "disconnected", "timeout", "connection"] + ): if attempt < max_retries - 1: - delay = base_delay * (2 ** attempt) # Exponential backoff - logging.warning(f"Сетевая ошибка при запуске бота: {e}. Повтор через {delay:.1f}с (попытка {attempt + 1}/{max_retries})") + delay = base_delay * (2**attempt) # Exponential backoff + logging.warning( + f"Сетевая ошибка при запуске бота: {e}. Повтор через {delay:.1f}с (попытка {attempt + 1}/{max_retries})" + ) await asyncio.sleep(delay) continue else: - logging.error(f"Превышено максимальное количество попыток запуска бота: {e}") + logging.error( + f"Превышено максимальное количество попыток запуска бота: {e}" + ) raise else: logging.error(f"Критическая ошибка при запуске бота: {e}") @@ -45,30 +55,36 @@ async def start_bot_with_retry(bot: Bot, dp: Dispatcher, max_retries: int = 5, b async def start_bot(bdf): - token = bdf.settings['Telegram']['bot_token'] - bot = Bot(token=token, default=DefaultBotProperties( - parse_mode='HTML', - link_preview_is_disabled=bdf.settings['Telegram']['preview_link'] - ), timeout=60.0) # Увеличиваем timeout для стабильности - + token = bdf.settings["Telegram"]["bot_token"] + bot = Bot( + token=token, + default=DefaultBotProperties( + parse_mode="HTML", + link_preview_is_disabled=bdf.settings["Telegram"]["preview_link"], + ), + timeout=60.0, + ) # Увеличиваем timeout для стабильности + dp = Dispatcher(storage=MemoryStorage(), fsm_strategy=FSMStrategy.GLOBAL_USER) - + # ✅ Оптимизированная регистрация middleware dp.update.outer_middleware(DependenciesMiddleware()) dp.update.outer_middleware(MetricsMiddleware()) dp.update.outer_middleware(BlacklistMiddleware()) dp.update.outer_middleware(RateLimitMiddleware()) - + # Создаем экземпляр VoiceHandlers voice_handlers = VoiceHandlers(bdf, bdf.settings) voice_router = voice_handlers.router - + # Middleware уже добавлены на уровне dispatcher - dp.include_routers(admin_router, private_router, callback_router, group_router, voice_router) - + dp.include_routers( + admin_router, private_router, callback_router, group_router, voice_router + ) + # Получаем scoring_manager для использования в shutdown scoring_manager = bdf.get_scoring_manager() - + # Добавляем обработчик завершения для корректного закрытия @dp.shutdown() async def on_shutdown(): @@ -81,25 +97,25 @@ async def start_bot(bdf): logging.info("ScoringManager закрыт") except Exception as e: logging.error(f"Ошибка закрытия ScoringManager: {e}") - + await bot.session.close() logging.info("Bot session closed successfully") except Exception as e: logging.error(f"Error closing bot session during shutdown: {e}") - + await bot.delete_webhook(drop_pending_updates=True) - + # Запускаем HTTP сервер для метрик параллельно с ботом - metrics_host = bdf.settings.get('Metrics', {}).get('host', '0.0.0.0') - metrics_port = bdf.settings.get('Metrics', {}).get('port', 8080) - + metrics_host = bdf.settings.get("Metrics", {}).get("host", "0.0.0.0") + metrics_port = bdf.settings.get("Metrics", {}).get("port", 8080) + try: # Запускаем метрики сервер await start_metrics_server(metrics_host, metrics_port) - + logging.info(f"✅ Метрики сервер запущен на {metrics_host}:{metrics_port}") logging.info("✅ Метрики будут обновляться в реальном времени через middleware") - + # Запускаем бота с retry логикой await start_bot_with_retry(bot, dp) @@ -115,13 +131,13 @@ async def start_bot(bdf): logging.info("ScoringManager закрыт в finally") except Exception as e: logging.error(f"Ошибка закрытия ScoringManager в finally: {e}") - + # Останавливаем метрики сервер при завершении try: await stop_metrics_server() except Exception as e: logging.error(f"Error stopping metrics server: {e}") - + # Закрываем сессию бота try: await bot.session.close() diff --git a/helper_bot/middlewares/album_middleware.py b/helper_bot/middlewares/album_middleware.py index 57e5ce7..eff4adb 100644 --- a/helper_bot/middlewares/album_middleware.py +++ b/helper_bot/middlewares/album_middleware.py @@ -7,19 +7,21 @@ from aiogram.types import Message class AlbumGetter: """Вспомогательный класс для получения полной медиагруппы из middleware""" - - def __init__(self, album_data: Dict[str, Any], media_group_id: str, event: asyncio.Event): + + def __init__( + self, album_data: Dict[str, Any], media_group_id: str, event: asyncio.Event + ): self.album_data = album_data self.media_group_id = media_group_id self.event = event - + async def get_album(self, timeout: float = 10.0) -> Optional[List[Message]]: """ Ждет полную медиагруппу и возвращает ее. - + Args: timeout: Максимальное время ожидания в секундах - + Returns: Список сообщений медиагруппы или None при таймауте """ @@ -38,11 +40,11 @@ class AlbumMiddleware(BaseMiddleware): Собирает все сообщения одной медиа группы и передает их как album в data. Не блокирует handler - сразу вызывает его, а полную медиагруппу передает через Event. """ - + def __init__(self, latency: Union[int, float] = 5.0): """ Инициализация middleware. - + Args: latency: Задержка в секундах для сбора всех сообщений медиа группы """ @@ -54,43 +56,43 @@ class AlbumMiddleware(BaseMiddleware): def collect_album_messages(self, event: Message) -> int: """ Собирает сообщения одной медиа группы. - + Args: event: Сообщение для обработки - + Returns: Количество сообщений в текущей медиа группе """ if not event.media_group_id: return 0 - + if event.media_group_id not in self.album_data: self.album_data[event.media_group_id] = {"messages": []} - + self.album_data[event.media_group_id]["messages"].append(event) return len(self.album_data[event.media_group_id]["messages"]) async def _collect_album_background(self, media_group_id: str) -> None: """ Фоновая задача для сбора всех сообщений медиагруппы. - + Args: media_group_id: ID медиагруппы для сбора """ try: await asyncio.sleep(self.latency) - + if media_group_id not in self.album_data: return - + # Получаем текущий список сообщений album_messages = self.album_data[media_group_id]["messages"].copy() album_messages.sort(key=lambda x: x.message_id) - + # Сохраняем собранную медиагруппу и уведомляем через Event self.album_data[media_group_id]["collected_album"] = album_messages self.album_data[media_group_id]["event"].set() - + # Очищаем данные после небольшой задержки (чтобы handler успел получить album) await asyncio.sleep(1.0) if media_group_id in self.album_data: @@ -114,24 +116,24 @@ class AlbumMiddleware(BaseMiddleware): async def __call__(self, handler, event: Message, data: Dict[str, Any]) -> Any: """ Основная логика middleware. - + Для медиагрупп: сразу вызывает handler, передавая Event для получения полной медиагруппы. Для обычных сообщений: сразу вызывает handler. - + Args: handler: Обработчик события event: Событие (сообщение) data: Данные для передачи в обработчик - + Returns: Результат выполнения обработчика """ if not event.media_group_id: return await handler(event, data) - + media_group_id = event.media_group_id message_id = event.message_id - + # Если это первое сообщение медиагруппы - создаем структуру данных is_first_message = False if media_group_id not in self.album_data: @@ -141,27 +143,25 @@ class AlbumMiddleware(BaseMiddleware): "messages": [], "event": album_event, "task": None, - "first_message_id": message_id + "first_message_id": message_id, } # Запускаем фоновую задачу для сбора медиагруппы task = asyncio.create_task(self._collect_album_background(media_group_id)) self.album_data[media_group_id]["task"] = task - + # Добавляем сообщение в медиагруппу self.album_data[media_group_id]["messages"].append(event) - + # Обрабатываем только первое сообщение медиагруппы if not is_first_message: # Для остальных сообщений просто возвращаемся, не вызывая handler return - + # Передаем объект-геттер в data, чтобы handler мог получить полную медиагруппу album_getter = AlbumGetter( - self.album_data, - media_group_id, - self.album_data[media_group_id]["event"] + self.album_data, media_group_id, self.album_data[media_group_id]["event"] ) data["album_getter"] = album_getter - + # Сразу вызываем handler только для первого сообщения (не блокируем) return await handler(event, data) diff --git a/helper_bot/middlewares/blacklist_middleware.py b/helper_bot/middlewares/blacklist_middleware.py index 4bcb92d..32279e2 100644 --- a/helper_bot/middlewares/blacklist_middleware.py +++ b/helper_bot/middlewares/blacklist_middleware.py @@ -4,6 +4,7 @@ from typing import Any, Dict from aiogram import BaseMiddleware, types from aiogram.types import CallbackQuery, Message, TelegramObject + from helper_bot.utils.base_dependency_factory import get_global_instance from logs.custom_logger import logger @@ -12,47 +13,61 @@ BotDB = bdf.get_db() class BlacklistMiddleware(BaseMiddleware): - async def __call__(self, handler, event: TelegramObject, data: Dict[str, Any]) -> Any: + async def __call__( + self, handler, event: TelegramObject, data: Dict[str, Any] + ) -> Any: # Проверяем тип события и получаем пользователя user = None if isinstance(event, Message): user = event.from_user elif isinstance(event, CallbackQuery): user = event.from_user - + # Если это не сообщение или callback, пропускаем проверку if not user: return await handler(event, data) - - logger.info(f'Вызов BlacklistMiddleware для пользователя {user.username}') - + + logger.info(f"Вызов BlacklistMiddleware для пользователя {user.username}") + # Используем асинхронную версию для предотвращения блокировки if await BotDB.check_user_in_blacklist(user.id): - logger.info(f'BlacklistMiddleware результат для пользователя: {user.username} заблокирован!') + logger.info( + f"BlacklistMiddleware результат для пользователя: {user.username} заблокирован!" + ) user_info = await BotDB.get_blacklist_users_by_id(user.id) # Экранируем потенциально проблемные символы - reason = html.escape(str(user_info[1])) if user_info and user_info[1] else "Не указана" - + reason = ( + html.escape(str(user_info[1])) + if user_info and user_info[1] + else "Не указана" + ) + # Преобразуем timestamp в человекочитаемый формат if user_info and user_info[2]: try: timestamp = int(user_info[2]) - date_unban = datetime.fromtimestamp(timestamp).strftime("%d-%m-%Y %H:%M") + date_unban = datetime.fromtimestamp(timestamp).strftime( + "%d-%m-%Y %H:%M" + ) except (ValueError, TypeError): date_unban = "Не указана" else: date_unban = "Не указана" - + # Отправляем сообщение в зависимости от типа события if isinstance(event, Message): await event.answer( - f"Ты заблокирован.\nПричина блокировки: {reason}\nДата разбана: {date_unban}") + f"Ты заблокирован.\nПричина блокировки: {reason}\nДата разбана: {date_unban}" + ) elif isinstance(event, CallbackQuery): await event.answer( f"Ты заблокирован.\nПричина блокировки: {reason}\nДата разбана: {date_unban}", - show_alert=True) - + show_alert=True, + ) + return False - - logger.info(f'BlacklistMiddleware результат для пользователя: {user.username} доступ разрешен') + + logger.info( + f"BlacklistMiddleware результат для пользователя: {user.username} доступ разрешен" + ) return await handler(event, data) diff --git a/helper_bot/middlewares/dependencies_middleware.py b/helper_bot/middlewares/dependencies_middleware.py index a0329b2..d18c28c 100644 --- a/helper_bot/middlewares/dependencies_middleware.py +++ b/helper_bot/middlewares/dependencies_middleware.py @@ -2,30 +2,35 @@ from typing import Any, Dict from aiogram import BaseMiddleware from aiogram.types import TelegramObject + from helper_bot.utils.base_dependency_factory import get_global_instance from logs.custom_logger import logger class DependenciesMiddleware(BaseMiddleware): """Универсальная middleware для внедрения зависимостей во все хендлеры""" - - async def __call__(self, handler, event: TelegramObject, data: Dict[str, Any]) -> Any: + + async def __call__( + self, handler, event: TelegramObject, data: Dict[str, Any] + ) -> Any: try: # Получаем глобальные зависимости bdf = get_global_instance() - + # Внедряем зависимости в data для MagicData - if 'bot_db' not in data: - data['bot_db'] = bdf.get_db() - if 'settings' not in data: - data['settings'] = bdf.settings - data['bot'] = data.get('bot') - data['dp'] = data.get('dp') - - logger.debug(f"DependenciesMiddleware: внедрены зависимости для {type(event).__name__}") - + if "bot_db" not in data: + data["bot_db"] = bdf.get_db() + if "settings" not in data: + data["settings"] = bdf.settings + data["bot"] = data.get("bot") + data["dp"] = data.get("dp") + + logger.debug( + f"DependenciesMiddleware: внедрены зависимости для {type(event).__name__}" + ) + except Exception as e: logger.error(f"Ошибка в DependenciesMiddleware: {e}") # Не прерываем выполнение, продолжаем без зависимостей - + return await handler(event, data) diff --git a/helper_bot/middlewares/metrics_middleware.py b/helper_bot/middlewares/metrics_middleware.py index 6acc4a4..2564b86 100644 --- a/helper_bot/middlewares/metrics_middleware.py +++ b/helper_bot/middlewares/metrics_middleware.py @@ -16,16 +16,16 @@ from ..utils.metrics import metrics # Import button command mapping try: - from ..handlers.admin.constants import (ADMIN_BUTTON_COMMAND_MAPPING, - ADMIN_COMMANDS) + from ..handlers.admin.constants import ADMIN_BUTTON_COMMAND_MAPPING, ADMIN_COMMANDS from ..handlers.callback.constants import CALLBACK_COMMAND_MAPPING from ..handlers.private.constants import BUTTON_COMMAND_MAPPING - from ..handlers.voice.constants import \ - BUTTON_COMMAND_MAPPING as VOICE_BUTTON_COMMAND_MAPPING - from ..handlers.voice.constants import \ - CALLBACK_COMMAND_MAPPING as VOICE_CALLBACK_COMMAND_MAPPING - from ..handlers.voice.constants import \ - COMMAND_MAPPING as VOICE_COMMAND_MAPPING + from ..handlers.voice.constants import ( + BUTTON_COMMAND_MAPPING as VOICE_BUTTON_COMMAND_MAPPING, + ) + from ..handlers.voice.constants import ( + CALLBACK_COMMAND_MAPPING as VOICE_CALLBACK_COMMAND_MAPPING, + ) + from ..handlers.voice.constants import COMMAND_MAPPING as VOICE_COMMAND_MAPPING except ImportError: # Fallback if constants not available BUTTON_COMMAND_MAPPING = {} @@ -39,40 +39,49 @@ except ImportError: class MetricsMiddleware(BaseMiddleware): """Enhanced middleware for automatic collection of ALL available metrics.""" - + def __init__(self): super().__init__() self.logger = logging.getLogger(__name__) - + # Metrics update intervals self.last_active_users_update = 0 self.active_users_update_interval = 300 # 5 minutes - + async def __call__( self, handler: Callable[[TelegramObject, Dict[str, Any]], Awaitable[Any]], event: TelegramObject, - data: Dict[str, Any] + data: Dict[str, Any], ) -> Any: """Process event and collect comprehensive metrics.""" # Update active users periodically current_time = time.time() - if current_time - self.last_active_users_update > self.active_users_update_interval: + if ( + current_time - self.last_active_users_update + > self.active_users_update_interval + ): await self._update_active_users_metric() self.last_active_users_update = current_time - + # Extract command and event info command_info = None event_metrics = {} # Process event based on type - if hasattr(event, 'message') and event.message: - event_metrics = await self._record_comprehensive_message_metrics(event.message) + if hasattr(event, "message") and event.message: + event_metrics = await self._record_comprehensive_message_metrics( + event.message + ) command_info = self._extract_command_info_with_fallback(event.message) - elif hasattr(event, 'callback_query') and event.callback_query: - event_metrics = await self._record_comprehensive_callback_metrics(event.callback_query) - command_info = self._extract_callback_command_info_with_fallback(event.callback_query) + elif hasattr(event, "callback_query") and event.callback_query: + event_metrics = await self._record_comprehensive_callback_metrics( + event.callback_query + ) + command_info = self._extract_callback_command_info_with_fallback( + event.callback_query + ) elif isinstance(event, Message): event_metrics = await self._record_comprehensive_message_metrics(event) command_info = self._extract_command_info_with_fallback(event) @@ -81,107 +90,106 @@ class MetricsMiddleware(BaseMiddleware): command_info = self._extract_callback_command_info_with_fallback(event) else: event_metrics = await self._record_unknown_event_metrics(event) - + if command_info: self.logger.info(f"📊 Command info extracted: {command_info}") else: - self.logger.warning(f"📊 No command info extracted for event type: {type(event).__name__}") - + self.logger.warning( + f"📊 No command info extracted for event type: {type(event).__name__}" + ) + # Execute handler with comprehensive timing and metrics start_time = time.time() try: result = await handler(event, data) duration = time.time() - start_time - + # Record successful execution metrics handler_name = self._get_handler_name(handler) - - metrics.record_method_duration( - handler_name, - duration, - "handler", - "success" - ) - + + metrics.record_method_duration(handler_name, duration, "handler", "success") + if command_info: metrics.record_command( - command_info['command'], - command_info['handler_type'], - command_info['user_type'], - "success" + command_info["command"], + command_info["handler_type"], + command_info["user_type"], + "success", ) - - await self._record_additional_success_metrics(event, event_metrics, handler_name) - + + await self._record_additional_success_metrics( + event, event_metrics, handler_name + ) + return result - + except Exception as e: duration = time.time() - start_time - + # Record error metrics handler_name = self._get_handler_name(handler) error_type = type(e).__name__ - - metrics.record_method_duration( - handler_name, - duration, - "handler", - "error" - ) - - metrics.record_error( - error_type, - "handler", - handler_name - ) - + + metrics.record_method_duration(handler_name, duration, "handler", "error") + + metrics.record_error(error_type, "handler", handler_name) + if command_info: metrics.record_command( - command_info['command'], - command_info['handler_type'], - command_info['user_type'], - "error" + command_info["command"], + command_info["handler_type"], + command_info["user_type"], + "error", ) - - await self._record_additional_error_metrics(event, event_metrics, handler_name, error_type) - + + await self._record_additional_error_metrics( + event, event_metrics, handler_name, error_type + ) + raise finally: # Record middleware execution time middleware_duration = time.time() - start_time - metrics.record_middleware("MetricsMiddleware", middleware_duration, "success") - + metrics.record_middleware( + "MetricsMiddleware", middleware_duration, "success" + ) + async def _update_active_users_metric(self): """Periodically update active users metric from database.""" try: - #TODO: Должна подключаться к базе данных, а не к глобальному экземпляру + # TODO: Должна подключаться к базе данных, а не к глобальному экземпляру from ..utils.base_dependency_factory import get_global_instance + bdf = get_global_instance() bot_db = bdf.get_db() - + # Используем правильные методы AsyncBotDB для выполнения запросов # Простой подсчет всех пользователей в базе total_users_query = "SELECT COUNT(DISTINCT user_id) as total FROM our_users" total_users_result = await bot_db.fetch_one(total_users_query) - total_users = total_users_result['total'] if total_users_result else 1 - + total_users = total_users_result["total"] if total_users_result else 1 + # Подсчет активных за день пользователей (date_changed - это Unix timestamp) daily_users_query = "SELECT COUNT(DISTINCT user_id) as daily FROM our_users WHERE date_changed > (strftime('%s', 'now', '-1 day'))" daily_users_result = await bot_db.fetch_one(daily_users_query) - daily_users = daily_users_result['daily'] if daily_users_result else 1 - + daily_users = daily_users_result["daily"] if daily_users_result else 1 + # Устанавливаем метрики с правильными лейблами metrics.set_active_users(daily_users, "daily") metrics.set_total_users(total_users) - self.logger.info(f"📊 Active users metric updated: {daily_users} (daily), {total_users} (total)") - + self.logger.info( + f"📊 Active users metric updated: {daily_users} (daily), {total_users} (total)" + ) + except Exception as e: self.logger.error(f"❌ Failed to update users metric: {e}") # Устанавливаем 1 как fallback metrics.set_active_users(1, "daily") metrics.set_total_users(1) - - async def _record_comprehensive_message_metrics(self, message: Message) -> Dict[str, Any]: + + async def _record_comprehensive_message_metrics( + self, message: Message + ) -> Dict[str, Any]: """Record comprehensive message metrics.""" # Determine message type message_type = "text" @@ -199,7 +207,7 @@ class MetricsMiddleware(BaseMiddleware): message_type = "sticker" elif message.animation: message_type = "animation" - + # Determine chat type chat_type = "private" if message.chat.type == ChatType.GROUP: @@ -208,129 +216,139 @@ class MetricsMiddleware(BaseMiddleware): chat_type = "supergroup" elif message.chat.type == ChatType.CHANNEL: chat_type = "channel" - + # Record message processing metrics.record_message(message_type, chat_type, "message_handler") - + return { - 'message_type': message_type, - 'chat_type': chat_type, - 'user_id': message.from_user.id if message.from_user else None, - 'is_bot': message.from_user.is_bot if message.from_user else False + "message_type": message_type, + "chat_type": chat_type, + "user_id": message.from_user.id if message.from_user else None, + "is_bot": message.from_user.is_bot if message.from_user else False, } - - async def _record_comprehensive_callback_metrics(self, callback: CallbackQuery) -> Dict[str, Any]: + + async def _record_comprehensive_callback_metrics( + self, callback: CallbackQuery + ) -> Dict[str, Any]: """Record comprehensive callback metrics.""" # Record callback message metrics.record_message("callback_query", "callback", "callback_handler") - + return { - 'callback_data': callback.data, - 'user_id': callback.from_user.id if callback.from_user else None, - 'is_bot': callback.from_user.is_bot if callback.from_user else False + "callback_data": callback.data, + "user_id": callback.from_user.id if callback.from_user else None, + "is_bot": callback.from_user.is_bot if callback.from_user else False, } - - async def _record_unknown_event_metrics(self, event: TelegramObject) -> Dict[str, Any]: + + async def _record_unknown_event_metrics( + self, event: TelegramObject + ) -> Dict[str, Any]: """Record metrics for unknown event types.""" # Record unknown event metrics.record_message("unknown", "unknown", "unknown_handler") - + return { - 'event_type': type(event).__name__, - 'event_data': str(event)[:100] if hasattr(event, '__str__') else "unknown" + "event_type": type(event).__name__, + "event_data": str(event)[:100] if hasattr(event, "__str__") else "unknown", } - - def _extract_command_info_with_fallback(self, message: Message) -> Optional[Dict[str, str]]: + + def _extract_command_info_with_fallback( + self, message: Message + ) -> Optional[Dict[str, str]]: """Extract command information with fallback for unknown commands.""" if not message.text: return None - + # Check if it's a slash command - if message.text.startswith('/'): - command_name = message.text.split()[0][1:] # Remove '/' and get command name - + if message.text.startswith("/"): + command_name = message.text.split()[0][ + 1: + ] # Remove '/' and get command name + # Check if it's an admin command if command_name in ADMIN_COMMANDS: return { - 'command': ADMIN_COMMANDS[command_name], - 'user_type': "admin" if message.from_user else "unknown", - 'handler_type': "admin_handler" + "command": ADMIN_COMMANDS[command_name], + "user_type": "admin" if message.from_user else "unknown", + "handler_type": "admin_handler", } # Check if it's a voice bot command elif command_name in VOICE_COMMAND_MAPPING: return { - 'command': VOICE_COMMAND_MAPPING[command_name], - 'user_type': "user" if message.from_user else "unknown", - 'handler_type': "voice_command_handler" + "command": VOICE_COMMAND_MAPPING[command_name], + "user_type": "user" if message.from_user else "unknown", + "handler_type": "voice_command_handler", } else: # FALLBACK: Record unknown command return { - 'command': command_name, - 'user_type': "user" if message.from_user else "unknown", - 'handler_type': "unknown_command_handler" + "command": command_name, + "user_type": "user" if message.from_user else "unknown", + "handler_type": "unknown_command_handler", } - + # Check if it's an admin button click if message.text in ADMIN_BUTTON_COMMAND_MAPPING: return { - 'command': ADMIN_BUTTON_COMMAND_MAPPING[message.text], - 'user_type': "admin" if message.from_user else "unknown", - 'handler_type': "admin_button_handler" + "command": ADMIN_BUTTON_COMMAND_MAPPING[message.text], + "user_type": "admin" if message.from_user else "unknown", + "handler_type": "admin_button_handler", } - + # Check if it's a regular button click (text button) if message.text in BUTTON_COMMAND_MAPPING: return { - 'command': BUTTON_COMMAND_MAPPING[message.text], - 'user_type': "user" if message.from_user else "unknown", - 'handler_type': "button_handler" + "command": BUTTON_COMMAND_MAPPING[message.text], + "user_type": "user" if message.from_user else "unknown", + "handler_type": "button_handler", } - + # Check if it's a voice bot button click if message.text in VOICE_BUTTON_COMMAND_MAPPING: return { - 'command': VOICE_BUTTON_COMMAND_MAPPING[message.text], - 'user_type': "user" if message.from_user else "unknown", - 'handler_type': "voice_button_handler" + "command": VOICE_BUTTON_COMMAND_MAPPING[message.text], + "user_type": "user" if message.from_user else "unknown", + "handler_type": "voice_button_handler", } - + # FALLBACK: Record ANY text message as a command for metrics if message.text and len(message.text.strip()) > 0: return { - 'command': f"text", - 'user_type': "user" if message.from_user else "unknown", - 'handler_type': "text_message_handler" + "command": f"text", + "user_type": "user" if message.from_user else "unknown", + "handler_type": "text_message_handler", } - + return None - - def _extract_callback_command_info_with_fallback(self, callback: CallbackQuery) -> Optional[Dict[str, str]]: + + def _extract_callback_command_info_with_fallback( + self, callback: CallbackQuery + ) -> Optional[Dict[str, str]]: """Extract callback command information with fallback.""" if not callback.data: return None - + # Extract command from callback data - parts = callback.data.split(':', 1) + parts = callback.data.split(":", 1) if parts and parts[0] in CALLBACK_COMMAND_MAPPING: return { - 'command': CALLBACK_COMMAND_MAPPING[parts[0]], - 'user_type': "user" if callback.from_user else "unknown", - 'handler_type': "callback_handler" + "command": CALLBACK_COMMAND_MAPPING[parts[0]], + "user_type": "user" if callback.from_user else "unknown", + "handler_type": "callback_handler", } - + # Check if it's a voice bot callback if parts and parts[0] in VOICE_CALLBACK_COMMAND_MAPPING: return { - 'command': VOICE_CALLBACK_COMMAND_MAPPING[parts[0]], - 'user_type': "user" if callback.from_user else "unknown", - 'handler_type': "voice_callback_handler" + "command": VOICE_CALLBACK_COMMAND_MAPPING[parts[0]], + "user_type": "user" if callback.from_user else "unknown", + "handler_type": "voice_callback_handler", } - + # FALLBACK: Record unknown callback if parts: callback_data = parts[0] - + # Группируем похожие callback'и по паттернам if callback_data.startswith("ban_") and callback_data[4:].isdigit(): # callback_ban_123456 -> callback_ban @@ -341,60 +359,69 @@ class MetricsMiddleware(BaseMiddleware): else: # Для остальных неизвестных callback'ов оставляем как есть command = f"callback_{callback_data[:20]}" - + return { - 'command': command, - 'user_type': "user" if callback.from_user else "unknown", - 'handler_type': "unknown_callback_handler" + "command": command, + "user_type": "user" if callback.from_user else "unknown", + "handler_type": "unknown_callback_handler", } - + return None - - async def _record_additional_success_metrics(self, event: TelegramObject, event_metrics: Dict[str, Any], handler_name: str): + + async def _record_additional_success_metrics( + self, event: TelegramObject, event_metrics: Dict[str, Any], handler_name: str + ): """Record additional success metrics.""" try: # Record rate limiting metrics (if applicable) - if hasattr(event, 'from_user') and event.from_user: + if hasattr(event, "from_user") and event.from_user: # You can add rate limiting logic here pass - + # Record user activity metrics - if event_metrics.get('user_id'): + if event_metrics.get("user_id"): # This could trigger additional user activity tracking pass - + except Exception as e: self.logger.error(f"❌ Error recording additional success metrics: {e}") - - async def _record_additional_error_metrics(self, event: TelegramObject, event_metrics: Dict[str, Any], handler_name: str, error_type: str): + + async def _record_additional_error_metrics( + self, + event: TelegramObject, + event_metrics: Dict[str, Any], + handler_name: str, + error_type: str, + ): """Record additional error metrics.""" try: # Record specific error context - if event_metrics.get('user_id'): + if event_metrics.get("user_id"): # You can add user-specific error tracking here pass - + except Exception as e: self.logger.error(f"❌ Error recording additional error metrics: {e}") - + def _get_handler_name(self, handler: Callable) -> str: """Extract handler name efficiently.""" # Check various ways to get handler name - if hasattr(handler, '__name__') and handler.__name__ != '': + if hasattr(handler, "__name__") and handler.__name__ != "": return handler.__name__ - elif hasattr(handler, '__qualname__') and handler.__qualname__ != '': + elif hasattr(handler, "__qualname__") and handler.__qualname__ != "": return handler.__qualname__ - elif hasattr(handler, 'callback') and hasattr(handler.callback, '__name__'): + elif hasattr(handler, "callback") and hasattr(handler.callback, "__name__"): return handler.callback.__name__ - elif hasattr(handler, 'view') and hasattr(handler.view, '__name__'): + elif hasattr(handler, "view") and hasattr(handler.view, "__name__"): return handler.view.__name__ else: # Пытаемся получить имя из строкового представления handler_str = str(handler) - if 'function' in handler_str: + if "function" in handler_str: # Извлекаем имя функции из строки import re - match = re.search(r'function\s+(\w+)', handler_str) + + match = re.search(r"function\s+(\w+)", handler_str) if match: return match.group(1) return "unknown" @@ -402,83 +429,77 @@ class MetricsMiddleware(BaseMiddleware): class DatabaseMetricsMiddleware(BaseMiddleware): """Enhanced middleware for database operation metrics.""" - + def __init__(self): super().__init__() self.logger = logging.getLogger(__name__) - + async def __call__( self, handler: Callable[[TelegramObject, Dict[str, Any]], Awaitable[Any]], event: TelegramObject, - data: Dict[str, Any] + data: Dict[str, Any], ) -> Any: """Process event and collect database metrics.""" - + # Check if this handler involves database operations - handler_name = handler.__name__ if hasattr(handler, '__name__') else "unknown" - + handler_name = handler.__name__ if hasattr(handler, "__name__") else "unknown" + # Record middleware start start_time = time.time() - + try: result = await handler(event, data) - + # Record successful database operation duration = time.time() - start_time metrics.record_middleware("DatabaseMetricsMiddleware", duration, "success") - + return result - + except Exception as e: # Record failed database operation duration = time.time() - start_time metrics.record_middleware("DatabaseMetricsMiddleware", duration, "error") - metrics.record_error( - type(e).__name__, - "database_middleware", - handler_name - ) + metrics.record_error(type(e).__name__, "database_middleware", handler_name) raise class ErrorMetricsMiddleware(BaseMiddleware): """Enhanced middleware for error tracking and metrics.""" - + def __init__(self): super().__init__() self.logger = logging.getLogger(__name__) - + async def __call__( self, handler: Callable[[TelegramObject, Dict[str, Any]], Awaitable[Any]], event: TelegramObject, - data: Dict[str, Any] + data: Dict[str, Any], ) -> Any: """Process event and collect error metrics.""" - + # Record middleware start start_time = time.time() - + try: result = await handler(event, data) - + # Record successful error handling duration = time.time() - start_time metrics.record_middleware("ErrorMetricsMiddleware", duration, "success") - + return result - + except Exception as e: # Record error metrics duration = time.time() - start_time - handler_name = handler.__name__ if hasattr(handler, '__name__') else "unknown" - - metrics.record_middleware("ErrorMetricsMiddleware", duration, "error") - metrics.record_error( - type(e).__name__, - "error_middleware", - handler_name + handler_name = ( + handler.__name__ if hasattr(handler, "__name__") else "unknown" ) - + + metrics.record_middleware("ErrorMetricsMiddleware", duration, "error") + metrics.record_error(type(e).__name__, "error_middleware", handler_name) + raise diff --git a/helper_bot/middlewares/rate_limit_middleware.py b/helper_bot/middlewares/rate_limit_middleware.py index bc59898..c50ef88 100644 --- a/helper_bot/middlewares/rate_limit_middleware.py +++ b/helper_bot/middlewares/rate_limit_middleware.py @@ -1,42 +1,43 @@ """ Middleware для автоматического применения rate limiting ко всем входящим сообщениям """ + from typing import Any, Awaitable, Callable, Dict, Union from aiogram import BaseMiddleware from aiogram.exceptions import TelegramAPIError, TelegramRetryAfter -from aiogram.types import (CallbackQuery, ChatMemberUpdated, InlineQuery, - Message, Update) +from aiogram.types import CallbackQuery, ChatMemberUpdated, InlineQuery, Message, Update + from helper_bot.utils.rate_limiter import telegram_rate_limiter from logs.custom_logger import logger class RateLimitMiddleware(BaseMiddleware): """Middleware для автоматического rate limiting входящих сообщений""" - + def __init__(self): super().__init__() self.rate_limiter = telegram_rate_limiter - + async def __call__( self, handler: Callable[[Update, Dict[str, Any]], Awaitable[Any]], event: Union[Update, Message, CallbackQuery, InlineQuery, ChatMemberUpdated], - data: Dict[str, Any] + data: Dict[str, Any], ) -> Any: """Обрабатывает событие с rate limiting""" - + # Извлекаем сообщение из Update message = None if isinstance(event, Update): message = event.message elif isinstance(event, Message): message = event - + # Применяем rate limiting только к сообщениям if message is not None: chat_id = message.chat.id - + # Обертываем handler в rate limiting async def rate_limited_handler(): try: @@ -46,13 +47,11 @@ class RateLimitMiddleware(BaseMiddleware): # Middleware не должен перехватывать эти ошибки, # пусть их обрабатывает rate_limiter в функциях отправки raise - + # Применяем rate limiting к handler return await self.rate_limiter.send_with_rate_limit( - rate_limited_handler, - chat_id + rate_limited_handler, chat_id ) else: # Для других типов событий просто вызываем handler return await handler(event, data) - diff --git a/helper_bot/middlewares/text_middleware.py b/helper_bot/middlewares/text_middleware.py index b18ed6f..890dd8b 100644 --- a/helper_bot/middlewares/text_middleware.py +++ b/helper_bot/middlewares/text_middleware.py @@ -12,7 +12,6 @@ class BulkTextMiddleware(BaseMiddleware): self.latency = latency self.texts = defaultdict(list) - async def __call__(self, handler, event: Message, data: Dict[str, Any]) -> Any: """ Main middleware logic. @@ -37,10 +36,9 @@ class BulkTextMiddleware(BaseMiddleware): # # Sort the album messages by message_id and add to data msg_texts = self.texts[key] msg_texts.sort(key=lambda x: x.message_id) - data["texts"] = ''.join([msg.text for msg in msg_texts]) + data["texts"] = "".join([msg.text for msg in msg_texts]) # # Remove the media group from tracking to free up memory del self.texts[key] # # Call the original event handler return await handler(event, data) - diff --git a/helper_bot/server_prometheus.py b/helper_bot/server_prometheus.py index 7255fd5..4051a8c 100644 --- a/helper_bot/server_prometheus.py +++ b/helper_bot/server_prometheus.py @@ -1,4 +1,3 @@ - """ HTTP server for metrics endpoint integration with centralized Prometheus monitoring. Provides /metrics endpoint and health check for the bot. @@ -17,53 +16,48 @@ try: except ImportError: # Fallback для случаев, когда custom_logger недоступен import logging + logger = logging.getLogger(__name__) class MetricsServer: """HTTP server for Prometheus metrics and health checks.""" - - def __init__(self, host: str = '0.0.0.0', port: int = 8080): + + def __init__(self, host: str = "0.0.0.0", port: int = 8080): self.host = host self.port = port self.app = web.Application() self.runner: Optional[web.AppRunner] = None self.site: Optional[web.TCPSite] = None - + # Настраиваем роуты - self.app.router.add_get('/metrics', self.metrics_handler) - self.app.router.add_get('/health', self.health_handler) - + self.app.router.add_get("/metrics", self.metrics_handler) + self.app.router.add_get("/health", self.health_handler) + async def metrics_handler(self, request: web.Request) -> web.Response: """Handle /metrics endpoint for Prometheus scraping.""" try: logger.debug("Generating metrics...") - + # Проверяем, что metrics доступен if not metrics: logger.error("Metrics object is not available") - return web.Response( - text="Metrics not available", - status=500 - ) - + return web.Response(text="Metrics not available", status=500) + # Генерируем метрики в формате Prometheus metrics_data = metrics.get_metrics() logger.debug(f"Generated metrics: {len(metrics_data)} bytes") - + return web.Response( - body=metrics_data, - content_type='text/plain; version=0.0.4' + body=metrics_data, content_type="text/plain; version=0.0.4" ) except Exception as e: logger.error(f"Error generating metrics: {e}") import traceback + logger.error(f"Traceback: {traceback.format_exc()}") - return web.Response( - text=f"Error generating metrics: {e}", - status=500 - ) - + return web.Response(text=f"Error generating metrics: {e}", status=500) + async def health_handler(self, request: web.Request) -> web.Response: """Handle /health endpoint for health checks.""" try: @@ -71,77 +65,72 @@ class MetricsServer: if not metrics: return web.Response( text="ERROR: Metrics not available", - content_type='text/plain', - status=503 + content_type="text/plain", + status=503, ) - + # Проверяем, что можем получить метрики try: metrics_data = metrics.get_metrics() if not metrics_data: return web.Response( text="ERROR: Empty metrics", - content_type='text/plain', - status=503 + content_type="text/plain", + status=503, ) except Exception as e: return web.Response( text=f"ERROR: Metrics generation failed: {e}", - content_type='text/plain', - status=503 + content_type="text/plain", + status=503, ) - - return web.Response( - text="OK", - content_type='text/plain', - status=200 - ) + + return web.Response(text="OK", content_type="text/plain", status=200) except Exception as e: logger.error(f"Health check failed: {e}") return web.Response( text=f"ERROR: Health check failed: {e}", - content_type='text/plain', - status=500 + content_type="text/plain", + status=500, ) - - + async def start(self) -> None: """Start the HTTP server.""" try: self.runner = web.AppRunner(self.app) await self.runner.setup() - + self.site = web.TCPSite(self.runner, self.host, self.port) await self.site.start() - + logger.info(f"Metrics server started on {self.host}:{self.port}") logger.info("Available endpoints:") logger.info(f" - /metrics - Prometheus metrics") logger.info(f" - /health - Health check") - + except Exception as e: logger.error(f"Failed to start metrics server: {e}") raise - + async def stop(self) -> None: """Stop the HTTP server.""" try: if self.site: await self.site.stop() logger.info("Metrics server site stopped") - + if self.runner: await self.runner.cleanup() logger.info("Metrics server runner cleaned up") - + except Exception as e: logger.error(f"Error stopping metrics server: {e}") - + async def __aenter__(self): """Async context manager entry.""" await self.start() return self - + async def __aexit__(self, exc_type, exc_val, exc_tb): """Async context manager exit.""" await self.stop() @@ -151,7 +140,9 @@ class MetricsServer: metrics_server: Optional[MetricsServer] = None -async def start_metrics_server(host: str = '0.0.0.0', port: int = 8080) -> MetricsServer: +async def start_metrics_server( + host: str = "0.0.0.0", port: int = 8080 +) -> MetricsServer: """Start metrics server and return instance.""" global metrics_server metrics_server = MetricsServer(host, port) diff --git a/helper_bot/services/scoring/__init__.py b/helper_bot/services/scoring/__init__.py index 6a1d156..33ee5f7 100644 --- a/helper_bot/services/scoring/__init__.py +++ b/helper_bot/services/scoring/__init__.py @@ -9,9 +9,14 @@ from .base import CombinedScore, ScoringResult, ScoringServiceProtocol from .deepseek_service import DeepSeekService -from .exceptions import (DeepSeekAPIError, InsufficientExamplesError, - ModelNotLoadedError, ScoringError, TextTooShortError, - VectorStoreError) +from .exceptions import ( + DeepSeekAPIError, + InsufficientExamplesError, + ModelNotLoadedError, + ScoringError, + TextTooShortError, + VectorStoreError, +) from .rag_client import RagApiClient from .scoring_manager import ScoringManager diff --git a/helper_bot/services/scoring/base.py b/helper_bot/services/scoring/base.py index 0848468..ba7e464 100644 --- a/helper_bot/services/scoring/base.py +++ b/helper_bot/services/scoring/base.py @@ -11,7 +11,7 @@ from typing import Any, Dict, Optional, Protocol class ScoringResult: """ Результат оценки поста от одного сервиса. - + Attributes: score: Оценка от 0.0 до 1.0 (вероятность публикации) source: Источник оценки ("deepseek", "rag", etc.) @@ -20,18 +20,21 @@ class ScoringResult: timestamp: Время получения оценки metadata: Дополнительные данные """ + score: float source: str model: str confidence: Optional[float] = None timestamp: int = field(default_factory=lambda: int(datetime.now().timestamp())) metadata: Dict[str, Any] = field(default_factory=dict) - + def __post_init__(self): """Валидация score в диапазоне [0.0, 1.0].""" if not 0.0 <= self.score <= 1.0: - raise ValueError(f"Score должен быть в диапазоне [0.0, 1.0], получено: {self.score}") - + raise ValueError( + f"Score должен быть в диапазоне [0.0, 1.0], получено: {self.score}" + ) + def to_dict(self) -> Dict[str, Any]: """Преобразует результат в словарь для сохранения в JSON.""" result = { @@ -44,7 +47,7 @@ class ScoringResult: if self.metadata: result["metadata"] = self.metadata return result - + @classmethod def from_dict(cls, source: str, data: Dict[str, Any]) -> "ScoringResult": """Создает ScoringResult из словаря.""" @@ -62,30 +65,31 @@ class ScoringResult: class CombinedScore: """ Объединенный результат от всех сервисов скоринга. - + Attributes: deepseek: Результат от DeepSeek API (None если отключен/ошибка) rag: Результат от RAG сервиса (None если отключен/ошибка) errors: Словарь с ошибками по источникам """ + deepseek: Optional[ScoringResult] = None rag: Optional[ScoringResult] = None errors: Dict[str, str] = field(default_factory=dict) - + @property def deepseek_score(self) -> Optional[float]: """Возвращает только числовой скор от DeepSeek.""" return self.deepseek.score if self.deepseek else None - + @property def rag_score(self) -> Optional[float]: """Возвращает только числовой скор от RAG.""" return self.rag.score if self.rag else None - + def to_json_dict(self) -> Dict[str, Any]: """ Преобразует в словарь для сохранения в ml_scores колонку. - + Формат: { "deepseek": {"score": 0.75, "model": "...", "ts": ...}, @@ -98,7 +102,7 @@ class CombinedScore: if self.rag: result["rag"] = self.rag.to_dict() return result - + def has_any_score(self) -> bool: """Проверяет, есть ли хотя бы один успешный скор.""" return self.deepseek is not None or self.rag is not None @@ -107,48 +111,48 @@ class CombinedScore: class ScoringServiceProtocol(Protocol): """ Протокол для сервисов скоринга. - + Любой сервис скоринга должен реализовывать эти методы. """ - + @property def source_name(self) -> str: """Возвращает имя источника ("deepseek", "rag", etc.).""" ... - + @property def is_enabled(self) -> bool: """Проверяет, включен ли сервис.""" ... - + async def calculate_score(self, text: str) -> ScoringResult: """ Рассчитывает скор для текста поста. - + Args: text: Текст поста для оценки - + Returns: ScoringResult с оценкой - + Raises: ScoringError: При ошибке расчета """ ... - + async def add_positive_example(self, text: str) -> None: """ Добавляет текст как положительный пример (опубликованный пост). - + Args: text: Текст опубликованного поста """ ... - + async def add_negative_example(self, text: str) -> None: """ Добавляет текст как отрицательный пример (отклоненный пост). - + Args: text: Текст отклоненного поста """ diff --git a/helper_bot/services/scoring/deepseek_service.py b/helper_bot/services/scoring/deepseek_service.py index 3bd9ecf..4f9cc23 100644 --- a/helper_bot/services/scoring/deepseek_service.py +++ b/helper_bot/services/scoring/deepseek_service.py @@ -9,6 +9,7 @@ import json from typing import List, Optional import httpx + from helper_bot.utils.metrics import track_errors, track_time from logs.custom_logger import logger @@ -19,17 +20,17 @@ from .exceptions import DeepSeekAPIError, ScoringError, TextTooShortError class DeepSeekService: """ Сервис для оценки постов через DeepSeek API. - + Отправляет текст поста в DeepSeek с промптом для оценки и получает числовой скор релевантности. - + Attributes: api_key: API ключ DeepSeek api_url: URL API эндпоинта model: Название модели timeout: Таймаут запроса в секундах """ - + # Промпт для оценки поста SCORING_PROMPT = """Роль: Ты — строгий и внимательный модератор сообщества в социальной сети, ориентированного на знакомства между людьми. Твоя задача — оценить, можно ли опубликовать пост, основываясь на четких правилах. @@ -77,7 +78,7 @@ class DeepSeekService: DEFAULT_API_URL = "https://api.deepseek.com/v1/chat/completions" DEFAULT_MODEL = "deepseek-chat" - + def __init__( self, api_key: Optional[str] = None, @@ -90,7 +91,7 @@ class DeepSeekService: ): """ Инициализация DeepSeek сервиса. - + Args: api_key: API ключ DeepSeek api_url: URL API эндпоинта @@ -107,29 +108,29 @@ class DeepSeekService: self._enabled = enabled and bool(api_key) self.min_text_length = min_text_length self.max_retries = max_retries - + # HTTP клиент (создается лениво) self._client: Optional[httpx.AsyncClient] = None - + if not api_key and enabled: logger.warning("DeepSeekService: API ключ не указан, сервис отключен") self._enabled = False - + logger.info( f"DeepSeekService инициализирован " f"(model={self.model}, enabled={self._enabled})" ) - + @property def source_name(self) -> str: """Имя источника для результатов.""" return "deepseek" - + @property def is_enabled(self) -> bool: """Проверяет, включен ли сервис.""" return self._enabled - + async def _get_client(self) -> httpx.AsyncClient: """Получает или создает HTTP клиент.""" if self._client is None: @@ -141,101 +142,106 @@ class DeepSeekService: }, ) return self._client - + async def close(self) -> None: """Закрывает HTTP клиент.""" if self._client: await self._client.aclose() self._client = None - + def _clean_text(self, text: str) -> str: """Очищает текст от лишних символов.""" if not text: return "" - + # Удаляем лишние пробелы и переносы строк clean = " ".join(text.split()) - + # Удаляем служебные символы if clean == "^": return "" - + return clean.strip() - + def _parse_score_response(self, response_text: str) -> float: """ Парсит ответ от DeepSeek и извлекает скор. - + Args: response_text: Текст ответа от API - + Returns: Числовой скор от 0.0 до 1.0 - + Raises: DeepSeekAPIError: Если не удалось распарсить ответ """ try: # Пытаемся найти число в ответе text = response_text.strip() - + # Убираем возможные обрамления - text = text.strip('"\'`') - + text = text.strip("\"'`") + # Пробуем распарсить как число score = float(text) - + # Ограничиваем диапазон score = max(0.0, min(1.0, score)) - + return score - + except ValueError: # Пробуем найти число в тексте import re - matches = re.findall(r'0\.\d+|1\.0|0|1', text) + + matches = re.findall(r"0\.\d+|1\.0|0|1", text) if matches: score = float(matches[0]) return max(0.0, min(1.0, score)) - - logger.error(f"DeepSeekService: Не удалось распарсить ответ: {response_text}") - raise DeepSeekAPIError(f"Не удалось распарсить скор из ответа: {response_text}") - + + logger.error( + f"DeepSeekService: Не удалось распарсить ответ: {response_text}" + ) + raise DeepSeekAPIError( + f"Не удалось распарсить скор из ответа: {response_text}" + ) + @track_time("calculate_score", "deepseek_service") @track_errors("deepseek_service", "calculate_score") async def calculate_score(self, text: str) -> ScoringResult: """ Рассчитывает скор для текста поста через DeepSeek API. - + Args: text: Текст поста для оценки - + Returns: ScoringResult с оценкой - + Raises: ScoringError: При ошибке расчета """ if not self._enabled: raise ScoringError("DeepSeek сервис отключен") - + # Очищаем текст clean_text = self._clean_text(text) - + if len(clean_text) < self.min_text_length: raise TextTooShortError( f"Текст слишком короткий (минимум {self.min_text_length} символов)" ) - + # Формируем промпт prompt = self.SCORING_PROMPT.format(text=clean_text) - + # Выполняем запрос с повторными попытками last_error = None for attempt in range(self.max_retries): try: score = await self._make_api_request(prompt) - + return ScoringResult( score=score, source=self.source_name, @@ -245,7 +251,7 @@ class DeepSeekService: "attempt": attempt + 1, }, ) - + except DeepSeekAPIError as e: last_error = e logger.warning( @@ -254,25 +260,27 @@ class DeepSeekService: ) if attempt < self.max_retries - 1: # Экспоненциальная задержка - await asyncio.sleep(2 ** attempt) - - raise ScoringError(f"Все попытки запроса к DeepSeek API не удались: {last_error}") - + await asyncio.sleep(2**attempt) + + raise ScoringError( + f"Все попытки запроса к DeepSeek API не удались: {last_error}" + ) + async def _make_api_request(self, prompt: str) -> float: """ Выполняет запрос к DeepSeek API. - + Args: prompt: Промпт для отправки - + Returns: Числовой скор от 0.0 до 1.0 - + Raises: DeepSeekAPIError: При ошибке API """ client = await self._get_client() - + payload = { "model": self.model, "messages": [ @@ -282,27 +290,27 @@ class DeepSeekService: } ], "temperature": 0.1, # Низкая температура для детерминированности - "max_tokens": 10, # Ожидаем только число + "max_tokens": 10, # Ожидаем только число } - + try: response = await client.post(self.api_url, json=payload) response.raise_for_status() - + data = response.json() - + # Извлекаем ответ if "choices" not in data or not data["choices"]: raise DeepSeekAPIError("Пустой ответ от API") - + response_text = data["choices"][0]["message"]["content"] - + # Парсим скор score = self._parse_score_response(response_text) - + logger.debug(f"DeepSeekService: Получен скор {score} для текста") return score - + except httpx.HTTPStatusError as e: error_msg = f"HTTP ошибка {e.response.status_code}" try: @@ -312,40 +320,40 @@ class DeepSeekService: except Exception: pass raise DeepSeekAPIError(error_msg) - + except httpx.TimeoutException: raise DeepSeekAPIError(f"Таймаут запроса ({self.timeout}s)") - + except Exception as e: raise DeepSeekAPIError(f"Ошибка запроса: {e}") - + async def add_positive_example(self, text: str) -> None: """ Добавляет текст как положительный пример. - + Для DeepSeek не требуется хранить примеры - оценка выполняется на основе промпта. Метод существует для совместимости с протоколом. - + Args: text: Текст опубликованного поста """ # DeepSeek не использует примеры для обучения # Промпт уже содержит критерии оценки pass - + async def add_negative_example(self, text: str) -> None: """ Добавляет текст как отрицательный пример. - + Для DeepSeek не требуется хранить примеры - оценка выполняется на основе промпта. Метод существует для совместимости с протоколом. - + Args: text: Текст отклоненного поста """ # DeepSeek не использует примеры для обучения pass - + def get_stats(self) -> dict: """Возвращает статистику сервиса.""" return { diff --git a/helper_bot/services/scoring/exceptions.py b/helper_bot/services/scoring/exceptions.py index 8af309c..eb6219c 100644 --- a/helper_bot/services/scoring/exceptions.py +++ b/helper_bot/services/scoring/exceptions.py @@ -5,29 +5,35 @@ class ScoringError(Exception): """Базовое исключение для ошибок скоринга.""" + pass class ModelNotLoadedError(ScoringError): """Модель не загружена или недоступна.""" + pass class VectorStoreError(ScoringError): """Ошибка при работе с хранилищем векторов.""" + pass class DeepSeekAPIError(ScoringError): """Ошибка при обращении к DeepSeek API.""" + pass class InsufficientExamplesError(ScoringError): """Недостаточно примеров для расчета скора.""" + pass class TextTooShortError(ScoringError): """Текст слишком короткий для векторизации.""" + pass diff --git a/helper_bot/services/scoring/rag_client.py b/helper_bot/services/scoring/rag_client.py index 45f5f3a..fc2f362 100644 --- a/helper_bot/services/scoring/rag_client.py +++ b/helper_bot/services/scoring/rag_client.py @@ -7,24 +7,24 @@ HTTP клиент для взаимодействия с внешним RAG се from typing import Any, Dict, Optional import httpx + from helper_bot.utils.metrics import track_errors, track_time from logs.custom_logger import logger from .base import ScoringResult -from .exceptions import (InsufficientExamplesError, ScoringError, - TextTooShortError) +from .exceptions import InsufficientExamplesError, ScoringError, TextTooShortError class RagApiClient: """ HTTP клиент для взаимодействия с внешним RAG сервисом. - + Использует REST API для: - Получения скоров постов (POST /api/v1/score) - Отправки положительных примеров (POST /api/v1/examples/positive) - Отправки отрицательных примеров (POST /api/v1/examples/negative) - Получения статистики (GET /api/v1/stats) - + Attributes: api_url: Базовый URL API сервиса api_key: API ключ для аутентификации @@ -32,7 +32,7 @@ class RagApiClient: test_mode: Включен ли тестовый режим (добавляет заголовок X-Test-Mode: true) enabled: Включен ли клиент """ - + def __init__( self, api_url: str, @@ -43,7 +43,7 @@ class RagApiClient: ): """ Инициализация клиента. - + Args: api_url: Базовый URL API (например, http://хх.ххх.ххх.хх/api/v1) api_key: API ключ для аутентификации @@ -52,49 +52,51 @@ class RagApiClient: enabled: Включен ли клиент """ # Убираем trailing slash если есть - self.api_url = api_url.rstrip('/') + self.api_url = api_url.rstrip("/") self.api_key = api_key self.timeout = timeout self.test_mode = test_mode self._enabled = enabled - + # Создаем HTTP клиент self._client = httpx.AsyncClient( timeout=httpx.Timeout(timeout), headers={ "X-API-Key": api_key, "Content-Type": "application/json", - } + }, ) - - logger.info(f"RagApiClient инициализирован (url={self.api_url}, enabled={enabled})") - + + logger.info( + f"RagApiClient инициализирован (url={self.api_url}, enabled={enabled})" + ) + @property def source_name(self) -> str: """Имя источника для результатов.""" return "rag" - + @property def is_enabled(self) -> bool: """Проверяет, включен ли клиент.""" return self._enabled - + async def close(self) -> None: """Закрывает HTTP клиент.""" await self._client.aclose() - + @track_time("calculate_score", "rag_client") @track_errors("rag_client", "calculate_score") async def calculate_score(self, text: str) -> ScoringResult: """ Рассчитывает скор для текста поста через API. - + Args: text: Текст поста для оценки - + Returns: ScoringResult с оценкой - + Raises: ScoringError: При ошибке расчета InsufficientExamplesError: Если недостаточно примеров @@ -102,16 +104,15 @@ class RagApiClient: """ if not self._enabled: raise ScoringError("RAG API клиент отключен") - + if not text or not text.strip(): raise TextTooShortError("Текст пустой") - + try: response = await self._client.post( - f"{self.api_url}/score", - json={"text": text.strip()} + f"{self.api_url}/score", json={"text": text.strip()} ) - + # Обрабатываем различные статусы if response.status_code == 400: try: @@ -119,43 +120,52 @@ class RagApiClient: error_msg = error_data.get("detail", "Неизвестная ошибка") except Exception: error_msg = response.text or "Неизвестная ошибка" - + logger.warning(f"RagApiClient: Ошибка валидации запроса: {error_msg}") - - if "недостаточно" in error_msg.lower() or "insufficient" in error_msg.lower(): + + if ( + "недостаточно" in error_msg.lower() + or "insufficient" in error_msg.lower() + ): raise InsufficientExamplesError(error_msg) if "коротк" in error_msg.lower() or "short" in error_msg.lower(): raise TextTooShortError(error_msg) raise ScoringError(f"Ошибка валидации: {error_msg}") - + if response.status_code == 401: logger.error("RagApiClient: Ошибка аутентификации: неверный API ключ") raise ScoringError("Ошибка аутентификации: неверный API ключ") - + if response.status_code == 404: logger.error("RagApiClient: RAG API endpoint не найден") raise ScoringError("RAG API endpoint не найден") - + if response.status_code >= 500: - logger.error(f"RagApiClient: Ошибка сервера RAG API: {response.status_code}") + logger.error( + f"RagApiClient: Ошибка сервера RAG API: {response.status_code}" + ) raise ScoringError(f"Ошибка сервера RAG API: {response.status_code}") - + # Проверяем успешный статус if response.status_code != 200: response.raise_for_status() - + data = response.json() - + # Парсим ответ score = float(data.get("rag_score", 0.0)) - confidence = float(data.get("rag_confidence", 0.0)) if data.get("rag_confidence") is not None else None + confidence = ( + float(data.get("rag_confidence", 0.0)) + if data.get("rag_confidence") is not None + else None + ) rag_score_pos_only_raw = data.get("rag_score_pos_only") rag_score_pos_only = float(rag_score_pos_only_raw) if rag_score_pos_only_raw is not None else None - + # Форматируем confidence для логирования confidence_str = f"{confidence:.4f}" if confidence is not None else "None" rag_score_pos_only_str = f"{rag_score_pos_only:.4f}" if rag_score_pos_only is not None else "None" - + logger.info( f"RagApiClient: Скор успешно получен из API - " f"rag_score={score:.4f} (type: {type(score).__name__}), " @@ -164,19 +174,23 @@ class RagApiClient: f"raw_response_rag_score={data.get('rag_score')}, " f"raw_response_rag_score_pos_only={rag_score_pos_only_raw}" ) - + return ScoringResult( score=score, source=self.source_name, model=data.get("meta", {}).get("model", "rag-service"), confidence=confidence, metadata={ - "rag_score_pos_only": rag_score_pos_only, + "rag_score_pos_only": ( + float(data.get("rag_score_pos_only", 0.0)) + if data.get("rag_score_pos_only") is not None + else None + ), "positive_examples": data.get("meta", {}).get("positive_examples"), "negative_examples": data.get("meta", {}).get("negative_examples"), - } + }, ) - + except httpx.TimeoutException: logger.error(f"RagApiClient: Таймаут запроса к RAG API (>{self.timeout}с)") raise ScoringError(f"Таймаут запроса к RAG API (>{self.timeout}с)") @@ -184,7 +198,9 @@ class RagApiClient: logger.error(f"RagApiClient: Ошибка подключения к RAG API: {e}") raise ScoringError(f"Ошибка подключения к RAG API: {e}") except (KeyError, ValueError, TypeError) as e: - logger.error(f"RagApiClient: Ошибка парсинга ответа: {e}, response: {response.text if 'response' in locals() else 'N/A'}") + logger.error( + f"RagApiClient: Ошибка парсинга ответа: {e}, response: {response.text if 'response' in locals() else 'N/A'}" + ) raise ScoringError(f"Ошибка парсинга ответа от RAG API: {e}") except InsufficientExamplesError: raise @@ -195,122 +211,145 @@ class RagApiClient: raise except Exception as e: # Только действительно неожиданные ошибки логируем здесь - logger.error(f"RagApiClient: Неожиданная ошибка при расчете скора: {e}", exc_info=True) + logger.error( + f"RagApiClient: Неожиданная ошибка при расчете скора: {e}", + exc_info=True, + ) raise ScoringError(f"Неожиданная ошибка: {e}") - + @track_time("add_positive_example", "rag_client") async def add_positive_example(self, text: str) -> None: """ Добавляет текст как положительный пример (опубликованный пост). - + Args: text: Текст опубликованного поста """ if not self._enabled: return - + if not text or not text.strip(): return - + try: # Формируем заголовки (добавляем X-Test-Mode если включен тестовый режим) headers = {} if self.test_mode: headers["X-Test-Mode"] = "true" - + response = await self._client.post( f"{self.api_url}/examples/positive", json={"text": text.strip()}, - headers=headers + headers=headers, ) - + if response.status_code == 200 or response.status_code == 201: logger.info("RagApiClient: Положительный пример успешно добавлен") elif response.status_code == 400: - logger.warning(f"RagApiClient: Ошибка валидации при добавлении положительного примера: {response.text}") + logger.warning( + f"RagApiClient: Ошибка валидации при добавлении положительного примера: {response.text}" + ) else: - logger.warning(f"RagApiClient: Неожиданный статус при добавлении положительного примера: {response.status_code}") - + logger.warning( + f"RagApiClient: Неожиданный статус при добавлении положительного примера: {response.status_code}" + ) + except httpx.TimeoutException: - logger.warning(f"RagApiClient: Таймаут при добавлении положительного примера") + logger.warning( + f"RagApiClient: Таймаут при добавлении положительного примера" + ) except httpx.RequestError as e: - logger.warning(f"RagApiClient: Ошибка подключения при добавлении положительного примера: {e}") + logger.warning( + f"RagApiClient: Ошибка подключения при добавлении положительного примера: {e}" + ) except Exception as e: logger.error(f"RagApiClient: Ошибка добавления положительного примера: {e}") - + @track_time("add_negative_example", "rag_client") async def add_negative_example(self, text: str) -> None: """ Добавляет текст как отрицательный пример (отклоненный пост). - + Args: text: Текст отклоненного поста """ if not self._enabled: return - + if not text or not text.strip(): return - + try: # Формируем заголовки (добавляем X-Test-Mode если включен тестовый режим) headers = {} if self.test_mode: headers["X-Test-Mode"] = "true" - + response = await self._client.post( f"{self.api_url}/examples/negative", json={"text": text.strip()}, - headers=headers + headers=headers, ) - + if response.status_code == 200 or response.status_code == 201: logger.info("RagApiClient: Отрицательный пример успешно добавлен") elif response.status_code == 400: - logger.warning(f"RagApiClient: Ошибка валидации при добавлении отрицательного примера: {response.text}") + logger.warning( + f"RagApiClient: Ошибка валидации при добавлении отрицательного примера: {response.text}" + ) else: - logger.warning(f"RagApiClient: Неожиданный статус при добавлении отрицательного примера: {response.status_code}") - + logger.warning( + f"RagApiClient: Неожиданный статус при добавлении отрицательного примера: {response.status_code}" + ) + except httpx.TimeoutException: - logger.warning(f"RagApiClient: Таймаут при добавлении отрицательного примера") + logger.warning( + f"RagApiClient: Таймаут при добавлении отрицательного примера" + ) except httpx.RequestError as e: - logger.warning(f"RagApiClient: Ошибка подключения при добавлении отрицательного примера: {e}") + logger.warning( + f"RagApiClient: Ошибка подключения при добавлении отрицательного примера: {e}" + ) except Exception as e: logger.error(f"RagApiClient: Ошибка добавления отрицательного примера: {e}") - + async def get_stats(self) -> Dict[str, Any]: """ Получает статистику от RAG API через endpoint /stats. - + Returns: Словарь со статистикой или пустой словарь при ошибке """ if not self._enabled: return {} - + try: response = await self._client.get(f"{self.api_url}/stats") - + if response.status_code == 200: return response.json() else: - logger.warning(f"RagApiClient: Неожиданный статус при получении статистики: {response.status_code}") + logger.warning( + f"RagApiClient: Неожиданный статус при получении статистики: {response.status_code}" + ) return {} - + except httpx.TimeoutException: logger.warning(f"RagApiClient: Таймаут при получении статистики") return {} except httpx.RequestError as e: - logger.warning(f"RagApiClient: Ошибка подключения при получении статистики: {e}") + logger.warning( + f"RagApiClient: Ошибка подключения при получении статистики: {e}" + ) return {} except Exception as e: logger.error(f"RagApiClient: Ошибка получения статистики: {e}") return {} - + def get_stats_sync(self) -> Dict[str, Any]: """ Синхронная версия get_stats для использования в get_stats() ScoringManager. - + Внимание: Это заглушка, реальная статистика будет получена асинхронно. """ return { diff --git a/helper_bot/services/scoring/scoring_manager.py b/helper_bot/services/scoring/scoring_manager.py index 6c03035..6761176 100644 --- a/helper_bot/services/scoring/scoring_manager.py +++ b/helper_bot/services/scoring/scoring_manager.py @@ -13,23 +13,22 @@ from logs.custom_logger import logger from .base import CombinedScore, ScoringResult from .deepseek_service import DeepSeekService -from .exceptions import (InsufficientExamplesError, ScoringError, - TextTooShortError) +from .exceptions import InsufficientExamplesError, ScoringError, TextTooShortError from .rag_client import RagApiClient class ScoringManager: """ Менеджер для управления всеми сервисами скоринга. - + Объединяет RagApiClient и DeepSeekService, выполняет параллельные запросы и агрегирует результаты в единый CombinedScore. - + Attributes: rag_client: HTTP клиент для RAG API deepseek_service: Сервис DeepSeek API """ - + def __init__( self, rag_client: Optional[RagApiClient] = None, @@ -37,68 +36,70 @@ class ScoringManager: ): """ Инициализация менеджера. - + Args: rag_client: HTTP клиент для RAG API (создается автоматически если не передан) deepseek_service: Сервис DeepSeek (создается автоматически если не передан) """ self.rag_client = rag_client self.deepseek_service = deepseek_service - + logger.info( f"ScoringManager инициализирован " f"(rag={rag_client is not None and rag_client.is_enabled}, " f"deepseek={deepseek_service is not None and deepseek_service.is_enabled})" ) - + @property def is_any_enabled(self) -> bool: """Проверяет, включен ли хотя бы один сервис.""" rag_enabled = self.rag_client is not None and self.rag_client.is_enabled - deepseek_enabled = self.deepseek_service is not None and self.deepseek_service.is_enabled + deepseek_enabled = ( + self.deepseek_service is not None and self.deepseek_service.is_enabled + ) return rag_enabled or deepseek_enabled - + @track_time("score_post", "scoring_manager") @track_errors("scoring_manager", "score_post") async def score_post(self, text: str) -> CombinedScore: """ Рассчитывает скоры для текста поста от всех сервисов. - + Выполняет запросы параллельно для минимизации задержки. - + Args: text: Текст поста для оценки - + Returns: CombinedScore с результатами от всех сервисов """ result = CombinedScore() - + if not text or not text.strip(): logger.debug("ScoringManager: Пустой текст, пропускаем скоринг") return result - + # Собираем задачи для параллельного выполнения tasks = [] task_names = [] - + # RAG API клиент if self.rag_client and self.rag_client.is_enabled: tasks.append(self._get_rag_score(text)) task_names.append("rag") - + # DeepSeek сервис if self.deepseek_service and self.deepseek_service.is_enabled: tasks.append(self._get_deepseek_score(text)) task_names.append("deepseek") - + if not tasks: logger.debug("ScoringManager: Нет активных сервисов для скоринга") return result - + # Выполняем параллельно results = await asyncio.gather(*tasks, return_exceptions=True) - + # Обрабатываем результаты for name, res in zip(task_names, results): if isinstance(res, Exception): @@ -111,14 +112,14 @@ class ScoringManager: result.rag = res elif name == "deepseek": result.deepseek = res - + logger.info( f"ScoringManager: Скоринг завершен " f"(rag={result.rag_score}, deepseek={result.deepseek_score})" ) - + return result - + async def _get_rag_score(self, text: str) -> Optional[ScoringResult]: """Получает скор от RAG API.""" try: @@ -134,7 +135,7 @@ class ScoringManager: except Exception as e: # Ошибки уже залогированы в RagApiClient, здесь только пробрасываем raise - + async def _get_deepseek_score(self, text: str) -> Optional[ScoringResult]: """Получает скор от DeepSeek сервиса.""" try: @@ -146,78 +147,77 @@ class ScoringManager: except Exception as e: # Ошибки уже залогированы в DeepSeekService, здесь только пробрасываем raise - + @track_time("on_post_published", "scoring_manager") async def on_post_published(self, text: str) -> None: """ Вызывается при публикации поста. - + Добавляет текст как положительный пример для обучения RAG. - + Args: text: Текст опубликованного поста """ if not text or not text.strip(): return - + tasks = [] - + if self.rag_client and self.rag_client.is_enabled: tasks.append(self.rag_client.add_positive_example(text)) - + if self.deepseek_service and self.deepseek_service.is_enabled: tasks.append(self.deepseek_service.add_positive_example(text)) - + if tasks: await asyncio.gather(*tasks, return_exceptions=True) logger.info("ScoringManager: Добавлен положительный пример") - + @track_time("on_post_declined", "scoring_manager") async def on_post_declined(self, text: str) -> None: """ Вызывается при отклонении поста. - + Добавляет текст как отрицательный пример для обучения RAG. - + Args: text: Текст отклоненного поста """ if not text or not text.strip(): return - + tasks = [] - + if self.rag_client and self.rag_client.is_enabled: tasks.append(self.rag_client.add_negative_example(text)) - + if self.deepseek_service and self.deepseek_service.is_enabled: tasks.append(self.deepseek_service.add_negative_example(text)) - + if tasks: await asyncio.gather(*tasks, return_exceptions=True) logger.info("ScoringManager: Добавлен отрицательный пример") - - + async def close(self) -> None: """Закрывает ресурсы всех сервисов.""" if self.deepseek_service: await self.deepseek_service.close() - + if self.rag_client: await self.rag_client.close() - + async def get_stats(self) -> dict: """Возвращает статистику всех сервисов.""" stats = { "any_enabled": self.is_any_enabled, } - + if self.rag_client: # Получаем статистику асинхронно от API rag_stats = await self.rag_client.get_stats() stats["rag"] = rag_stats if rag_stats else self.rag_client.get_stats_sync() - + if self.deepseek_service: stats["deepseek"] = self.deepseek_service.get_stats() - + return stats diff --git a/helper_bot/utils/auto_unban_scheduler.py b/helper_bot/utils/auto_unban_scheduler.py index 25be88d..91550b5 100644 --- a/helper_bot/utils/auto_unban_scheduler.py +++ b/helper_bot/utils/auto_unban_scheduler.py @@ -4,6 +4,7 @@ from typing import Optional from apscheduler.schedulers.asyncio import AsyncIOScheduler from apscheduler.triggers.cron import CronTrigger + from helper_bot.utils.base_dependency_factory import get_global_instance from logs.custom_logger import logger @@ -15,17 +16,17 @@ class AutoUnbanScheduler: Класс для автоматического разбана пользователей по истечении срока блокировки. Запускается ежедневно в 5:00 по московскому времени. """ - + def __init__(self): self.bdf = get_global_instance() self.bot_db = self.bdf.get_db() self.scheduler = AsyncIOScheduler() self.bot = None # Будет установлен позже - + def set_bot(self, bot): """Устанавливает экземпляр бота для отправки уведомлений""" self.bot = bot - + @track_time("auto_unban_users", "auto_unban_scheduler") @track_errors("auto_unban_scheduler", "auto_unban_users") @db_query_time("auto_unban_users", "users", "mixed") @@ -37,26 +38,32 @@ class AutoUnbanScheduler: """ try: logger.info("Запуск автоматического разбана пользователей") - + # Получаем текущий UNIX timestamp current_timestamp = int(datetime.now().timestamp()) - - logger.info(f"Поиск пользователей для разблокировки на timestamp: {current_timestamp}") - + + logger.info( + f"Поиск пользователей для разблокировки на timestamp: {current_timestamp}" + ) + # Получаем список пользователей для разблокировки - users_to_unban = await self.bot_db.get_users_for_unblock_today(current_timestamp) - + users_to_unban = await self.bot_db.get_users_for_unblock_today( + current_timestamp + ) + if not users_to_unban: logger.info("Нет пользователей для разблокировки сегодня") return - - logger.info(f"Найдено {len(users_to_unban)} пользователей для разблокировки") - + + logger.info( + f"Найдено {len(users_to_unban)} пользователей для разблокировки" + ) + # Список для отслеживания результатов success_count = 0 failed_count = 0 failed_users = [] - + # Разблокируем каждого пользователя for user_id in users_to_unban: try: @@ -71,92 +78,99 @@ class AutoUnbanScheduler: except Exception as e: failed_count += 1 failed_users.append(f"{user_id}") - logger.error(f"Исключение при разблокировке пользователя {user_id}: {e}") - + logger.error( + f"Исключение при разблокировке пользователя {user_id}: {e}" + ) + # Формируем отчет - report = self._generate_report(success_count, failed_count, failed_users, users_to_unban) - + report = self._generate_report( + success_count, failed_count, failed_users, users_to_unban + ) + # Отправляем отчет в лог-канал await self._send_report(report) - - logger.info(f"Автоматический разбан завершен. Успешно: {success_count}, Ошибок: {failed_count}") - + + logger.info( + f"Автоматический разбан завершен. Успешно: {success_count}, Ошибок: {failed_count}" + ) + except Exception as e: error_msg = f"Критическая ошибка в автоматическом разбане: {e}" logger.error(error_msg) await self._send_error_report(error_msg) - - def _generate_report(self, success_count: int, failed_count: int, - failed_users: list, all_users: dict) -> str: + + def _generate_report( + self, success_count: int, failed_count: int, failed_users: list, all_users: dict + ) -> str: """Генерирует отчет о результатах автоматического разбана""" report = f"🤖 Отчет об автоматическом разбане\n\n" report += f"📅 Дата: {datetime.now().strftime('%d.%m.%Y %H:%M')}\n" report += f"✅ Успешно разблокировано: {success_count}\n" report += f"❌ Ошибок: {failed_count}\n\n" - + if success_count > 0: report += "✅ Разблокированные пользователи:\n" for user_id in all_users: if str(user_id) not in failed_users: report += f"• ID: {user_id}\n" report += "\n" - + if failed_users: report += "❌ Ошибки при разблокировке:\n" for user in failed_users: report += f"• {user}\n" - + return report - + @track_time("send_report", "auto_unban_scheduler") @track_errors("auto_unban_scheduler", "send_report") async def _send_report(self, report: str): """Отправляет отчет в лог-канал""" try: if self.bot: - group_for_logs = self.bdf.settings['Telegram']['group_for_logs'] + group_for_logs = self.bdf.settings["Telegram"]["group_for_logs"] await self.bot.send_message( - chat_id=group_for_logs, - text=report, - parse_mode='HTML' + chat_id=group_for_logs, text=report, parse_mode="HTML" ) except Exception as e: logger.error(f"Ошибка при отправке отчета: {e}") - + @track_time("send_error_report", "auto_unban_scheduler") @track_errors("auto_unban_scheduler", "send_error_report") async def _send_error_report(self, error_msg: str): """Отправляет отчет об ошибке в важный лог-канал""" try: if self.bot: - important_logs = self.bdf.settings['Telegram']['important_logs'] + important_logs = self.bdf.settings["Telegram"]["important_logs"] await self.bot.send_message( chat_id=important_logs, text=f"🚨 Ошибка автоматического разбана\n\n{error_msg}", - parse_mode='HTML' + parse_mode="HTML", ) except Exception as e: logger.error(f"Ошибка при отправке отчета об ошибке: {e}") - + def start_scheduler(self): """Запускает планировщик задач""" try: # Добавляем задачу на ежедневное выполнение в 5:00 по Москве self.scheduler.add_job( self.auto_unban_users, - CronTrigger(hour=5, minute=0, timezone='Europe/Moscow'), - id='auto_unban_users', - name='Автоматический разбан пользователей', - replace_existing=True + CronTrigger(hour=5, minute=0, timezone="Europe/Moscow"), + id="auto_unban_users", + name="Автоматический разбан пользователей", + replace_existing=True, ) - + # Запускаем планировщик self.scheduler.start() - logger.info("Планировщик автоматического разбана запущен. Задача запланирована на 5:00 по Москве") - + logger.info( + "Планировщик автоматического разбана запущен. Задача запланирована на 5:00 по Москве" + ) + except Exception as e: logger.error(f"Ошибка при запуске планировщика: {e}") - + def stop_scheduler(self): """Останавливает планировщик задач""" try: @@ -165,7 +179,7 @@ class AutoUnbanScheduler: logger.info("Планировщик автоматического разбана остановлен") except Exception as e: logger.error(f"Ошибка при остановке планировщика: {e}") - + async def run_manual_unban(self): """Запускает разбан вручную (для тестирования)""" logger.info("Запуск ручного разбана пользователей") diff --git a/helper_bot/utils/base_dependency_factory.py b/helper_bot/utils/base_dependency_factory.py index f50c079..e3e1971 100644 --- a/helper_bot/utils/base_dependency_factory.py +++ b/helper_bot/utils/base_dependency_factory.py @@ -2,23 +2,26 @@ import os import sys from typing import Optional -from database.async_db import AsyncBotDB from dotenv import load_dotenv + +from database.async_db import AsyncBotDB from helper_bot.utils.s3_storage import S3StorageService from logs.custom_logger import logger class BaseDependencyFactory: def __init__(self): - project_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) - env_path = os.path.join(project_dir, '.env') + project_dir = os.path.dirname( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + ) + env_path = os.path.join(project_dir, ".env") if os.path.exists(env_path): load_dotenv(env_path) self.settings = {} self._project_dir = project_dir - database_path = os.getenv('DATABASE_PATH', 'database/tg-bot-database.db') + database_path = os.getenv("DATABASE_PATH", "database/tg-bot-database.db") if not os.path.isabs(database_path): database_path = os.path.join(project_dir, database_path) @@ -26,78 +29,87 @@ class BaseDependencyFactory: self._load_settings_from_env() self._init_s3_storage() - + # ScoringManager инициализируется лениво self._scoring_manager = None def _load_settings_from_env(self): """Загружает настройки из переменных окружения.""" - self.settings['Telegram'] = { - 'bot_token': os.getenv('BOT_TOKEN', ''), - 'listen_bot_token': os.getenv('LISTEN_BOT_TOKEN', ''), - 'test_bot_token': os.getenv('TEST_BOT_TOKEN', ''), - 'preview_link': self._parse_bool(os.getenv('PREVIEW_LINK', 'false')), - 'main_public': os.getenv('MAIN_PUBLIC', ''), - 'group_for_posts': self._parse_int(os.getenv('GROUP_FOR_POSTS', '0')), - 'group_for_message': self._parse_int(os.getenv('GROUP_FOR_MESSAGE', '0')), - 'group_for_logs': self._parse_int(os.getenv('GROUP_FOR_LOGS', '0')), - 'important_logs': self._parse_int(os.getenv('IMPORTANT_LOGS', '0')), - 'archive': self._parse_int(os.getenv('ARCHIVE', '0')), - 'test_group': self._parse_int(os.getenv('TEST_GROUP', '0')) + self.settings["Telegram"] = { + "bot_token": os.getenv("BOT_TOKEN", ""), + "listen_bot_token": os.getenv("LISTEN_BOT_TOKEN", ""), + "test_bot_token": os.getenv("TEST_BOT_TOKEN", ""), + "preview_link": self._parse_bool(os.getenv("PREVIEW_LINK", "false")), + "main_public": os.getenv("MAIN_PUBLIC", ""), + "group_for_posts": self._parse_int(os.getenv("GROUP_FOR_POSTS", "0")), + "group_for_message": self._parse_int(os.getenv("GROUP_FOR_MESSAGE", "0")), + "group_for_logs": self._parse_int(os.getenv("GROUP_FOR_LOGS", "0")), + "important_logs": self._parse_int(os.getenv("IMPORTANT_LOGS", "0")), + "archive": self._parse_int(os.getenv("ARCHIVE", "0")), + "test_group": self._parse_int(os.getenv("TEST_GROUP", "0")), } - self.settings['Settings'] = { - 'logs': self._parse_bool(os.getenv('LOGS', 'false')), - 'test': self._parse_bool(os.getenv('TEST', 'false')) + self.settings["Settings"] = { + "logs": self._parse_bool(os.getenv("LOGS", "false")), + "test": self._parse_bool(os.getenv("TEST", "false")), } - self.settings['Metrics'] = { - 'host': os.getenv('METRICS_HOST', '0.0.0.0'), - 'port': self._parse_int(os.getenv('METRICS_PORT', '8080')) + self.settings["Metrics"] = { + "host": os.getenv("METRICS_HOST", "0.0.0.0"), + "port": self._parse_int(os.getenv("METRICS_PORT", "8080")), } - self.settings['S3'] = { - 'enabled': self._parse_bool(os.getenv('S3_ENABLED', 'false')), - 'endpoint_url': os.getenv('S3_ENDPOINT_URL', ''), - 'access_key': os.getenv('S3_ACCESS_KEY', ''), - 'secret_key': os.getenv('S3_SECRET_KEY', ''), - 'bucket_name': os.getenv('S3_BUCKET_NAME', ''), - 'region': os.getenv('S3_REGION', 'us-east-1') + self.settings["S3"] = { + "enabled": self._parse_bool(os.getenv("S3_ENABLED", "false")), + "endpoint_url": os.getenv("S3_ENDPOINT_URL", ""), + "access_key": os.getenv("S3_ACCESS_KEY", ""), + "secret_key": os.getenv("S3_SECRET_KEY", ""), + "bucket_name": os.getenv("S3_BUCKET_NAME", ""), + "region": os.getenv("S3_REGION", "us-east-1"), } - + # Настройки ML-скоринга - self.settings['Scoring'] = { + self.settings["Scoring"] = { # RAG API - 'rag_enabled': self._parse_bool(os.getenv('RAG_ENABLED', 'false')), - 'rag_api_url': os.getenv('RAG_API_URL', ''), - 'rag_api_key': os.getenv('RAG_API_KEY', ''), - 'rag_api_timeout': self._parse_int(os.getenv('RAG_API_TIMEOUT', '30')), - 'rag_test_mode': self._parse_bool(os.getenv('RAG_TEST_MODE', 'false')), + "rag_enabled": self._parse_bool(os.getenv("RAG_ENABLED", "false")), + "rag_api_url": os.getenv("RAG_API_URL", ""), + "rag_api_key": os.getenv("RAG_API_KEY", ""), + "rag_api_timeout": self._parse_int(os.getenv("RAG_API_TIMEOUT", "30")), + "rag_test_mode": self._parse_bool(os.getenv("RAG_TEST_MODE", "false")), # DeepSeek - 'deepseek_enabled': self._parse_bool(os.getenv('DEEPSEEK_ENABLED', 'false')), - 'deepseek_api_key': os.getenv('DEEPSEEK_API_KEY', ''), - 'deepseek_api_url': os.getenv('DEEPSEEK_API_URL', 'https://api.deepseek.com/v1/chat/completions'), - 'deepseek_model': os.getenv('DEEPSEEK_MODEL', 'deepseek-chat'), - 'deepseek_timeout': self._parse_int(os.getenv('DEEPSEEK_TIMEOUT', '30')), + "deepseek_enabled": self._parse_bool( + os.getenv("DEEPSEEK_ENABLED", "false") + ), + "deepseek_api_key": os.getenv("DEEPSEEK_API_KEY", ""), + "deepseek_api_url": os.getenv( + "DEEPSEEK_API_URL", "https://api.deepseek.com/v1/chat/completions" + ), + "deepseek_model": os.getenv("DEEPSEEK_MODEL", "deepseek-chat"), + "deepseek_timeout": self._parse_int(os.getenv("DEEPSEEK_TIMEOUT", "30")), } - + def _init_s3_storage(self): """Инициализирует S3StorageService если S3 включен.""" self.s3_storage = None - if self.settings['S3']['enabled']: - s3_config = self.settings['S3'] - if s3_config['endpoint_url'] and s3_config['access_key'] and s3_config['secret_key'] and s3_config['bucket_name']: + if self.settings["S3"]["enabled"]: + s3_config = self.settings["S3"] + if ( + s3_config["endpoint_url"] + and s3_config["access_key"] + and s3_config["secret_key"] + and s3_config["bucket_name"] + ): self.s3_storage = S3StorageService( - endpoint_url=s3_config['endpoint_url'], - access_key=s3_config['access_key'], - secret_key=s3_config['secret_key'], - bucket_name=s3_config['bucket_name'], - region=s3_config['region'] + endpoint_url=s3_config["endpoint_url"], + access_key=s3_config["access_key"], + secret_key=s3_config["secret_key"], + bucket_name=s3_config["bucket_name"], + region=s3_config["region"], ) def _parse_bool(self, value: str) -> bool: """Парсит строковое значение в boolean.""" - return value.lower() in ('true', '1', 'yes', 'on') + return value.lower() in ("true", "1", "yes", "on") def _parse_int(self, value: str) -> int: """Парсит строковое значение в integer.""" @@ -105,7 +117,7 @@ class BaseDependencyFactory: return int(value) except (ValueError, TypeError): return 0 - + def _parse_float(self, value: str) -> float: """Парсит строковое значение в float.""" try: @@ -119,87 +131,95 @@ class BaseDependencyFactory: def get_db(self) -> AsyncBotDB: """Возвращает подключение к базе данных.""" return self.database - + def get_s3_storage(self) -> Optional[S3StorageService]: """Возвращает S3StorageService если S3 включен, иначе None.""" return self.s3_storage - + def _init_scoring_manager(self): """ Инициализирует ScoringManager с RAG API клиентом и DeepSeek сервисом. - + Вызывается лениво при первом обращении к get_scoring_manager(). """ - from helper_bot.services.scoring import (DeepSeekService, RagApiClient, - ScoringManager) - - scoring_config = self.settings['Scoring'] - + from helper_bot.services.scoring import ( + DeepSeekService, + RagApiClient, + ScoringManager, + ) + + scoring_config = self.settings["Scoring"] + # Инициализация RAG API клиента rag_client = None - if scoring_config['rag_enabled']: - api_url = scoring_config['rag_api_url'] - api_key = scoring_config['rag_api_key'] - + if scoring_config["rag_enabled"]: + api_url = scoring_config["rag_api_url"] + api_key = scoring_config["rag_api_key"] + if not api_url or not api_key: logger.warning("RAG включен, но не указаны RAG_API_URL или RAG_API_KEY") else: rag_client = RagApiClient( api_url=api_url, api_key=api_key, - timeout=scoring_config['rag_api_timeout'], - test_mode=scoring_config['rag_test_mode'], + timeout=scoring_config["rag_api_timeout"], + test_mode=scoring_config["rag_test_mode"], enabled=True, ) - logger.info(f"RagApiClient инициализирован: {api_url} (test_mode={scoring_config['rag_test_mode']})") - + logger.info( + f"RagApiClient инициализирован: {api_url} (test_mode={scoring_config['rag_test_mode']})" + ) + # Инициализация DeepSeek сервиса deepseek_service = None - if scoring_config['deepseek_enabled'] and scoring_config['deepseek_api_key']: + if scoring_config["deepseek_enabled"] and scoring_config["deepseek_api_key"]: deepseek_service = DeepSeekService( - api_key=scoring_config['deepseek_api_key'], - api_url=scoring_config['deepseek_api_url'], - model=scoring_config['deepseek_model'], - timeout=scoring_config['deepseek_timeout'], + api_key=scoring_config["deepseek_api_key"], + api_url=scoring_config["deepseek_api_url"], + model=scoring_config["deepseek_model"], + timeout=scoring_config["deepseek_timeout"], enabled=True, ) - logger.info(f"DeepSeekService инициализирован: {scoring_config['deepseek_model']}") - + logger.info( + f"DeepSeekService инициализирован: {scoring_config['deepseek_model']}" + ) + # Создаем менеджер self._scoring_manager = ScoringManager( rag_client=rag_client, deepseek_service=deepseek_service, ) - + return self._scoring_manager - + def get_scoring_manager(self): """ Возвращает ScoringManager для ML-скоринга постов. - + Инициализируется лениво при первом вызове. - + Returns: ScoringManager или None если скоринг полностью отключен """ if self._scoring_manager is None: - scoring_config = self.settings.get('Scoring', {}) - + scoring_config = self.settings.get("Scoring", {}) + # Проверяем, включен ли хотя бы один сервис - rag_enabled = scoring_config.get('rag_enabled', False) - deepseek_enabled = scoring_config.get('deepseek_enabled', False) - + rag_enabled = scoring_config.get("rag_enabled", False) + deepseek_enabled = scoring_config.get("deepseek_enabled", False) + if not rag_enabled and not deepseek_enabled: logger.info("Scoring полностью отключен (RAG и DeepSeek disabled)") return None - + self._init_scoring_manager() - + return self._scoring_manager _global_instance = None + def get_global_instance(): """Возвращает глобальный экземпляр BaseDependencyFactory.""" global _global_instance diff --git a/helper_bot/utils/helper_func.py b/helper_bot/utils/helper_func.py index 7d3e881..39aab76 100644 --- a/helper_bot/utils/helper_func.py +++ b/helper_bot/utils/helper_func.py @@ -10,45 +10,72 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union try: import emoji as _emoji_lib + _emoji_lib_available = True except ImportError: _emoji_lib = None _emoji_lib_available = False from aiogram import types -from aiogram.types import (FSInputFile, InputMediaAudio, InputMediaDocument, - InputMediaPhoto, InputMediaVideo) +from aiogram.types import ( + FSInputFile, + InputMediaAudio, + InputMediaDocument, + InputMediaPhoto, + InputMediaVideo, +) + from database.models import TelegramPost -from helper_bot.utils.base_dependency_factory import (BaseDependencyFactory, - get_global_instance) +from helper_bot.utils.base_dependency_factory import ( + BaseDependencyFactory, + get_global_instance, +) from logs.custom_logger import logger # Local imports - metrics -from .metrics import (db_query_time, track_errors, track_file_operations, - track_media_processing, track_time) +from .metrics import ( + db_query_time, + track_errors, + track_file_operations, + track_media_processing, + track_time, +) bdf = get_global_instance() -#TODO: поменять архитектуру и подключить правильный BotDB +# TODO: поменять архитектуру и подключить правильный BotDB BotDB = bdf.get_db() -GROUP_FOR_LOGS = bdf.settings['Telegram']['group_for_logs'] +GROUP_FOR_LOGS = bdf.settings["Telegram"]["group_for_logs"] if _emoji_lib_available and _emoji_lib is not None: emoji_list = list(_emoji_lib.EMOJI_DATA.keys()) else: # Fallback minimal emoji set for environments without the 'emoji' package (e.g., CI tests) emoji_list = [ - "🙂", "😀", "😉", "😎", "🤖", "🦄", "🐱", "🐶", "🍀", "🔥", - "🌟", "🎉", "💡", "🚀", "🌈" + "🙂", + "😀", + "😉", + "😎", + "🤖", + "🦄", + "🐱", + "🐶", + "🍀", + "🔥", + "🌟", + "🎉", + "💡", + "🚀", + "🌈", ] def safe_html_escape(text: str) -> str: """ Безопасно экранирует текст для использования в HTML разметке. - + Args: text: Текст для экранирования - + Returns: str: Экранированный текст """ @@ -62,10 +89,10 @@ def safe_html_escape(text: str) -> str: def get_first_name(message: types.Message) -> str: """ Безопасно получает и экранирует имя пользователя для использования в HTML разметке. - + Args: message: Сообщение от пользователя - + Returns: str: Экранированное имя пользователя или пустая строка если имя отсутствует """ @@ -76,8 +103,8 @@ def get_first_name(message: types.Message) -> str: # Дополнительная проверка на специальные символы, которые могут вызвать проблемы в HTML first_name = str(message.from_user.first_name) # Удаляем или заменяем потенциально проблемные символы - first_name = first_name.replace('\u0cc0', '') # Убираем символ "ೀ" (U+0CC0) - first_name = first_name.replace('\u0cc1', '') # Убираем символ "ೀ" (U+0CC1) + first_name = first_name.replace("\u0cc0", "") # Убираем символ "ೀ" (U+0CC0) + first_name = first_name.replace("\u0cc1", "") # Убираем символ "ೀ" (U+0CC1) first_name = html.escape(first_name) return first_name return "" @@ -96,25 +123,25 @@ def determine_anonymity(post_text: str) -> bool: """ if not post_text: return False - + post_text_lower = post_text.lower() - + # Сначала проверяем "неанон" или "не анон" (более специфичное условие) if "неанон" in post_text_lower or "не анон" in post_text_lower: return False - + # Проверяем "анон" if "анон" in post_text_lower: return True - + # По умолчанию False return False def get_text_message( - post_text: str, - first_name: str, - username: str = None, + post_text: str, + first_name: str, + username: str = None, is_anonymous: Optional[bool] = None, deepseek_score: Optional[float] = None, rag_score: Optional[float] = None, @@ -140,34 +167,38 @@ def get_text_message( """ # Экранируем post_text для безопасного использования в HTML safe_post_text = html.escape(str(post_text)) if post_text else "" - + # Экранируем username для безопасного использования в HTML safe_username = html.escape(username) if username else None - + # Формируем строку с информацией об авторе if safe_username: author_info = f"{first_name} @{safe_username}" else: author_info = f"{first_name} (Ник не указан)" - + # Формируем базовый текст # Если передан is_anonymous, используем его, иначе определяем по тексту (legacy) if is_anonymous is not None: if is_anonymous: - final_text = f'{safe_post_text}\n\nПост опубликован анонимно' + final_text = f"{safe_post_text}\n\nПост опубликован анонимно" else: - final_text = f'{safe_post_text}\n\nАвтор поста: {author_info}' + final_text = f"{safe_post_text}\n\nАвтор поста: {author_info}" else: # Legacy: определяем по тексту if "неанон" in post_text or "не анон" in post_text: - final_text = f'{safe_post_text}\n\nАвтор поста: {author_info}' + final_text = f"{safe_post_text}\n\nАвтор поста: {author_info}" elif "анон" in post_text: - final_text = f'{safe_post_text}\n\nПост опубликован анонимно' + final_text = f"{safe_post_text}\n\nПост опубликован анонимно" else: - final_text = f'{safe_post_text}\n\nАвтор поста: {author_info}' - + final_text = f"{safe_post_text}\n\nАвтор поста: {author_info}" + # Добавляем блок со скорами если есть - if deepseek_score is not None or rag_score is not None or rag_score_pos_only is not None: + if ( + deepseek_score is not None + or rag_score is not None + or rag_score_pos_only is not None + ): scores_lines = ["\n📊 Уверенность в одобрении:"] if deepseek_score is not None: scores_lines.append(f"DeepSeek: {deepseek_score:.2f}") @@ -186,14 +217,16 @@ def get_text_message( if rag_score_pos_only is not None: scores_lines.append(f"RAG pos only: {rag_score_pos_only:.2f}") final_text += "\n" + "\n".join(scores_lines) - + return final_text + @track_time("download_file", "helper_func") @track_errors("helper_func", "download_file") @track_file_operations("unknown") -async def download_file(message: types.Message, file_id: str, content_type: str = None, - s3_storage = None) -> Optional[str]: +async def download_file( + message: types.Message, file_id: str, content_type: str = None, s3_storage=None +) -> Optional[str]: """ Скачивает файл по file_id из Telegram и сохраняет в S3 или на локальный диск. @@ -207,53 +240,63 @@ async def download_file(message: types.Message, file_id: str, content_type: str S3 ключ (если s3_storage указан) или локальный путь к файлу, иначе None """ start_time = time.time() - + try: # Валидация параметров if not file_id or not message or not message.bot: - logger.error("download_file: Неверные параметры - file_id, message или bot отсутствуют") + logger.error( + "download_file: Неверные параметры - file_id, message или bot отсутствуют" + ) return None - + # Получаем информацию о файле file = await message.bot.get_file(file_id) if not file or not file.file_path: - logger.error(f"download_file: Не удалось получить информацию о файле {file_id}") + logger.error( + f"download_file: Не удалось получить информацию о файле {file_id}" + ) return None - + # Определяем расширение original_filename = os.path.basename(file.file_path) - file_extension = os.path.splitext(original_filename)[1] or '.bin' - + file_extension = os.path.splitext(original_filename)[1] or ".bin" + if s3_storage: # Сохраняем в S3 # Скачиваем во временный файл temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=file_extension) temp_path = temp_file.name temp_file.close() - + try: # Скачиваем из Telegram - await message.bot.download_file(file_path=file.file_path, destination=temp_path) - + await message.bot.download_file( + file_path=file.file_path, destination=temp_path + ) + # Генерируем S3 ключ s3_key = s3_storage.generate_s3_key(content_type, file_id) - + # Загружаем в S3 success = await s3_storage.upload_file(temp_path, s3_key) - + # Удаляем временный файл try: os.remove(temp_path) except: pass - + if success: - file_size = file.file_size if hasattr(file, 'file_size') else 0 + file_size = file.file_size if hasattr(file, "file_size") else 0 download_time = time.time() - start_time - logger.info(f"download_file: Файл загружен в S3 - {s3_key}, размер: {file_size} байт, время: {download_time:.2f}с") + logger.info( + f"download_file: Файл загружен в S3 - {s3_key}, размер: {file_size} байт, время: {download_time:.2f}с" + ) return s3_key else: - logger.error(f"download_file: Не удалось загрузить файл в S3: {s3_key}") + logger.error( + f"download_file: Не удалось загрузить файл в S3: {s3_key}" + ) return None except Exception as e: # Удаляем временный файл при ошибке @@ -262,57 +305,68 @@ async def download_file(message: types.Message, file_id: str, content_type: str except: pass download_time = time.time() - start_time - logger.error(f"download_file: Ошибка загрузки файла в S3 {file_id}: {e}, время: {download_time:.2f}с") + logger.error( + f"download_file: Ошибка загрузки файла в S3 {file_id}: {e}, время: {download_time:.2f}с" + ) return None else: # Старая логика - сохраняем на локальный диск # Определяем папку по типу контента type_folders = { - 'photo': 'photos', - 'video': 'videos', - 'audio': 'music', - 'voice': 'voice', - 'video_note': 'video_notes' + "photo": "photos", + "video": "videos", + "audio": "music", + "voice": "voice", + "video_note": "video_notes", } - - folder = type_folders.get(content_type, 'other') + + folder = type_folders.get(content_type, "other") base_path = "files" full_folder_path = os.path.join(base_path, folder) - + # Создаем необходимые папки os.makedirs(base_path, exist_ok=True) os.makedirs(full_folder_path, exist_ok=True) - - logger.debug(f"download_file: Начинаю скачивание файла {file_id} типа {content_type} в папку {folder}") - + + logger.debug( + f"download_file: Начинаю скачивание файла {file_id} типа {content_type} в папку {folder}" + ) + # Генерируем уникальное имя файла safe_filename = f"{file_id}{file_extension}" file_path = os.path.join(full_folder_path, safe_filename) - + # Скачиваем файл - await message.bot.download_file(file_path=file.file_path, destination=file_path) - + await message.bot.download_file( + file_path=file.file_path, destination=file_path + ) + # Проверяем, что файл действительно скачался if not os.path.exists(file_path): logger.error(f"download_file: Файл не был скачан - {file_path}") return None - + file_size = os.path.getsize(file_path) download_time = time.time() - start_time - - logger.info(f"download_file: Файл успешно скачан - {file_path}, размер: {file_size} байт, время: {download_time:.2f}с") - + + logger.info( + f"download_file: Файл успешно скачан - {file_path}, размер: {file_size} байт, время: {download_time:.2f}с" + ) + return file_path - + except Exception as e: download_time = time.time() - start_time - logger.error(f"download_file: Ошибка скачивания файла {file_id}: {e}, время: {download_time:.2f}с") + logger.error( + f"download_file: Ошибка скачивания файла {file_id}: {e}, время: {download_time:.2f}с" + ) return None + @track_time("prepare_media_group_from_middlewares", "helper_func") @track_errors("helper_func", "prepare_media_group_from_middlewares") @track_media_processing("media_group") -async def prepare_media_group_from_middlewares(album, post_caption: str = ''): +async def prepare_media_group_from_middlewares(album, post_caption: str = ""): """ Создает MediaGroup согласно best practices aiogram 3.x. @@ -325,7 +379,7 @@ async def prepare_media_group_from_middlewares(album, post_caption: str = ''): """ # Экранируем post_caption для безопасного использования в HTML safe_post_caption = html.escape(str(post_caption)) if post_caption else "" - + media_group = [] for i, message in enumerate(album): @@ -333,28 +387,36 @@ async def prepare_media_group_from_middlewares(album, post_caption: str = ''): file_id = message.photo[-1].file_id # Для фото используем InputMediaPhoto if i == 0: # Первое фото получает подпись - media_group.append(InputMediaPhoto(media=file_id, caption=safe_post_caption)) + media_group.append( + InputMediaPhoto(media=file_id, caption=safe_post_caption) + ) else: media_group.append(InputMediaPhoto(media=file_id)) elif message.video: file_id = message.video.file_id # Для видео используем InputMediaVideo if i == 0: # Первое видео получает подпись - media_group.append(InputMediaVideo(media=file_id, caption=safe_post_caption)) + media_group.append( + InputMediaVideo(media=file_id, caption=safe_post_caption) + ) else: media_group.append(InputMediaVideo(media=file_id)) elif message.audio: file_id = message.audio.file_id # Для аудио используем InputMediaAudio if i == 0: # Первое аудио получает подпись - media_group.append(InputMediaAudio(media=file_id, caption=safe_post_caption)) + media_group.append( + InputMediaAudio(media=file_id, caption=safe_post_caption) + ) else: media_group.append(InputMediaAudio(media=file_id)) elif message.document: file_id = message.document.file_id # Для документов используем InputMediaDocument (если поддерживается) if i == 0: # Первый документ получает подпись - media_group.append(InputMediaDocument(media=file_id, caption=safe_post_caption)) + media_group.append( + InputMediaDocument(media=file_id, caption=safe_post_caption) + ) else: media_group.append(InputMediaDocument(media=file_id)) else: @@ -363,21 +425,38 @@ async def prepare_media_group_from_middlewares(album, post_caption: str = ''): return media_group -async def _save_media_group_background(sent_message: List[types.Message], bot_db: Any, main_post_id: Optional[int], s3_storage) -> None: + +async def _save_media_group_background( + sent_message: List[types.Message], + bot_db: Any, + main_post_id: Optional[int], + s3_storage, +) -> None: """Сохраняет медиагруппу в фоне, чтобы не блокировать ответ пользователю""" try: - success = await add_in_db_media_mediagroup(sent_message, bot_db, main_post_id, s3_storage) + success = await add_in_db_media_mediagroup( + sent_message, bot_db, main_post_id, s3_storage + ) if not success: - logger.warning(f"_save_media_group_background: Не удалось сохранить медиа для медиагруппы {sent_message[-1].message_id}") + logger.warning( + f"_save_media_group_background: Не удалось сохранить медиа для медиагруппы {sent_message[-1].message_id}" + ) except Exception as e: - logger.error(f"_save_media_group_background: Ошибка при сохранении медиа для медиагруппы {sent_message[-1].message_id}: {e}") + logger.error( + f"_save_media_group_background: Ошибка при сохранении медиа для медиагруппы {sent_message[-1].message_id}: {e}" + ) + @track_time("add_in_db_media_mediagroup", "helper_func") @track_errors("helper_func", "add_in_db_media_mediagroup") @track_media_processing("media_group") @db_query_time("add_in_db_media_mediagroup", "posts", "insert") -async def add_in_db_media_mediagroup(sent_message: List[types.Message], bot_db: Any, - main_post_id: Optional[int] = None, s3_storage = None) -> bool: +async def add_in_db_media_mediagroup( + sent_message: List[types.Message], + bot_db: Any, + main_post_id: Optional[int] = None, + s3_storage=None, +) -> bool: """ Добавляет контент медиа-группы в базу данных @@ -390,100 +469,130 @@ async def add_in_db_media_mediagroup(sent_message: List[types.Message], bot_db: bool: True если весь контент успешно добавлен, False в случае ошибки """ start_time = time.time() - + try: # Валидация параметров if not sent_message or not bot_db or not isinstance(sent_message, list): - logger.error("add_in_db_media_mediagroup: Неверные параметры - sent_message, bot_db или sent_message не является списком") + logger.error( + "add_in_db_media_mediagroup: Неверные параметры - sent_message, bot_db или sent_message не является списком" + ) return False - + if len(sent_message) == 0: logger.warning("add_in_db_media_mediagroup: Пустая медиагруппа") return False - + post_id = main_post_id or sent_message[-1].message_id - + processed_count = 0 failed_count = 0 - + for i, message in enumerate(sent_message): try: content_type = None file_id = None - + if message.photo: - content_type = 'photo' + content_type = "photo" file_id = message.photo[-1].file_id elif message.video: - content_type = 'video' + content_type = "video" file_id = message.video.file_id elif message.audio: - content_type = 'audio' + content_type = "audio" file_id = message.audio.file_id elif message.voice: - content_type = 'voice' + content_type = "voice" file_id = message.voice.file_id elif message.video_note: - content_type = 'video_note' + content_type = "video_note" file_id = message.video_note.file_id else: - logger.warning(f"add_in_db_media_mediagroup: Неподдерживаемый тип контента в сообщении {i+1}/{len(sent_message)}") + logger.warning( + f"add_in_db_media_mediagroup: Неподдерживаемый тип контента в сообщении {i+1}/{len(sent_message)}" + ) failed_count += 1 continue - + if not file_id: - logger.error(f"add_in_db_media_mediagroup: file_id отсутствует в сообщении {i+1}/{len(sent_message)}") + logger.error( + f"add_in_db_media_mediagroup: file_id отсутствует в сообщении {i+1}/{len(sent_message)}" + ) failed_count += 1 continue - + if s3_storage is None: bdf = get_global_instance() s3_storage = bdf.get_s3_storage() - - file_path = await download_file(message, file_id=file_id, content_type=content_type, s3_storage=s3_storage) + + file_path = await download_file( + message, + file_id=file_id, + content_type=content_type, + s3_storage=s3_storage, + ) if not file_path: - logger.error(f"add_in_db_media_mediagroup: Не удалось скачать файл {file_id} в сообщении {i+1}/{len(sent_message)}") + logger.error( + f"add_in_db_media_mediagroup: Не удалось скачать файл {file_id} в сообщении {i+1}/{len(sent_message)}" + ) failed_count += 1 continue - - success = await bot_db.add_post_content(post_id, post_id, file_path, content_type) + + success = await bot_db.add_post_content( + post_id, post_id, file_path, content_type + ) if not success: - logger.error(f"add_in_db_media_mediagroup: Не удалось добавить контент в БД для сообщения {i+1}/{len(sent_message)}") - if file_path.startswith('files/'): + logger.error( + f"add_in_db_media_mediagroup: Не удалось добавить контент в БД для сообщения {i+1}/{len(sent_message)}" + ) + if file_path.startswith("files/"): try: os.remove(file_path) except Exception as e: - logger.warning(f"add_in_db_media_mediagroup: Не удалось удалить файл {file_path}: {e}") + logger.warning( + f"add_in_db_media_mediagroup: Не удалось удалить файл {file_path}: {e}" + ) failed_count += 1 continue - + processed_count += 1 - + except Exception as e: - logger.error(f"add_in_db_media_mediagroup: Ошибка обработки сообщения {i+1}/{len(sent_message)}: {e}") + logger.error( + f"add_in_db_media_mediagroup: Ошибка обработки сообщения {i+1}/{len(sent_message)}: {e}" + ) failed_count += 1 continue - + if processed_count == 0: - logger.error(f"add_in_db_media_mediagroup: Не удалось обработать ни одного сообщения из медиагруппы {post_id}") + logger.error( + f"add_in_db_media_mediagroup: Не удалось обработать ни одного сообщения из медиагруппы {post_id}" + ) return False - + if failed_count > 0: - logger.warning(f"add_in_db_media_mediagroup: Обработано {processed_count}/{len(sent_message)} сообщений медиагруппы {post_id}, ошибок: {failed_count}") - + logger.warning( + f"add_in_db_media_mediagroup: Обработано {processed_count}/{len(sent_message)} сообщений медиагруппы {post_id}, ошибок: {failed_count}" + ) + return failed_count == 0 - + except Exception as e: processing_time = time.time() - start_time - logger.error(f"add_in_db_media_mediagroup: Критическая ошибка обработки медиагруппы: {e}, время: {processing_time:.2f}с") + logger.error( + f"add_in_db_media_mediagroup: Критическая ошибка обработки медиагруппы: {e}, время: {processing_time:.2f}с" + ) return False + @track_time("add_in_db_media", "helper_func") @track_errors("helper_func", "add_in_db_media") @track_media_processing("media_group") @db_query_time("add_in_db_media", "posts", "insert") @track_file_operations("media") -async def add_in_db_media(sent_message: types.Message, bot_db: Any, s3_storage = None) -> bool: +async def add_in_db_media( + sent_message: types.Message, bot_db: Any, s3_storage=None +) -> bool: """ Добавляет контент одиночного сообщения в базу данных @@ -495,86 +604,120 @@ async def add_in_db_media(sent_message: types.Message, bot_db: Any, s3_storage = bool: True если контент успешно добавлен, False в случае ошибки """ start_time = time.time() - + try: # Валидация параметров if not sent_message or not bot_db: - logger.error("add_in_db_media: Неверные параметры - sent_message или bot_db отсутствуют") + logger.error( + "add_in_db_media: Неверные параметры - sent_message или bot_db отсутствуют" + ) return False - + post_id = sent_message.message_id # ID поста (это же сообщение) content_type = None file_id = None - + # Определяем тип контента и file_id if sent_message.photo: - content_type = 'photo' + content_type = "photo" file_id = sent_message.photo[-1].file_id elif sent_message.video: - content_type = 'video' + content_type = "video" file_id = sent_message.video.file_id elif sent_message.voice: - content_type = 'voice' + content_type = "voice" file_id = sent_message.voice.file_id elif sent_message.audio: - content_type = 'audio' + content_type = "audio" file_id = sent_message.audio.file_id elif sent_message.video_note: - content_type = 'video_note' + content_type = "video_note" file_id = sent_message.video_note.file_id else: - logger.warning(f"add_in_db_media: Неподдерживаемый тип контента для сообщения {post_id}") + logger.warning( + f"add_in_db_media: Неподдерживаемый тип контента для сообщения {post_id}" + ) return False - + if not file_id: - logger.error(f"add_in_db_media: file_id отсутствует для сообщения {post_id}") + logger.error( + f"add_in_db_media: file_id отсутствует для сообщения {post_id}" + ) return False - - logger.debug(f"add_in_db_media: Обрабатываю {content_type} для сообщения {post_id}") - + + logger.debug( + f"add_in_db_media: Обрабатываю {content_type} для сообщения {post_id}" + ) + # Получаем s3_storage если не передан if s3_storage is None: bdf = get_global_instance() s3_storage = bdf.get_s3_storage() - + # Скачиваем файл (в S3 или на локальный диск) - file_path = await download_file(sent_message, file_id=file_id, content_type=content_type, s3_storage=s3_storage) + file_path = await download_file( + sent_message, + file_id=file_id, + content_type=content_type, + s3_storage=s3_storage, + ) if not file_path: - logger.error(f"add_in_db_media: Не удалось скачать файл {file_id} для сообщения {post_id}") + logger.error( + f"add_in_db_media: Не удалось скачать файл {file_id} для сообщения {post_id}" + ) return False - + # Добавляем в базу данных - success = await bot_db.add_post_content(post_id, sent_message.message_id, file_path, content_type) + success = await bot_db.add_post_content( + post_id, sent_message.message_id, file_path, content_type + ) if not success: - logger.error(f"add_in_db_media: Не удалось добавить контент в БД для сообщения {post_id}") + logger.error( + f"add_in_db_media: Не удалось добавить контент в БД для сообщения {post_id}" + ) # Удаляем скачанный файл при ошибке БД (только если это локальный файл, не S3) - if file_path.startswith('files/'): + if file_path.startswith("files/"): try: os.remove(file_path) - logger.debug(f"add_in_db_media: Удален файл {file_path} после ошибки БД") + logger.debug( + f"add_in_db_media: Удален файл {file_path} после ошибки БД" + ) except Exception as e: - logger.warning(f"add_in_db_media: Не удалось удалить файл {file_path}: {e}") + logger.warning( + f"add_in_db_media: Не удалось удалить файл {file_path}: {e}" + ) return False - + processing_time = time.time() - start_time - logger.info(f"add_in_db_media: Контент успешно добавлен для сообщения {post_id}, тип: {content_type}, время: {processing_time:.2f}с") - + logger.info( + f"add_in_db_media: Контент успешно добавлен для сообщения {post_id}, тип: {content_type}, время: {processing_time:.2f}с" + ) + return True - + except Exception as e: processing_time = time.time() - start_time - logger.error(f"add_in_db_media: Ошибка обработки медиа для сообщения {post_id}: {e}, время: {processing_time:.2f}с") + logger.error( + f"add_in_db_media: Ошибка обработки медиа для сообщения {post_id}: {e}, время: {processing_time:.2f}с" + ) return False + @track_time("send_media_group_message_to_private_chat", "helper_func") @track_errors("helper_func", "send_media_group_message_to_private_chat") @track_media_processing("media_group") @db_query_time("send_media_group_message_to_private_chat", "posts", "insert") -async def send_media_group_message_to_private_chat(chat_id: int, message: types.Message, - media_group: List, bot_db: Any, main_post_id: Optional[int] = None, s3_storage=None) -> List[int]: +async def send_media_group_message_to_private_chat( + chat_id: int, + message: types.Message, + media_group: List, + bot_db: Any, + main_post_id: Optional[int] = None, + s3_storage=None, +) -> List[int]: """ Отправляет медиагруппу в чат и возвращает все message_id отправленных сообщений. - + Args: chat_id: ID чата для отправки message: Оригинальное сообщение от пользователя @@ -582,7 +725,7 @@ async def send_media_group_message_to_private_chat(chat_id: int, message: types. bot_db: Экземпляр базы данных main_post_id: ID основного поста в БД (опционально) s3_storage: S3StorageService для сохранения медиа - + Returns: List[int]: Список всех message_id отправленных сообщений медиагруппы """ @@ -590,18 +733,23 @@ async def send_media_group_message_to_private_chat(chat_id: int, message: types. chat_id=chat_id, media=media_group, ) - + sent_message_ids = [msg.message_id for msg in sent_messages] main_message_id = sent_message_ids[-1] - - asyncio.create_task(_save_media_group_background(sent_messages, bot_db, main_message_id, s3_storage)) - + + asyncio.create_task( + _save_media_group_background(sent_messages, bot_db, main_message_id, s3_storage) + ) + return sent_message_ids + @track_time("send_media_group_to_channel", "helper_func") @track_errors("helper_func", "send_media_group_to_channel") @track_media_processing("media_group") -async def send_media_group_to_channel(bot, chat_id: int, post_content: List, post_text: str, s3_storage = None): +async def send_media_group_to_channel( + bot, chat_id: int, post_content: List, post_text: str, s3_storage=None +): """ Отправляет медиа-группу с подписью к последнему файлу. @@ -612,25 +760,33 @@ async def send_media_group_to_channel(bot, chat_id: int, post_content: List, pos post_text: Текст подписи. s3_storage: опциональный S3StorageService для работы с S3. """ - logger.info(f"Начинаю отправку медиа-группы в чат {chat_id}, количество файлов: {len(post_content)}") - + logger.info( + f"Начинаю отправку медиа-группы в чат {chat_id}, количество файлов: {len(post_content)}" + ) + # Получаем s3_storage если не передан if s3_storage is None: bdf = get_global_instance() s3_storage = bdf.get_s3_storage() - + media = [] temp_files = [] # Для хранения путей к временным файлам - + try: for i, file_path_tuple in enumerate(post_content): try: file_path, content_type = file_path_tuple - logger.debug(f"Обрабатываю файл {i+1}/{len(post_content)}: {file_path} (тип: {content_type})") - + logger.debug( + f"Обрабатываю файл {i+1}/{len(post_content)}: {file_path} (тип: {content_type})" + ) + # Проверяем, это S3 ключ или локальный путь actual_path = file_path - if s3_storage and not file_path.startswith('files/') and not os.path.exists(file_path): + if ( + s3_storage + and not file_path.startswith("files/") + and not os.path.exists(file_path) + ): # Это S3 ключ, скачиваем во временный файл temp_path = await s3_storage.download_to_temp(file_path) if not temp_path: @@ -641,15 +797,17 @@ async def send_media_group_to_channel(bot, chat_id: int, post_content: List, pos elif not os.path.exists(file_path): logger.error(f"Файл не найден: {file_path}") continue - + file = FSInputFile(path=actual_path) - - if content_type == 'video': + + if content_type == "video": media.append(types.InputMediaVideo(media=file)) - elif content_type == 'photo': + elif content_type == "photo": media.append(types.InputMediaPhoto(media=file)) else: - logger.warning(f"Неизвестный тип файла: {content_type} для {file_path}") + logger.warning( + f"Неизвестный тип файла: {content_type} для {file_path}" + ) except FileNotFoundError: logger.error(f"Файл не найден: {file_path_tuple[0]}") continue @@ -664,11 +822,15 @@ async def send_media_group_to_channel(bot, chat_id: int, post_content: List, pos # Экранируем post_text для безопасного использования в HTML safe_post_text = html.escape(str(post_text)) if post_text else "" media[-1].caption = safe_post_text - logger.debug(f"Добавлена подпись к последнему файлу: {safe_post_text[:50]}{'...' if len(safe_post_text) > 50 else ''}") + logger.debug( + f"Добавлена подпись к последнему файлу: {safe_post_text[:50]}{'...' if len(safe_post_text) > 50 else ''}" + ) try: sent_messages = await bot.send_media_group(chat_id=chat_id, media=media) - logger.info(f"Медиа-группа успешно отправлена в чат {chat_id}, количество сообщений: {len(sent_messages)}") + logger.info( + f"Медиа-группа успешно отправлена в чат {chat_id}, количество сообщений: {len(sent_messages)}" + ) return sent_messages except Exception as e: logger.error(f"Ошибка при отправке медиа-группы в чат {chat_id}: {e}") @@ -681,141 +843,148 @@ async def send_media_group_to_channel(bot, chat_id: int, post_content: List, pos except: pass + @track_time("send_text_message", "helper_func") @track_errors("helper_func", "send_text_message") -async def send_text_message(chat_id, message: types.Message, post_text: str, markup: types.ReplyKeyboardMarkup = None): +async def send_text_message( + chat_id, + message: types.Message, + post_text: str, + markup: types.ReplyKeyboardMarkup = None, +): from .rate_limiter import send_with_rate_limit # Экранируем post_text для безопасного использования в HTML safe_post_text = html.escape(str(post_text)) if post_text else "" - + async def _send_message(): if markup is None: - return await message.bot.send_message( - chat_id=chat_id, - text=safe_post_text - ) + return await message.bot.send_message(chat_id=chat_id, text=safe_post_text) else: return await message.bot.send_message( - chat_id=chat_id, - text=safe_post_text, - reply_markup=markup + chat_id=chat_id, text=safe_post_text, reply_markup=markup ) - + sent_message = await send_with_rate_limit(_send_message, chat_id) return sent_message + @track_time("send_photo_message", "helper_func") @track_errors("helper_func", "send_photo_message") -async def send_photo_message(chat_id, message: types.Message, photo: str, post_text: str, - markup: types.ReplyKeyboardMarkup = None): +async def send_photo_message( + chat_id, + message: types.Message, + photo: str, + post_text: str, + markup: types.ReplyKeyboardMarkup = None, +): # Экранируем post_text для безопасного использования в HTML safe_post_text = html.escape(str(post_text)) if post_text else "" - + if markup is None: sent_message = await message.bot.send_photo( - chat_id=chat_id, - caption=safe_post_text, - photo=photo + chat_id=chat_id, caption=safe_post_text, photo=photo ) else: sent_message = await message.bot.send_photo( - chat_id=chat_id, - caption=safe_post_text, - photo=photo, - reply_markup=markup + chat_id=chat_id, caption=safe_post_text, photo=photo, reply_markup=markup ) return sent_message + @track_time("send_video_message", "helper_func") @track_errors("helper_func", "send_video_message") -async def send_video_message(chat_id, message: types.Message, video: str, post_text: str = "", - markup: types.ReplyKeyboardMarkup = None): +async def send_video_message( + chat_id, + message: types.Message, + video: str, + post_text: str = "", + markup: types.ReplyKeyboardMarkup = None, +): # Экранируем post_text для безопасного использования в HTML safe_post_text = html.escape(str(post_text)) if post_text else "" - + if markup is None: sent_message = await message.bot.send_video( - chat_id=chat_id, - caption=safe_post_text, - video=video + chat_id=chat_id, caption=safe_post_text, video=video ) else: sent_message = await message.bot.send_video( - chat_id=chat_id, - caption=safe_post_text, - video=video, - reply_markup=markup + chat_id=chat_id, caption=safe_post_text, video=video, reply_markup=markup ) return sent_message + @track_time("send_video_note_message", "helper_func") @track_errors("helper_func", "send_video_note_message") -async def send_video_note_message(chat_id, message: types.Message, video_note: str, - markup: types.ReplyKeyboardMarkup = None): +async def send_video_note_message( + chat_id, + message: types.Message, + video_note: str, + markup: types.ReplyKeyboardMarkup = None, +): if markup is None: sent_message = await message.bot.send_video_note( - chat_id=chat_id, - video_note=video_note + chat_id=chat_id, video_note=video_note ) else: sent_message = await message.bot.send_video_note( - chat_id=chat_id, - video_note=video_note, - reply_markup=markup + chat_id=chat_id, video_note=video_note, reply_markup=markup ) return sent_message + @track_time("send_audio_message", "helper_func") @track_errors("helper_func", "send_audio_message") -async def send_audio_message(chat_id, message: types.Message, audio: str, post_text: str, - markup: types.ReplyKeyboardMarkup = None): +async def send_audio_message( + chat_id, + message: types.Message, + audio: str, + post_text: str, + markup: types.ReplyKeyboardMarkup = None, +): # Экранируем post_text для безопасного использования в HTML safe_post_text = html.escape(str(post_text)) if post_text else "" - + if markup is None: sent_message = await message.bot.send_audio( - chat_id=chat_id, - caption=safe_post_text, - audio=audio + chat_id=chat_id, caption=safe_post_text, audio=audio ) else: sent_message = await message.bot.send_audio( - chat_id=chat_id, - caption=safe_post_text, - audio=audio, - reply_markup=markup + chat_id=chat_id, caption=safe_post_text, audio=audio, reply_markup=markup ) return sent_message @track_time("send_voice_message", "helper_func") @track_errors("helper_func", "send_voice_message") -async def send_voice_message(chat_id, message: types.Message, voice: str, - markup: types.ReplyKeyboardMarkup = None): +async def send_voice_message( + chat_id, + message: types.Message, + voice: str, + markup: types.ReplyKeyboardMarkup = None, +): from .rate_limiter import send_with_rate_limit - + async def _send_voice(): if markup is None: - return await message.bot.send_voice( - chat_id=chat_id, - voice=voice - ) + return await message.bot.send_voice(chat_id=chat_id, voice=voice) else: return await message.bot.send_voice( - chat_id=chat_id, - voice=voice, - reply_markup=markup + chat_id=chat_id, voice=voice, reply_markup=markup ) - + return await send_with_rate_limit(_send_voice, chat_id) + @track_time("check_access", "helper_func") @track_errors("helper_func", "check_access") @db_query_time("check_access", "users", "select") async def check_access(user_id: int, bot_db): """Проверка прав на совершение действий""" from logs.custom_logger import logger + result = await bot_db.is_admin(user_id) logger.info(f"check_access: пользователь {user_id} - результат: {result}") return result @@ -827,6 +996,7 @@ def add_days_to_date(days: int): future_date = current_date + timedelta(days=days) return int(future_date.timestamp()) + @track_time("get_banned_users_list", "helper_func") @track_errors("helper_func", "get_banned_users_list") @db_query_time("get_banned_users_list", "users", "select") @@ -851,11 +1021,13 @@ async def get_banned_users_list(offset: int, bot_db): username = await bot_db.get_username(user_id) full_name = await bot_db.get_full_name_by_id(user_id) safe_user_name = username or full_name or f"User_{user_id}" - + # Экранируем пользовательские данные для безопасного использования safe_user_name = html.escape(str(safe_user_name)) - safe_ban_reason = html.escape(str(ban_reason)) if ban_reason else "Причина не указана" - + safe_ban_reason = ( + html.escape(str(ban_reason)) if ban_reason else "Причина не указана" + ) + # Форматируем дату разбана в человекочитаемый формат if unban_date: try: @@ -873,7 +1045,7 @@ async def get_banned_users_list(offset: int, bot_db): except (ValueError, TypeError): # Если не удалось, показываем как есть safe_unban_date = html.escape(str(unban_date)) - elif hasattr(unban_date, 'strftime'): + elif hasattr(unban_date, "strftime"): # Если это datetime объект safe_unban_date = unban_date.strftime("%d-%m-%Y %H:%M") else: @@ -884,12 +1056,13 @@ async def get_banned_users_list(offset: int, bot_db): safe_unban_date = html.escape(str(unban_date)) else: safe_unban_date = "Дата не указана" - + message += f"**Пользователь:** {safe_user_name}\n" message += f"**Причина бана:** {safe_ban_reason}\n" message += f"**Дата разбана:** {safe_unban_date}\n\n" return message + @track_time("get_banned_users_buttons", "helper_func") @track_errors("helper_func", "get_banned_users_buttons") @db_query_time("get_banned_users_buttons", "users", "select") @@ -913,12 +1086,13 @@ async def get_banned_users_buttons(bot_db): username = await bot_db.get_username(user_id) full_name = await bot_db.get_full_name_by_id(user_id) safe_user_name = username or full_name or f"User_{user_id}" - + # Экранируем user_name для безопасного использования safe_user_name = html.escape(str(safe_user_name)) user_ids.append((safe_user_name, user_id)) return user_ids + @track_time("delete_user_blacklist", "helper_func") @track_errors("helper_func", "delete_user_blacklist") @db_query_time("delete_user_blacklist", "users", "delete") @@ -929,7 +1103,9 @@ async def delete_user_blacklist(user_id: int, bot_db): @track_time("check_username_and_full_name", "helper_func") @track_errors("helper_func", "check_username_and_full_name") @db_query_time("check_username_and_full_name", "users", "select") -async def check_username_and_full_name(user_id: int, username: str, full_name: str, bot_db): +async def check_username_and_full_name( + user_id: int, username: str, full_name: str, bot_db +): """Проверяет, изменились ли username или full_name пользователя""" try: username_db = await bot_db.get_username(user_id) @@ -939,6 +1115,7 @@ async def check_username_and_full_name(user_id: int, username: str, full_name: s logger.error(f"Ошибка при проверке username и full_name: {e}") return False + @track_time("unban_notifier", "helper_func") @track_errors("helper_func", "unban_notifier") @db_query_time("unban_notifier", "users", "select") @@ -973,13 +1150,14 @@ async def update_user_info(source: str, message: types.Message): is_bot = message.from_user.is_bot language_code = message.from_user.language_code user_id = message.from_user.id - + # Выбираем эмодзю, пробегаемся циклом и смотрим что в базе такого еще не было user_emoji = await get_random_emoji() - + if not await BotDB.user_exists(user_id): # Create User object with current timestamp from database.models import User + current_timestamp = int(datetime.now().timestamp()) user = User( user_id=user_id, @@ -992,18 +1170,23 @@ async def update_user_info(source: str, message: types.Message): has_stickers=False, date_added=current_timestamp, date_changed=current_timestamp, - voice_bot_welcome_received=False + voice_bot_welcome_received=False, ) await BotDB.add_user(user) else: - is_need_update = await check_username_and_full_name(user_id, username, full_name, BotDB) + is_need_update = await check_username_and_full_name( + user_id, username, full_name, BotDB + ) if is_need_update: await BotDB.update_user_info(user_id, username, full_name) - if source != 'voice': + if source != "voice": await message.answer( - f"Давно не виделись! Вижу что ты изменился;) Теперь буду звать тебя: {full_name}") - await message.bot.send_message(chat_id=GROUP_FOR_LOGS, - text=f'Для пользователя: {user_id} обновлены данные в БД.\nНовое имя: {full_name}\nНовый ник:{username}. Новый эмодзи:{user_emoji}') + f"Давно не виделись! Вижу что ты изменился;) Теперь буду звать тебя: {full_name}" + ) + await message.bot.send_message( + chat_id=GROUP_FOR_LOGS, + text=f"Для пользователя: {user_id} обновлены данные в БД.\nНовое имя: {full_name}\nНовый ник:{username}. Новый эмодзи:{user_emoji}", + ) sleep(1) await BotDB.update_user_date(user_id) @@ -1014,7 +1197,11 @@ async def update_user_info(source: str, message: types.Message): async def check_user_emoji(message: types.Message): user_id = message.from_user.id user_emoji = await BotDB.get_user_emoji(user_id=user_id) - if user_emoji is None or user_emoji in ("Смайл еще не определен", "Эмоджи не определен", ""): + if user_emoji is None or user_emoji in ( + "Смайл еще не определен", + "Эмоджи не определен", + "", + ): user_emoji = await get_random_emoji() await BotDB.update_user_emoji(user_id=user_id, emoji=user_emoji) return user_emoji diff --git a/helper_bot/utils/messages.py b/helper_bot/utils/messages.py index 462b570..94855b4 100644 --- a/helper_bot/utils/messages.py +++ b/helper_bot/utils/messages.py @@ -4,55 +4,55 @@ import html from .metrics import metrics, track_errors, track_time constants = { - 'HELLO_MESSAGE': "Привет, username!👋🏼&Меня зовут Виби, я бот канала 'Влюбленный Бийск'❤🤖" - "&Я был создан для того, чтобы помочь тебе выложить пост в наш канал и если это необходимо, связаться с админами ✍✉" - "&Так же я могу выдать тебе набор стикеров, где я буду главным героем🦸‍♂" - "&Наш бот голосового общения переехал ко мне! Доступен по кнопке 🎤Голосовой бот &Там можно послушать о чем говорит наш город🎧" - "&Предлагай свой пост мне и я обязательно его опубликую😉" - "&Для продолжения взаимодействия воспользуйся меню внизу твоего дисплея⬇" - "&&Если что-то пошло не так: введи в чат команду /start или /restart, это перезапустит сценарий сначала." - "Почитать инструкцию к боту можно по команде /help. Если есть вопросы, то пиши в личку: @Kerrad1" - "&Не жми кнопку несколько раз если я не ответил с первого раза. Возможно ведутся тех.работы и я отвечу позже" - "&&Группа в ВК: https://vk.com/love_bsk" - "&Канал в ТГ: https://t.me/love_bsk", - 'SUGGEST_NEWS': "username, окей, жду от тебя текст поста🙌🏼" - "&Обрати внимание, что я умный и смогу из твоего текста понять команды указанные ниже😉" - "&Если хочешь чтобы пост был опубликован анонимно, напиши в любом месте своего поста слово 'анон'." - "&Если хочешь опубликовать пост не анонимно, то напиши 'не анон', 'неанон' или не пиши ничего." - "&&❗️❗️Я обучен только на команды, указанные мной выше👆" - "&❗️❗️Проверь, чтобы указание авторства было выполнено так как я попросил, иначе пост будет выложен не корректно" - "&Пост будет опубликован только в группе ТГ📩", + "HELLO_MESSAGE": "Привет, username!👋🏼&Меня зовут Виби, я бот канала 'Влюбленный Бийск'❤🤖" + "&Я был создан для того, чтобы помочь тебе выложить пост в наш канал и если это необходимо, связаться с админами ✍✉" + "&Так же я могу выдать тебе набор стикеров, где я буду главным героем🦸‍♂" + "&Наш бот голосового общения переехал ко мне! Доступен по кнопке 🎤Голосовой бот &Там можно послушать о чем говорит наш город🎧" + "&Предлагай свой пост мне и я обязательно его опубликую😉" + "&Для продолжения взаимодействия воспользуйся меню внизу твоего дисплея⬇" + "&&Если что-то пошло не так: введи в чат команду /start или /restart, это перезапустит сценарий сначала." + "Почитать инструкцию к боту можно по команде /help. Если есть вопросы, то пиши в личку: @Kerrad1" + "&Не жми кнопку несколько раз если я не ответил с первого раза. Возможно ведутся тех.работы и я отвечу позже" + "&&Группа в ВК: https://vk.com/love_bsk" + "&Канал в ТГ: https://t.me/love_bsk", + "SUGGEST_NEWS": "username, окей, жду от тебя текст поста🙌🏼" + "&Обрати внимание, что я умный и смогу из твоего текста понять команды указанные ниже😉" + "&Если хочешь чтобы пост был опубликован анонимно, напиши в любом месте своего поста слово 'анон'." + "&Если хочешь опубликовать пост не анонимно, то напиши 'не анон', 'неанон' или не пиши ничего." + "&&❗️❗️Я обучен только на команды, указанные мной выше👆" + "&❗️❗️Проверь, чтобы указание авторства было выполнено так как я попросил, иначе пост будет выложен не корректно" + "&Пост будет опубликован только в группе ТГ📩", "CONNECT_WITH_ADMIN": "username, напиши свое обращение или предложение✍️" - "&Мы рассмотрим и ответим тебе в ближайшее время☺️❤️", + "&Мы рассмотрим и ответим тебе в ближайшее время☺️❤️", "DEL_MESSAGE": "username, напиши свое обращение или предложение✍" - "&Мы рассмотрим и ответим тебе в ближайшее время☺❤", + "&Мы рассмотрим и ответим тебе в ближайшее время☺❤", "BYE_MESSAGE": "Если позднее захочешь предложить еще один пост или обратиться к админам с вопросом, то просто пришли в чат команду 👉 /restart" - "&&И тебе пока!👋🏼❤️", + "&&И тебе пока!👋🏼❤️", "USER_ERROR": "Увы, я не понимаю тебя😐💔 Выбери один из пунктов в нижнем меню, а затем пиши.", "QUESTION": "Сообщение успешно отправлено❤️ Ответим, как только сможем😉", "SUCCESS_SEND_MESSAGE": "Пост успешно отправлен❤️ Ожидай одобрения😊", # Voice handler messages "MESSAGE_FOR_STANDUP": "Отлично, ты вошел в режим стендапа 📣" - "&Это свободное пространство, в котором может высказаться каждый житель нашего города, и он будет услышан🙌🏼" - "&Для того чтобы высказаться, нажми кнопку: 'Высказаться' и запиши голосовое сообщение, оно выпадет анонимно кому-то другому🗣" - "&Для того чтобы послушать о чем говорит наш город, нажми кнопку: 'Послушать'👂" - "&Ты можешь анонимно пообщаться, поделиться чем-то важным, обратиться напрямую к жителям🤝 Также можешь выступить перед аудиторией (спеть песню, рассказать стихотворение, шутку)🎤" - "&❗️Но пожалуйста не оскорбляй никого, и будь вежлив.", - 'WELCOME_MESSAGE': "Привет.", - 'DESCRIPTION_MESSAGE': "Здесь можно послушать голосовые сообщения от совершенно незнакомых людей из Бийска", - 'ANALOGY_MESSAGE': "Это почти как написать письмо, положить его в бутылку и швырнуть в океан. Никогда не узнаешь, послушал его кто-то или нет и ответить тоже не получится..", - 'RULES_MESSAGE': "Записывать можно всё что угодно — никаких правил нет. Главное — твой голос, хотя бы на 5-10 секунд", - 'ANONYMITY_MESSAGE': "Здесь всё анонимно: тот, кому я отправлю твое сообщение, не узнает ни твое имя, ни твой аккаунт (так что можно не стесняться говорить то, что не стал(а) бы выкладывать в собственные соцсети)", - 'SUGGESTION_MESSAGE': "Если не знаешь, что сказать, можешь просто прочитать любое текстовое сообщение из недавно полученных или отправленных (или спеть, рассказать стихотворенье)", - 'EMOJI_INFO_MESSAGE': "Любые войсы будут помечены эмоджи. Твой эмоджи - {emoji}Таким эмоджи будут помечены твои сообщения для других Но другие люди не узнают кто за каким эмоджи скрывается:)", - 'HELP_INFO_MESSAGE': "Так же можешь ознакомиться с инструкцией к боту по команде /help", - 'FINAL_MESSAGE': "Ну всё, достаточно инструкций. записывайся! Микрофон твой - 🎤", - 'HELP_MESSAGE': "Когда-нибудь здесь будет инструкция к боту. А пока по вопросам пиши в личку: @Kerrad1 или в Связаться с админами", - 'VOICE_SAVED_MESSAGE': "Окей, сохранил!👌", - 'LISTENINGS_CLEARED_MESSAGE': "Прослушивания очищены. Можешь начать слушать заново🤗", - 'NO_AUDIO_MESSAGE': "Прости, ты прослушал все аудио😔. Возвращайся позже, возможно наша база пополнится", - 'UNKNOWN_CONTENT_MESSAGE': "Я тебя не понимаю🤷‍♀️ запиши голосовое", - 'RECORD_VOICE_MESSAGE': "Хорошо, теперь пришли мне свое голосовое сообщение" + "&Это свободное пространство, в котором может высказаться каждый житель нашего города, и он будет услышан🙌🏼" + "&Для того чтобы высказаться, нажми кнопку: 'Высказаться' и запиши голосовое сообщение, оно выпадет анонимно кому-то другому🗣" + "&Для того чтобы послушать о чем говорит наш город, нажми кнопку: 'Послушать'👂" + "&Ты можешь анонимно пообщаться, поделиться чем-то важным, обратиться напрямую к жителям🤝 Также можешь выступить перед аудиторией (спеть песню, рассказать стихотворение, шутку)🎤" + "&❗️Но пожалуйста не оскорбляй никого, и будь вежлив.", + "WELCOME_MESSAGE": "Привет.", + "DESCRIPTION_MESSAGE": "Здесь можно послушать голосовые сообщения от совершенно незнакомых людей из Бийска", + "ANALOGY_MESSAGE": "Это почти как написать письмо, положить его в бутылку и швырнуть в океан. Никогда не узнаешь, послушал его кто-то или нет и ответить тоже не получится..", + "RULES_MESSAGE": "Записывать можно всё что угодно — никаких правил нет. Главное — твой голос, хотя бы на 5-10 секунд", + "ANONYMITY_MESSAGE": "Здесь всё анонимно: тот, кому я отправлю твое сообщение, не узнает ни твое имя, ни твой аккаунт (так что можно не стесняться говорить то, что не стал(а) бы выкладывать в собственные соцсети)", + "SUGGESTION_MESSAGE": "Если не знаешь, что сказать, можешь просто прочитать любое текстовое сообщение из недавно полученных или отправленных (или спеть, рассказать стихотворенье)", + "EMOJI_INFO_MESSAGE": "Любые войсы будут помечены эмоджи. Твой эмоджи - {emoji}Таким эмоджи будут помечены твои сообщения для других Но другие люди не узнают кто за каким эмоджи скрывается:)", + "HELP_INFO_MESSAGE": "Так же можешь ознакомиться с инструкцией к боту по команде /help", + "FINAL_MESSAGE": "Ну всё, достаточно инструкций. записывайся! Микрофон твой - 🎤", + "HELP_MESSAGE": "Когда-нибудь здесь будет инструкция к боту. А пока по вопросам пиши в личку: @Kerrad1 или в Связаться с админами", + "VOICE_SAVED_MESSAGE": "Окей, сохранил!👌", + "LISTENINGS_CLEARED_MESSAGE": "Прослушивания очищены. Можешь начать слушать заново🤗", + "NO_AUDIO_MESSAGE": "Прости, ты прослушал все аудио😔. Возвращайся позже, возможно наша база пополнится", + "UNKNOWN_CONTENT_MESSAGE": "Я тебя не понимаю🤷‍♀️ запиши голосовое", + "RECORD_VOICE_MESSAGE": "Хорошо, теперь пришли мне свое голосовое сообщение", } @@ -64,5 +64,5 @@ def get_message(username: str, type_message: str): raise TypeError("username is None") message = constants[type_message] # Экранируем потенциально проблемные символы для HTML - message = message.replace('username', html.escape(username)).replace('&', '\n') + message = message.replace("username", html.escape(username)).replace("&", "\n") return message diff --git a/helper_bot/utils/metrics.py b/helper_bot/utils/metrics.py index 60a4336..b416b8c 100644 --- a/helper_bot/utils/metrics.py +++ b/helper_bot/utils/metrics.py @@ -10,8 +10,13 @@ from contextlib import asynccontextmanager from functools import wraps from typing import Any, Dict, Optional -from prometheus_client import (CONTENT_TYPE_LATEST, Counter, Gauge, Histogram, - generate_latest) +from prometheus_client import ( + CONTENT_TYPE_LATEST, + Counter, + Gauge, + Histogram, + generate_latest, +) from prometheus_client.core import CollectorRegistry # Метрики rate limiter теперь создаются в основном классе @@ -19,372 +24,399 @@ from prometheus_client.core import CollectorRegistry class BotMetrics: """Central class for managing all bot metrics.""" - + def __init__(self): self.registry = CollectorRegistry() - + # Создаем метрики rate limiter в том же registry self._create_rate_limit_metrics() - + # Bot commands counter self.bot_commands_total = Counter( - 'bot_commands_total', - 'Total number of bot commands processed', - ['command', 'status', 'handler_type', 'user_type'], - registry=self.registry + "bot_commands_total", + "Total number of bot commands processed", + ["command", "status", "handler_type", "user_type"], + registry=self.registry, ) - + # Method execution time histogram self.method_duration_seconds = Histogram( - 'method_duration_seconds', - 'Time spent executing methods', - ['method_name', 'handler_type', 'status'], + "method_duration_seconds", + "Time spent executing methods", + ["method_name", "handler_type", "status"], # Оптимизированные buckets для Telegram API (обычно < 1 сек) buckets=[0.01, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0], - registry=self.registry + registry=self.registry, ) - + # Errors counter self.errors_total = Counter( - 'errors_total', - 'Total number of errors', - ['error_type', 'handler_type', 'method_name'], - registry=self.registry + "errors_total", + "Total number of errors", + ["error_type", "handler_type", "method_name"], + registry=self.registry, ) - + # Active users gauge self.active_users = Gauge( - 'active_users', - 'Number of currently active users', - ['user_type'], - registry=self.registry + "active_users", + "Number of currently active users", + ["user_type"], + registry=self.registry, ) - + # Total users gauge (отдельная метрика) self.total_users = Gauge( - 'total_users', - 'Total number of users in database', - registry=self.registry + "total_users", "Total number of users in database", registry=self.registry ) - + # Database query metrics self.db_query_duration_seconds = Histogram( - 'db_query_duration_seconds', - 'Time spent executing database queries', - ['query_type', 'table_name', 'operation'], + "db_query_duration_seconds", + "Time spent executing database queries", + ["query_type", "table_name", "operation"], # Оптимизированные buckets для SQLite/PostgreSQL buckets=[0.001, 0.005, 0.01, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5], - registry=self.registry + registry=self.registry, ) - + # Database queries counter self.db_queries_total = Counter( - 'db_queries_total', - 'Total number of database queries executed', - ['query_type', 'table_name', 'operation'], - registry=self.registry + "db_queries_total", + "Total number of database queries executed", + ["query_type", "table_name", "operation"], + registry=self.registry, ) - + # Database errors counter self.db_errors_total = Counter( - 'db_errors_total', - 'Total number of database errors', - ['error_type', 'query_type', 'table_name', 'operation'], - registry=self.registry + "db_errors_total", + "Total number of database errors", + ["error_type", "query_type", "table_name", "operation"], + registry=self.registry, ) - + # Message processing metrics self.messages_processed_total = Counter( - 'messages_processed_total', - 'Total number of messages processed', - ['message_type', 'chat_type', 'handler_type'], - registry=self.registry + "messages_processed_total", + "Total number of messages processed", + ["message_type", "chat_type", "handler_type"], + registry=self.registry, ) - + # Middleware execution metrics self.middleware_duration_seconds = Histogram( - 'middleware_duration_seconds', - 'Time spent in middleware execution', - ['middleware_name', 'status'], + "middleware_duration_seconds", + "Time spent in middleware execution", + ["middleware_name", "status"], # Middleware должен быть быстрым buckets=[0.001, 0.005, 0.01, 0.05, 0.1, 0.25], - registry=self.registry + registry=self.registry, ) - + # Rate limiting metrics self.rate_limit_hits_total = Counter( - 'rate_limit_hits_total', - 'Total number of rate limit hits', - ['limit_type', 'user_id', 'action'], - registry=self.registry + "rate_limit_hits_total", + "Total number of rate limit hits", + ["limit_type", "user_id", "action"], + registry=self.registry, ) # User activity metrics self.user_activity_total = Counter( - 'user_activity_total', - 'Total user activity events', - ['activity_type', 'user_type', 'chat_type'], - registry=self.registry + "user_activity_total", + "Total user activity events", + ["activity_type", "user_type", "chat_type"], + registry=self.registry, ) - + # File download metrics self.file_downloads_total = Counter( - 'file_downloads_total', - 'Total number of file downloads', - ['content_type', 'status'], - registry=self.registry + "file_downloads_total", + "Total number of file downloads", + ["content_type", "status"], + registry=self.registry, ) - + self.file_download_duration_seconds = Histogram( - 'file_download_duration_seconds', - 'Time spent downloading files', - ['content_type'], + "file_download_duration_seconds", + "Time spent downloading files", + ["content_type"], buckets=[0.1, 0.5, 1.0, 2.5, 5.0, 10.0, 30.0, 60.0], - registry=self.registry + registry=self.registry, ) - + self.file_download_size_bytes = Histogram( - 'file_download_size_bytes', - 'Size of downloaded files in bytes', - ['content_type'], + "file_download_size_bytes", + "Size of downloaded files in bytes", + ["content_type"], buckets=[1024, 10240, 102400, 1048576, 10485760, 104857600, 1073741824], - registry=self.registry + registry=self.registry, ) - + # Media processing metrics self.media_processing_total = Counter( - 'media_processing_total', - 'Total number of media processing operations', - ['content_type', 'status'], - registry=self.registry + "media_processing_total", + "Total number of media processing operations", + ["content_type", "status"], + registry=self.registry, ) - + self.media_processing_duration_seconds = Histogram( - 'media_processing_duration_seconds', - 'Time spent processing media', - ['content_type'], + "media_processing_duration_seconds", + "Time spent processing media", + ["content_type"], buckets=[0.1, 0.5, 1.0, 2.5, 5.0, 10.0, 30.0], - registry=self.registry + registry=self.registry, ) - + def _create_rate_limit_metrics(self): """Создает метрики rate limiter в основном registry""" try: # Создаем метрики rate limiter в том же registry self.rate_limit_requests_total = Counter( - 'rate_limit_requests_total', - 'Total number of rate limited requests', - ['chat_id', 'status', 'error_type'], - registry=self.registry + "rate_limit_requests_total", + "Total number of rate limited requests", + ["chat_id", "status", "error_type"], + registry=self.registry, ) - + self.rate_limit_errors_total = Counter( - 'rate_limit_errors_total', - 'Total number of rate limit errors', - ['error_type', 'chat_id'], - registry=self.registry + "rate_limit_errors_total", + "Total number of rate limit errors", + ["error_type", "chat_id"], + registry=self.registry, ) - + self.rate_limit_wait_duration_seconds = Histogram( - 'rate_limit_wait_duration_seconds', - 'Time spent waiting due to rate limiting', - ['chat_id'], + "rate_limit_wait_duration_seconds", + "Time spent waiting due to rate limiting", + ["chat_id"], buckets=[0.1, 0.5, 1.0, 2.0, 5.0, 10.0, 30.0, 60.0], - registry=self.registry + registry=self.registry, ) - + self.rate_limit_active_chats = Gauge( - 'rate_limit_active_chats', - 'Number of active chats with rate limiting', - registry=self.registry + "rate_limit_active_chats", + "Number of active chats with rate limiting", + registry=self.registry, ) - + self.rate_limit_success_rate = Gauge( - 'rate_limit_success_rate', - 'Success rate of rate limited requests', - ['chat_id'], - registry=self.registry + "rate_limit_success_rate", + "Success rate of rate limited requests", + ["chat_id"], + registry=self.registry, ) - + self.rate_limit_requests_per_minute = Gauge( - 'rate_limit_requests_per_minute', - 'Requests per minute', - ['chat_id'], - registry=self.registry + "rate_limit_requests_per_minute", + "Requests per minute", + ["chat_id"], + registry=self.registry, ) - + self.rate_limit_total_requests = Gauge( - 'rate_limit_total_requests', - 'Total number of requests', - ['chat_id'], - registry=self.registry + "rate_limit_total_requests", + "Total number of requests", + ["chat_id"], + registry=self.registry, ) - + self.rate_limit_total_errors = Gauge( - 'rate_limit_total_errors', - 'Total number of errors', - ['chat_id', 'error_type'], - registry=self.registry + "rate_limit_total_errors", + "Total number of errors", + ["chat_id", "error_type"], + registry=self.registry, ) - + self.rate_limit_avg_wait_time_seconds = Gauge( - 'rate_limit_avg_wait_time_seconds', - 'Average wait time in seconds', - ['chat_id'], - registry=self.registry + "rate_limit_avg_wait_time_seconds", + "Average wait time in seconds", + ["chat_id"], + registry=self.registry, ) - + except Exception as e: # Логируем ошибку, но не прерываем инициализацию import logging + logging.warning(f"Failed to create rate limit metrics: {e}") - - def record_command(self, command_type: str, handler_type: str = "unknown", user_type: str = "unknown", status: str = "success"): + + def record_command( + self, + command_type: str, + handler_type: str = "unknown", + user_type: str = "unknown", + status: str = "success", + ): """Record a bot command execution.""" self.bot_commands_total.labels( command=command_type, status=status, handler_type=handler_type, - user_type=user_type + user_type=user_type, ).inc() - - def record_error(self, error_type: str, handler_type: str = "unknown", method_name: str = "unknown"): + + def record_error( + self, + error_type: str, + handler_type: str = "unknown", + method_name: str = "unknown", + ): """Record an error occurrence.""" self.errors_total.labels( - error_type=error_type, - handler_type=handler_type, - method_name=method_name + error_type=error_type, handler_type=handler_type, method_name=method_name ).inc() - - def record_method_duration(self, method_name: str, duration: float, handler_type: str = "unknown", status: str = "success"): + + def record_method_duration( + self, + method_name: str, + duration: float, + handler_type: str = "unknown", + status: str = "success", + ): """Record method execution duration.""" self.method_duration_seconds.labels( - method_name=method_name, - handler_type=handler_type, - status=status + method_name=method_name, handler_type=handler_type, status=status ).observe(duration) - + def set_active_users(self, count: int, user_type: str = "daily"): """Set the number of active users for a specific type.""" self.active_users.labels(user_type=user_type).set(count) - + def set_total_users(self, count: int): """Set the total number of users in database.""" self.total_users.set(count) - - def record_db_query(self, query_type: str, duration: float, table_name: str = "unknown", operation: str = "unknown"): + + def record_db_query( + self, + query_type: str, + duration: float, + table_name: str = "unknown", + operation: str = "unknown", + ): """Record database query duration.""" self.db_query_duration_seconds.labels( - query_type=query_type, - table_name=table_name, - operation=operation + query_type=query_type, table_name=table_name, operation=operation ).observe(duration) self.db_queries_total.labels( - query_type=query_type, - table_name=table_name, - operation=operation + query_type=query_type, table_name=table_name, operation=operation ).inc() - - def record_message(self, message_type: str, chat_type: str = "unknown", handler_type: str = "unknown"): + + def record_message( + self, + message_type: str, + chat_type: str = "unknown", + handler_type: str = "unknown", + ): """Record a processed message.""" self.messages_processed_total.labels( - message_type=message_type, - chat_type=chat_type, - handler_type=handler_type + message_type=message_type, chat_type=chat_type, handler_type=handler_type ).inc() - - def record_middleware(self, middleware_name: str, duration: float, status: str = "success"): + + def record_middleware( + self, middleware_name: str, duration: float, status: str = "success" + ): """Record middleware execution duration.""" self.middleware_duration_seconds.labels( - middleware_name=middleware_name, - status=status + middleware_name=middleware_name, status=status ).observe(duration) - + def record_file_download(self, content_type: str, file_size: int, duration: float): """Record file download metrics.""" self.file_downloads_total.labels( - content_type=content_type, - status="success" + content_type=content_type, status="success" ).inc() - - self.file_download_duration_seconds.labels( - content_type=content_type - ).observe(duration) - - self.file_download_size_bytes.labels( - content_type=content_type - ).observe(file_size) - + + self.file_download_duration_seconds.labels(content_type=content_type).observe( + duration + ) + + self.file_download_size_bytes.labels(content_type=content_type).observe( + file_size + ) + def record_file_download_error(self, content_type: str, error_message: str): """Record file download error metrics.""" self.file_downloads_total.labels( - content_type=content_type, - status="error" + content_type=content_type, status="error" ).inc() - + self.errors_total.labels( error_type="file_download_error", handler_type="media_processing", - method_name="download_file" + method_name="download_file", ).inc() - - def record_media_processing(self, content_type: str, duration: float, success: bool): + + def record_media_processing( + self, content_type: str, duration: float, success: bool + ): """Record media processing metrics.""" status = "success" if success else "error" - + self.media_processing_total.labels( - content_type=content_type, - status=status + content_type=content_type, status=status ).inc() - + self.media_processing_duration_seconds.labels( content_type=content_type ).observe(duration) - + if not success: self.errors_total.labels( error_type="media_processing_error", handler_type="media_processing", - method_name="add_in_db_media" + method_name="add_in_db_media", ).inc() - - def record_db_error(self, error_type: str, query_type: str = "unknown", table_name: str = "unknown", operation: str = "unknown"): + + def record_db_error( + self, + error_type: str, + query_type: str = "unknown", + table_name: str = "unknown", + operation: str = "unknown", + ): """Record database error occurrence.""" self.db_errors_total.labels( error_type=error_type, query_type=query_type, table_name=table_name, - operation=operation + operation=operation, ).inc() - - def record_rate_limit_request(self, chat_id: int, success: bool, wait_time: float = 0.0, error_type: str = None): + + def record_rate_limit_request( + self, + chat_id: int, + success: bool, + wait_time: float = 0.0, + error_type: str = None, + ): """Record rate limit request metrics.""" try: # Определяем статус status = "success" if success else "error" - + # Записываем счетчик запросов self.rate_limit_requests_total.labels( - chat_id=str(chat_id), - status=status, - error_type=error_type or "none" + chat_id=str(chat_id), status=status, error_type=error_type or "none" ).inc() - + # Записываем время ожидания if wait_time > 0: self.rate_limit_wait_duration_seconds.labels( chat_id=str(chat_id) ).observe(wait_time) - + # Записываем ошибки if not success and error_type: self.rate_limit_errors_total.labels( - error_type=error_type, - chat_id=str(chat_id) + error_type=error_type, chat_id=str(chat_id) ).inc() except Exception as e: import logging + logging.warning(f"Failed to record rate limit request: {e}") - + def update_rate_limit_gauges(self): """Update rate limit gauge metrics.""" try: @@ -392,52 +424,51 @@ class BotMetrics: # Обновляем количество активных чатов self.rate_limit_active_chats.set(len(rate_limit_monitor.stats)) - + # Обновляем метрики для каждого чата for chat_id, chat_stats in rate_limit_monitor.stats.items(): chat_id_str = str(chat_id) - + # Процент успеха - self.rate_limit_success_rate.labels( - chat_id=chat_id_str - ).set(chat_stats.success_rate) - + self.rate_limit_success_rate.labels(chat_id=chat_id_str).set( + chat_stats.success_rate + ) + # Запросов в минуту - self.rate_limit_requests_per_minute.labels( - chat_id=chat_id_str - ).set(chat_stats.requests_per_minute) - + self.rate_limit_requests_per_minute.labels(chat_id=chat_id_str).set( + chat_stats.requests_per_minute + ) + # Общее количество запросов - self.rate_limit_total_requests.labels( - chat_id=chat_id_str - ).set(chat_stats.total_requests) - + self.rate_limit_total_requests.labels(chat_id=chat_id_str).set( + chat_stats.total_requests + ) + # Среднее время ожидания - self.rate_limit_avg_wait_time_seconds.labels( - chat_id=chat_id_str - ).set(chat_stats.average_wait_time) - + self.rate_limit_avg_wait_time_seconds.labels(chat_id=chat_id_str).set( + chat_stats.average_wait_time + ) + # Количество ошибок по типам if chat_stats.retry_after_errors > 0: self.rate_limit_total_errors.labels( - chat_id=chat_id_str, - error_type="RetryAfter" + chat_id=chat_id_str, error_type="RetryAfter" ).set(chat_stats.retry_after_errors) - + if chat_stats.other_errors > 0: self.rate_limit_total_errors.labels( - chat_id=chat_id_str, - error_type="Other" + chat_id=chat_id_str, error_type="Other" ).set(chat_stats.other_errors) except Exception as e: import logging + logging.warning(f"Failed to update rate limit gauges: {e}") - + def get_metrics(self) -> bytes: """Generate metrics in Prometheus format.""" # Обновляем gauge метрики rate limiter перед генерацией self.update_rate_limit_gauges() - + return generate_latest(self.registry) @@ -448,6 +479,7 @@ metrics = BotMetrics() # Decorators for easy metric collection def track_time(method_name: str = None, handler_type: str = "unknown"): """Decorator to track execution time of functions.""" + def decorator(func): @wraps(func) async def async_wrapper(*args, **kwargs): @@ -456,27 +488,19 @@ def track_time(method_name: str = None, handler_type: str = "unknown"): result = await func(*args, **kwargs) duration = time.time() - start_time metrics.record_method_duration( - method_name or func.__name__, - duration, - handler_type, - "success" + method_name or func.__name__, duration, handler_type, "success" ) return result except Exception as e: duration = time.time() - start_time metrics.record_method_duration( - method_name or func.__name__, - duration, - handler_type, - "error" + method_name or func.__name__, duration, handler_type, "error" ) metrics.record_error( - type(e).__name__, - handler_type, - method_name or func.__name__ + type(e).__name__, handler_type, method_name or func.__name__ ) raise - + @wraps(func) def sync_wrapper(*args, **kwargs): start_time = time.time() @@ -484,35 +508,29 @@ def track_time(method_name: str = None, handler_type: str = "unknown"): result = func(*args, **kwargs) duration = time.time() - start_time metrics.record_method_duration( - method_name or func.__name__, - duration, - handler_type, - "success" + method_name or func.__name__, duration, handler_type, "success" ) return result except Exception as e: duration = time.time() - start_time metrics.record_method_duration( - method_name or func.__name__, - duration, - handler_type, - "error" + method_name or func.__name__, duration, handler_type, "error" ) metrics.record_error( - type(e).__name__, - handler_type, - method_name or func.__name__ + type(e).__name__, handler_type, method_name or func.__name__ ) raise - + if asyncio.iscoroutinefunction(func): return async_wrapper return sync_wrapper + return decorator def track_errors(handler_type: str = "unknown", method_name: str = None): """Decorator to track errors in functions.""" + def decorator(func): @wraps(func) async def async_wrapper(*args, **kwargs): @@ -520,32 +538,32 @@ def track_errors(handler_type: str = "unknown", method_name: str = None): return await func(*args, **kwargs) except Exception as e: metrics.record_error( - type(e).__name__, - handler_type, - method_name or func.__name__ + type(e).__name__, handler_type, method_name or func.__name__ ) raise - + @wraps(func) def sync_wrapper(*args, **kwargs): try: return func(*args, **kwargs) except Exception as e: metrics.record_error( - type(e).__name__, - handler_type, - method_name or func.__name__ + type(e).__name__, handler_type, method_name or func.__name__ ) raise - + if asyncio.iscoroutinefunction(func): return async_wrapper return sync_wrapper + return decorator -def db_query_time(query_type: str = "unknown", table_name: str = "unknown", operation: str = "unknown"): +def db_query_time( + query_type: str = "unknown", table_name: str = "unknown", operation: str = "unknown" +): """Decorator to track database query execution time.""" + def decorator(func): @wraps(func) async def async_wrapper(*args, **kwargs): @@ -559,18 +577,11 @@ def db_query_time(query_type: str = "unknown", table_name: str = "unknown", oper duration = time.time() - start_time metrics.record_db_query(query_type, duration, table_name, operation) metrics.record_db_error( - type(e).__name__, - query_type, - table_name, - operation - ) - metrics.record_error( - type(e).__name__, - "database", - func.__name__ + type(e).__name__, query_type, table_name, operation ) + metrics.record_error(type(e).__name__, "database", func.__name__) raise - + @wraps(func) def sync_wrapper(*args, **kwargs): start_time = time.time() @@ -583,21 +594,15 @@ def db_query_time(query_type: str = "unknown", table_name: str = "unknown", oper duration = time.time() - start_time metrics.record_db_query(query_type, duration, table_name, operation) metrics.record_db_error( - type(e).__name__, - query_type, - table_name, - operation - ) - metrics.record_error( - type(e).__name__, - "database", - func.__name__ + type(e).__name__, query_type, table_name, operation ) + metrics.record_error(type(e).__name__, "database", func.__name__) raise - + if asyncio.iscoroutinefunction(func): return async_wrapper return sync_wrapper + return decorator @@ -612,16 +617,13 @@ async def track_middleware(middleware_name: str): except Exception as e: duration = time.time() - start_time metrics.record_middleware(middleware_name, duration, "error") - metrics.record_error( - type(e).__name__, - "middleware", - middleware_name - ) + metrics.record_error(type(e).__name__, "middleware", middleware_name) raise def track_media_processing(content_type: str = "unknown"): """Decorator to track media processing operations.""" + def decorator(func): @wraps(func) async def async_wrapper(*args, **kwargs): @@ -635,7 +637,7 @@ def track_media_processing(content_type: str = "unknown"): duration = time.time() - start_time metrics.record_media_processing(content_type, duration, False) raise - + @wraps(func) def sync_wrapper(*args, **kwargs): start_time = time.time() @@ -648,15 +650,17 @@ def track_media_processing(content_type: str = "unknown"): duration = time.time() - start_time metrics.record_media_processing(content_type, duration, False) raise - + if asyncio.iscoroutinefunction(func): return async_wrapper return sync_wrapper + return decorator def track_file_operations(content_type: str = "unknown"): """Decorator to track file download/upload operations.""" + def decorator(func): @wraps(func) async def async_wrapper(*args, **kwargs): @@ -664,43 +668,44 @@ def track_file_operations(content_type: str = "unknown"): try: result = await func(*args, **kwargs) duration = time.time() - start_time - + # Получаем размер файла из результата file_size = 0 if result and isinstance(result, str) and os.path.exists(result): file_size = os.path.getsize(result) - + # Записываем метрики metrics.record_file_download(content_type, file_size, duration) - + return result except Exception as e: duration = time.time() - start_time metrics.record_file_download_error(content_type, str(e)) raise - + @wraps(func) def sync_wrapper(*args, **kwargs): start_time = time.time() try: result = func(*args, **kwargs) duration = time.time() - start_time - + # Получаем размер файла из результата file_size = 0 if result and isinstance(result, str) and os.path.exists(result): file_size = os.path.getsize(result) - + # Записываем метрики metrics.record_file_download(content_type, file_size, duration) - + return result except Exception as e: duration = time.time() - start_time metrics.record_file_download_error(content_type, str(e)) raise - + if asyncio.iscoroutinefunction(func): return async_wrapper return sync_wrapper + return decorator diff --git a/helper_bot/utils/rate_limit_monitor.py b/helper_bot/utils/rate_limit_monitor.py index 4a9c6b9..103b154 100644 --- a/helper_bot/utils/rate_limit_monitor.py +++ b/helper_bot/utils/rate_limit_monitor.py @@ -1,6 +1,7 @@ """ Мониторинг и статистика rate limiting """ + import time from collections import defaultdict, deque from dataclasses import dataclass, field @@ -12,6 +13,7 @@ from logs.custom_logger import logger @dataclass class RateLimitStats: """Статистика rate limiting для чата""" + chat_id: int total_requests: int = 0 successful_requests: int = 0 @@ -21,53 +23,61 @@ class RateLimitStats: total_wait_time: float = 0.0 last_request_time: float = 0.0 request_times: deque = field(default_factory=lambda: deque(maxlen=100)) - + @property def success_rate(self) -> float: """Процент успешных запросов""" if self.total_requests == 0: return 1.0 return self.successful_requests / self.total_requests - + @property def error_rate(self) -> float: """Процент ошибок""" return 1.0 - self.success_rate - + @property def average_wait_time(self) -> float: """Среднее время ожидания""" if self.total_requests == 0: return 0.0 return self.total_wait_time / self.total_requests - + @property def requests_per_minute(self) -> float: """Запросов в минуту""" if not self.request_times: return 0.0 - + current_time = time.time() minute_ago = current_time - 60 - + # Подсчитываем запросы за последнюю минуту - recent_requests = sum(1 for req_time in self.request_times if req_time > minute_ago) + recent_requests = sum( + 1 for req_time in self.request_times if req_time > minute_ago + ) return recent_requests class RateLimitMonitor: """Монитор для отслеживания статистики rate limiting""" - + def __init__(self, max_history_size: int = 1000): self.stats: Dict[int, RateLimitStats] = defaultdict(lambda: RateLimitStats(0)) self.global_stats = RateLimitStats(0) self.max_history_size = max_history_size self.error_history: deque = deque(maxlen=max_history_size) - - def record_request(self, chat_id: int, success: bool, wait_time: float = 0.0, error_type: Optional[str] = None): + + def record_request( + self, + chat_id: int, + success: bool, + wait_time: float = 0.0, + error_type: Optional[str] = None, + ): """Записывает информацию о запросе""" current_time = time.time() - + # Обновляем статистику для чата chat_stats = self.stats[chat_id] chat_stats.chat_id = chat_id @@ -75,7 +85,7 @@ class RateLimitMonitor: chat_stats.total_wait_time += wait_time chat_stats.last_request_time = current_time chat_stats.request_times.append(current_time) - + if success: chat_stats.successful_requests += 1 else: @@ -84,21 +94,23 @@ class RateLimitMonitor: chat_stats.retry_after_errors += 1 else: chat_stats.other_errors += 1 - + # Записываем ошибку в историю - self.error_history.append({ - 'chat_id': chat_id, - 'error_type': error_type, - 'timestamp': current_time, - 'wait_time': wait_time - }) - + self.error_history.append( + { + "chat_id": chat_id, + "error_type": error_type, + "timestamp": current_time, + "wait_time": wait_time, + } + ) + # Обновляем глобальную статистику self.global_stats.total_requests += 1 self.global_stats.total_wait_time += wait_time self.global_stats.last_request_time = current_time self.global_stats.request_times.append(current_time) - + if success: self.global_stats.successful_requests += 1 else: @@ -107,56 +119,54 @@ class RateLimitMonitor: self.global_stats.retry_after_errors += 1 else: self.global_stats.other_errors += 1 - + def get_chat_stats(self, chat_id: int) -> Optional[RateLimitStats]: """Получает статистику для конкретного чата""" return self.stats.get(chat_id) - + def get_global_stats(self) -> RateLimitStats: """Получает глобальную статистику""" return self.global_stats - + def get_top_chats_by_requests(self, limit: int = 10) -> List[tuple]: """Получает топ чатов по количеству запросов""" sorted_chats = sorted( - self.stats.items(), - key=lambda x: x[1].total_requests, - reverse=True + self.stats.items(), key=lambda x: x[1].total_requests, reverse=True ) return sorted_chats[:limit] - + def get_chats_with_high_error_rate(self, threshold: float = 0.1) -> List[tuple]: """Получает чаты с высоким процентом ошибок""" high_error_chats = [ - (chat_id, stats) for chat_id, stats in self.stats.items() + (chat_id, stats) + for chat_id, stats in self.stats.items() if stats.error_rate > threshold and stats.total_requests > 5 ] return sorted(high_error_chats, key=lambda x: x[1].error_rate, reverse=True) - + def get_recent_errors(self, minutes: int = 60) -> List[dict]: """Получает недавние ошибки""" current_time = time.time() cutoff_time = current_time - (minutes * 60) - + return [ - error for error in self.error_history - if error['timestamp'] > cutoff_time + error for error in self.error_history if error["timestamp"] > cutoff_time ] - + def get_error_summary(self, minutes: int = 60) -> Dict[str, int]: """Получает сводку ошибок за указанный период""" recent_errors = self.get_recent_errors(minutes) error_summary = defaultdict(int) - + for error in recent_errors: - error_summary[error['error_type']] += 1 - + error_summary[error["error_type"]] += 1 + return dict(error_summary) - + def log_statistics(self, log_level: str = "info"): """Логирует текущую статистику""" global_stats = self.get_global_stats() - + log_message = ( f"Rate Limit Statistics:\n" f" Total requests: {global_stats.total_requests}\n" @@ -168,21 +178,25 @@ class RateLimitMonitor: f" Requests per minute: {global_stats.requests_per_minute:.1f}\n" f" Active chats: {len(self.stats)}" ) - + if log_level == "error": logger.error(log_message) elif log_level == "warning": logger.warning(log_message) else: logger.info(log_message) - + # Логируем чаты с высоким процентом ошибок high_error_chats = self.get_chats_with_high_error_rate(0.2) if high_error_chats: - logger.warning(f"Chats with high error rate (>20%): {len(high_error_chats)}") + logger.warning( + f"Chats with high error rate (>20%): {len(high_error_chats)}" + ) for chat_id, stats in high_error_chats[:5]: # Показываем только первые 5 - logger.warning(f" Chat {chat_id}: {stats.error_rate:.2%} error rate ({stats.failed_requests}/{stats.total_requests})") - + logger.warning( + f" Chat {chat_id}: {stats.error_rate:.2%} error rate ({stats.failed_requests}/{stats.total_requests})" + ) + def reset_stats(self, chat_id: Optional[int] = None): """Сбрасывает статистику""" if chat_id is None: @@ -200,7 +214,12 @@ class RateLimitMonitor: rate_limit_monitor = RateLimitMonitor() -def record_rate_limit_request(chat_id: int, success: bool, wait_time: float = 0.0, error_type: Optional[str] = None): +def record_rate_limit_request( + chat_id: int, + success: bool, + wait_time: float = 0.0, + error_type: Optional[str] = None, +): """Удобная функция для записи информации о запросе""" rate_limit_monitor.record_request(chat_id, success, wait_time, error_type) @@ -209,13 +228,13 @@ def get_rate_limit_summary() -> Dict: """Получает краткую сводку по rate limiting""" global_stats = rate_limit_monitor.get_global_stats() recent_errors = rate_limit_monitor.get_recent_errors(60) # За последний час - + return { - 'total_requests': global_stats.total_requests, - 'success_rate': global_stats.success_rate, - 'error_rate': global_stats.error_rate, - 'recent_errors_count': len(recent_errors), - 'active_chats': len(rate_limit_monitor.stats), - 'requests_per_minute': global_stats.requests_per_minute, - 'average_wait_time': global_stats.average_wait_time + "total_requests": global_stats.total_requests, + "success_rate": global_stats.success_rate, + "error_rate": global_stats.error_rate, + "recent_errors_count": len(recent_errors), + "active_chats": len(rate_limit_monitor.stats), + "requests_per_minute": global_stats.requests_per_minute, + "average_wait_time": global_stats.average_wait_time, } diff --git a/helper_bot/utils/rate_limiter.py b/helper_bot/utils/rate_limiter.py index 25a8891..78d891f 100644 --- a/helper_bot/utils/rate_limiter.py +++ b/helper_bot/utils/rate_limiter.py @@ -1,20 +1,23 @@ """ Rate limiter для предотвращения Flood control ошибок в Telegram Bot API """ + import asyncio import time from dataclasses import dataclass from typing import Any, Callable, Dict, Optional from aiogram.exceptions import TelegramAPIError, TelegramRetryAfter + from logs.custom_logger import logger from .metrics import metrics -@dataclass +@dataclass class RateLimitConfig: """Конфигурация для rate limiting""" + messages_per_second: float = 0.5 # Максимум 0.5 сообщений в секунду на чат burst_limit: int = 3 # Максимум 3 сообщения подряд retry_after_multiplier: float = 1.2 # Множитель для увеличения задержки при retry @@ -23,23 +26,23 @@ class RateLimitConfig: class ChatRateLimiter: """Rate limiter для конкретного чата""" - + def __init__(self, config: RateLimitConfig): self.config = config self.last_send_time = 0.0 self.burst_count = 0 self.burst_reset_time = 0.0 self.retry_delay = 1.0 - + async def wait_if_needed(self) -> None: """Ждет если необходимо для соблюдения rate limit""" current_time = time.time() - + # Сбрасываем счетчик burst если прошло достаточно времени if current_time >= self.burst_reset_time: self.burst_count = 0 self.burst_reset_time = current_time + 1.0 - + # Проверяем burst limit if self.burst_count >= self.config.burst_limit: wait_time = self.burst_reset_time - current_time @@ -49,16 +52,16 @@ class ChatRateLimiter: current_time = time.time() self.burst_count = 0 self.burst_reset_time = current_time + 1.0 - + # Проверяем минимальный интервал между сообщениями time_since_last = current_time - self.last_send_time min_interval = 1.0 / self.config.messages_per_second - + if time_since_last < min_interval: wait_time = min_interval - time_since_last logger.debug(f"Rate limiting: waiting {wait_time:.2f}s") await asyncio.sleep(wait_time) - + # Обновляем время последней отправки self.last_send_time = time.time() self.burst_count += 1 @@ -66,126 +69,126 @@ class ChatRateLimiter: class GlobalRateLimiter: """Глобальный rate limiter для всех чатов""" - + def __init__(self, config: RateLimitConfig): self.config = config self.chat_limiters: Dict[int, ChatRateLimiter] = {} self.global_last_send = 0.0 self.global_min_interval = 0.1 # Минимум 100ms между любыми сообщениями - + def get_chat_limiter(self, chat_id: int) -> ChatRateLimiter: """Получает rate limiter для конкретного чата""" if chat_id not in self.chat_limiters: self.chat_limiters[chat_id] = ChatRateLimiter(self.config) return self.chat_limiters[chat_id] - + async def wait_if_needed(self, chat_id: int) -> None: """Ждет если необходимо для соблюдения глобального и чат-специфичного rate limit""" current_time = time.time() - + # Глобальный rate limit time_since_global = current_time - self.global_last_send if time_since_global < self.global_min_interval: wait_time = self.global_min_interval - time_since_global await asyncio.sleep(wait_time) current_time = time.time() - + # Чат-специфичный rate limit chat_limiter = self.get_chat_limiter(chat_id) await chat_limiter.wait_if_needed() - + self.global_last_send = time.time() class RetryHandler: """Обработчик повторных попыток с экспоненциальной задержкой""" - + def __init__(self, config: RateLimitConfig): self.config = config - + async def execute_with_retry( - self, - func: Callable, - chat_id: int, - *args, - max_retries: int = 3, - **kwargs + self, func: Callable, chat_id: int, *args, max_retries: int = 3, **kwargs ) -> Any: """Выполняет функцию с повторными попытками при ошибках""" retry_count = 0 current_delay = self.config.retry_after_multiplier total_wait_time = 0.0 - + while retry_count <= max_retries: try: result = await func(*args, **kwargs) # Записываем успешный запрос metrics.record_rate_limit_request(chat_id, True, total_wait_time) return result - + except TelegramRetryAfter as e: retry_count += 1 if retry_count > max_retries: logger.error(f"Max retries exceeded for RetryAfter: {e}") - metrics.record_rate_limit_request(chat_id, False, total_wait_time, "RetryAfter") + metrics.record_rate_limit_request( + chat_id, False, total_wait_time, "RetryAfter" + ) raise - + # Используем время ожидания от Telegram или наше увеличенное wait_time = max(e.retry_after, current_delay) wait_time = min(wait_time, self.config.max_retry_delay) total_wait_time += wait_time - - logger.warning(f"RetryAfter error, waiting {wait_time:.2f}s (attempt {retry_count}/{max_retries})") + + logger.warning( + f"RetryAfter error, waiting {wait_time:.2f}s (attempt {retry_count}/{max_retries})" + ) await asyncio.sleep(wait_time) current_delay *= self.config.retry_after_multiplier - + except TelegramAPIError as e: retry_count += 1 if retry_count > max_retries: logger.error(f"Max retries exceeded for TelegramAPIError: {e}") - metrics.record_rate_limit_request(chat_id, False, total_wait_time, "TelegramAPIError") + metrics.record_rate_limit_request( + chat_id, False, total_wait_time, "TelegramAPIError" + ) raise - + wait_time = min(current_delay, self.config.max_retry_delay) total_wait_time += wait_time - logger.warning(f"TelegramAPIError, waiting {wait_time:.2f}s (attempt {retry_count}/{max_retries}): {e}") + logger.warning( + f"TelegramAPIError, waiting {wait_time:.2f}s (attempt {retry_count}/{max_retries}): {e}" + ) await asyncio.sleep(wait_time) current_delay *= self.config.retry_after_multiplier - + except Exception as e: # Для других ошибок не делаем retry logger.error(f"Non-retryable error: {e}") - metrics.record_rate_limit_request(chat_id, False, total_wait_time, "Other") + metrics.record_rate_limit_request( + chat_id, False, total_wait_time, "Other" + ) raise class TelegramRateLimiter: """Основной класс для rate limiting в Telegram боте""" - + def __init__(self, config: Optional[RateLimitConfig] = None): self.config = config or RateLimitConfig() self.global_limiter = GlobalRateLimiter(self.config) self.retry_handler = RetryHandler(self.config) - + async def send_with_rate_limit( - self, - send_func: Callable, - chat_id: int, - *args, - **kwargs + self, send_func: Callable, chat_id: int, *args, **kwargs ) -> Any: """Отправляет сообщение с соблюдением rate limit и retry логики""" - + async def _send(): await self.global_limiter.wait_if_needed(chat_id) return await send_func(*args, **kwargs) - + return await self.retry_handler.execute_with_retry(_send, chat_id) # Глобальный экземпляр rate limiter -from helper_bot.config.rate_limit_config import (RateLimitSettings, - get_rate_limit_config) +from helper_bot.config.rate_limit_config import RateLimitSettings, get_rate_limit_config def _create_rate_limit_config(settings: RateLimitSettings) -> RateLimitConfig: @@ -194,9 +197,10 @@ def _create_rate_limit_config(settings: RateLimitSettings) -> RateLimitConfig: messages_per_second=settings.messages_per_second, burst_limit=settings.burst_limit, retry_after_multiplier=settings.retry_after_multiplier, - max_retry_delay=settings.max_retry_delay + max_retry_delay=settings.max_retry_delay, ) + # Получаем конфигурацию из настроек _rate_limit_settings = get_rate_limit_config("production") _default_config = _create_rate_limit_config(_rate_limit_settings) @@ -204,16 +208,20 @@ _default_config = _create_rate_limit_config(_rate_limit_settings) telegram_rate_limiter = TelegramRateLimiter(_default_config) -async def send_with_rate_limit(send_func: Callable, chat_id: int, *args, **kwargs) -> Any: +async def send_with_rate_limit( + send_func: Callable, chat_id: int, *args, **kwargs +) -> Any: """ Удобная функция для отправки сообщений с rate limiting - + Args: send_func: Функция отправки (например, bot.send_message) chat_id: ID чата *args, **kwargs: Аргументы для функции отправки - + Returns: Результат выполнения функции отправки """ - return await telegram_rate_limiter.send_with_rate_limit(send_func, chat_id, *args, **kwargs) + return await telegram_rate_limiter.send_with_rate_limit( + send_func, chat_id, *args, **kwargs + ) diff --git a/helper_bot/utils/s3_storage.py b/helper_bot/utils/s3_storage.py index 9a7512f..dbbf2d6 100644 --- a/helper_bot/utils/s3_storage.py +++ b/helper_bot/utils/s3_storage.py @@ -1,114 +1,114 @@ """ Сервис для работы с S3 хранилищем. """ + import os import tempfile from pathlib import Path from typing import Optional import aioboto3 + from logs.custom_logger import logger class S3StorageService: """Сервис для работы с S3 хранилищем.""" - - def __init__(self, endpoint_url: str, access_key: str, secret_key: str, - bucket_name: str, region: str = "us-east-1"): + + def __init__( + self, + endpoint_url: str, + access_key: str, + secret_key: str, + bucket_name: str, + region: str = "us-east-1", + ): self.endpoint_url = endpoint_url self.access_key = access_key self.secret_key = secret_key self.bucket_name = bucket_name self.region = region self.session = aioboto3.Session() - - async def upload_file(self, file_path: str, s3_key: str, - content_type: Optional[str] = None) -> bool: + + async def upload_file( + self, file_path: str, s3_key: str, content_type: Optional[str] = None + ) -> bool: """Загружает файл в S3.""" try: async with self.session.client( - 's3', + "s3", endpoint_url=self.endpoint_url, aws_access_key_id=self.access_key, aws_secret_access_key=self.secret_key, - region_name=self.region + region_name=self.region, ) as s3: extra_args = {} if content_type: - extra_args['ContentType'] = content_type - + extra_args["ContentType"] = content_type + await s3.upload_file( - file_path, - self.bucket_name, - s3_key, - ExtraArgs=extra_args + file_path, self.bucket_name, s3_key, ExtraArgs=extra_args ) logger.info(f"Файл загружен в S3: {s3_key}") return True except Exception as e: logger.error(f"Ошибка загрузки файла в S3 {s3_key}: {e}") return False - - async def upload_fileobj(self, file_obj, s3_key: str, - content_type: Optional[str] = None) -> bool: + + async def upload_fileobj( + self, file_obj, s3_key: str, content_type: Optional[str] = None + ) -> bool: """Загружает файл из объекта в S3.""" try: async with self.session.client( - 's3', + "s3", endpoint_url=self.endpoint_url, aws_access_key_id=self.access_key, aws_secret_access_key=self.secret_key, - region_name=self.region + region_name=self.region, ) as s3: extra_args = {} if content_type: - extra_args['ContentType'] = content_type - + extra_args["ContentType"] = content_type + await s3.upload_fileobj( - file_obj, - self.bucket_name, - s3_key, - ExtraArgs=extra_args + file_obj, self.bucket_name, s3_key, ExtraArgs=extra_args ) logger.info(f"Файл загружен в S3 из объекта: {s3_key}") return True except Exception as e: logger.error(f"Ошибка загрузки файла в S3 из объекта {s3_key}: {e}") return False - + async def download_file(self, s3_key: str, local_path: str) -> bool: """Скачивает файл из S3 на локальный диск.""" try: async with self.session.client( - 's3', + "s3", endpoint_url=self.endpoint_url, aws_access_key_id=self.access_key, aws_secret_access_key=self.secret_key, - region_name=self.region + region_name=self.region, ) as s3: # Создаем директорию если её нет os.makedirs(os.path.dirname(local_path), exist_ok=True) - - await s3.download_file( - self.bucket_name, - s3_key, - local_path - ) + + await s3.download_file(self.bucket_name, s3_key, local_path) logger.info(f"Файл скачан из S3: {s3_key} -> {local_path}") return True except Exception as e: logger.error(f"Ошибка скачивания файла из S3 {s3_key}: {e}") return False - + async def download_to_temp(self, s3_key: str) -> Optional[str]: """Скачивает файл из S3 во временный файл. Возвращает путь к временному файлу.""" try: # Определяем расширение из ключа - ext = Path(s3_key).suffix or '.bin' + ext = Path(s3_key).suffix or ".bin" temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=ext) temp_path = temp_file.name temp_file.close() - + success = await self.download_file(s3_key, temp_path) if success: return temp_path @@ -120,33 +120,35 @@ class S3StorageService: pass return None except Exception as e: - logger.error(f"Ошибка скачивания файла из S3 во временный файл {s3_key}: {e}") + logger.error( + f"Ошибка скачивания файла из S3 во временный файл {s3_key}: {e}" + ) return None - + async def file_exists(self, s3_key: str) -> bool: """Проверяет существование файла в S3.""" try: async with self.session.client( - 's3', + "s3", endpoint_url=self.endpoint_url, aws_access_key_id=self.access_key, aws_secret_access_key=self.secret_key, - region_name=self.region + region_name=self.region, ) as s3: await s3.head_object(Bucket=self.bucket_name, Key=s3_key) return True except: return False - + async def delete_file(self, s3_key: str) -> bool: """Удаляет файл из S3.""" try: async with self.session.client( - 's3', + "s3", endpoint_url=self.endpoint_url, aws_access_key_id=self.access_key, aws_secret_access_key=self.secret_key, - region_name=self.region + region_name=self.region, ) as s3: await s3.delete_object(Bucket=self.bucket_name, Key=s3_key) logger.info(f"Файл удален из S3: {s3_key}") @@ -154,23 +156,35 @@ class S3StorageService: except Exception as e: logger.error(f"Ошибка удаления файла из S3 {s3_key}: {e}") return False - + def generate_s3_key(self, content_type: str, file_id: str) -> str: """Генерирует S3 ключ для файла. Один и тот же для всех постов с этим file_id.""" type_folders = { - 'photo': 'photos', - 'video': 'videos', - 'audio': 'music', - 'voice': 'voice', - 'video_note': 'video_notes' + "photo": "photos", + "video": "videos", + "audio": "music", + "voice": "voice", + "video_note": "video_notes", } - - folder = type_folders.get(content_type, 'other') + + folder = type_folders.get(content_type, "other") # Определяем расширение из file_id или используем дефолтное - ext = '.jpg' if content_type == 'photo' else \ - '.mp4' if content_type == 'video' else \ - '.mp3' if content_type == 'audio' else \ - '.ogg' if content_type == 'voice' else \ - '.mp4' if content_type == 'video_note' else '.bin' - + ext = ( + ".jpg" + if content_type == "photo" + else ( + ".mp4" + if content_type == "video" + else ( + ".mp3" + if content_type == "audio" + else ( + ".ogg" + if content_type == "voice" + else ".mp4" if content_type == "video_note" else ".bin" + ) + ) + ) + ) + return f"{folder}/{file_id}{ext}" diff --git a/logs/custom_logger.py b/logs/custom_logger.py index 03a57f3..2f6ca95 100644 --- a/logs/custom_logger.py +++ b/logs/custom_logger.py @@ -8,7 +8,7 @@ from loguru import logger logger.remove() # Check if running in Docker/container -is_container = os.path.exists('/.dockerenv') or os.getenv('DOCKER_CONTAINER') == 'true' +is_container = os.path.exists("/.dockerenv") or os.getenv("DOCKER_CONTAINER") == "true" if is_container: # In container: log to stdout/stderr @@ -16,23 +16,23 @@ if is_container: sys.stdout, format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {name} | {line} | {message}", level=os.getenv("LOG_LEVEL", "INFO"), - colorize=True + colorize=True, ) logger.add( sys.stderr, format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {name} | {line} | {message}", level="ERROR", - colorize=True + colorize=True, ) else: # Local development: log to files current_dir = os.path.dirname(os.path.abspath(__file__)) if not os.path.exists(current_dir): os.makedirs(current_dir) - - today = datetime.date.today().strftime('%Y-%m-%d') - filename = f'{current_dir}/helper_bot_{today}.log' - + + today = datetime.date.today().strftime("%Y-%m-%d") + filename = f"{current_dir}/helper_bot_{today}.log" + logger.add( filename, rotation="00:00", @@ -42,4 +42,4 @@ else: ) # Bind logger name -logger = logger.bind(name='main_log') +logger = logger.bind(name="main_log") diff --git a/pyproject.toml b/pyproject.toml index 8689105..75f42a4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,6 +4,13 @@ version = "1.0.0" description = "Telegram bot with monitoring and metrics" requires-python = ">=3.11" +[tool.black] +line-length = 88 + +[tool.isort] +profile = "black" +line_length = 88 + [tool.pytest.ini_options] testpaths = ["tests"] python_files = ["test_*.py"] diff --git a/requirements-dev.txt b/requirements-dev.txt index 1b4d9a4..411601d 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -9,5 +9,6 @@ coverage>=7.0.0 # Development tools black>=23.0.0 +isort>=5.12.0 flake8>=6.0.0 mypy>=1.0.0 diff --git a/run_helper.py b/run_helper.py index 178857f..8715db1 100644 --- a/run_helper.py +++ b/run_helper.py @@ -17,71 +17,70 @@ from logs.custom_logger import logger async def main(): """Основная функция запуска""" - + bdf = get_global_instance() - + # Создаем бота для автоматического разбана from aiogram import Bot from aiogram.client.default import DefaultBotProperties - + auto_unban_bot = Bot( - token=bdf.settings['Telegram']['bot_token'], - default=DefaultBotProperties(parse_mode='HTML'), - timeout=30.0 + token=bdf.settings["Telegram"]["bot_token"], + default=DefaultBotProperties(parse_mode="HTML"), + timeout=30.0, ) - + # Инициализируем планировщик автоматического разбана auto_unban_scheduler = get_auto_unban_scheduler() auto_unban_scheduler.set_bot(auto_unban_bot) auto_unban_scheduler.start_scheduler() - + # Метрики запускаются в main.py через server_prometheus.py # Здесь не нужно дублировать функциональность - + # Флаг для корректного завершения shutdown_event = asyncio.Event() - + def signal_handler(signum, frame): """Обработчик сигналов для корректного завершения""" logger.info(f"Получен сигнал {signum}, завершаем работу...") shutdown_event.set() - + # Регистрируем обработчики сигналов signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGTERM, signal_handler) - + # Запускаем бота (метрики запускаются внутри start_bot) bot_task = asyncio.create_task(start_bot(bdf)) - + main_bot = None - + try: # Ждем сигнала завершения await shutdown_event.wait() logger.info("Начинаем корректное завершение...") - + except KeyboardInterrupt: logger.info("Получен сигнал завершения...") finally: logger.info("Останавливаем планировщик автоматического разбана...") auto_unban_scheduler.stop_scheduler() - + # Останавливаем планировщик метрик try: - from helper_bot.utils.metrics_scheduler import \ - stop_metrics_scheduler + from helper_bot.utils.metrics_scheduler import stop_metrics_scheduler + stop_metrics_scheduler() logger.info("Планировщик метрик остановлен") except Exception as e: logger.error(f"Ошибка при остановке планировщика метрик: {e}") - + # Метрики останавливаются в main.py - + logger.info("Останавливаем задачи...") # Отменяем задачу бота bot_task.cancel() - - + # Ждем завершения задачи бота и получаем результат main bot try: results = await asyncio.gather(bot_task, return_exceptions=True) @@ -90,42 +89,46 @@ async def main(): main_bot = results[0] except Exception as e: logger.error(f"Ошибка при остановке задач: {e}") - + # Закрываем сессию основного бота (если она еще не закрыта) - if main_bot and hasattr(main_bot, 'session') and not main_bot.session.closed: + if main_bot and hasattr(main_bot, "session") and not main_bot.session.closed: try: await main_bot.session.close() logger.info("Сессия основного бота корректно закрыта") except Exception as e: logger.error(f"Ошибка при закрытии сессии основного бота: {e}") - + # Закрываем сессию бота для автоматического разбана if not auto_unban_bot.session.closed: try: await auto_unban_bot.session.close() logger.info("Сессия бота автоматического разбана корректно закрыта") except Exception as e: - logger.error(f"Ошибка при закрытии сессии бота автоматического разбана: {e}") - + logger.error( + f"Ошибка при закрытии сессии бота автоматического разбана: {e}" + ) + # Даем время на завершение всех aiohttp соединений await asyncio.sleep(0.2) - + logger.info("Бот корректно остановлен") + def init_db(): - db_path = '/app/database/tg-bot-database.db' - schema_path = '/app/database/schema.sql' - + db_path = "/app/database/tg-bot-database.db" + schema_path = "/app/database/schema.sql" + if not os.path.exists(db_path): print("Initializing database...") - with open(schema_path, 'r') as f: + with open(schema_path, "r") as f: schema = f.read() - + with sqlite3.connect(db_path) as conn: conn.executescript(schema) print("Database initialized successfully") -if __name__ == '__main__': + +if __name__ == "__main__": try: init_db() asyncio.run(main()) @@ -139,9 +142,11 @@ if __name__ == '__main__': pending = asyncio.all_tasks(loop) for task in pending: task.cancel() - + # Ждем завершения всех задач if pending: - loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True)) - + loop.run_until_complete( + asyncio.gather(*pending, return_exceptions=True) + ) + loop.close() diff --git a/scripts/add_ml_scores_columns.py b/scripts/add_ml_scores_columns.py index 78e237e..17caa40 100644 --- a/scripts/add_ml_scores_columns.py +++ b/scripts/add_ml_scores_columns.py @@ -11,6 +11,7 @@ "rag": {"score": 0.90, "model": "rubert-base-cased", "ts": 1706198400} } """ + import argparse import asyncio import os @@ -28,7 +29,10 @@ try: from logs.custom_logger import logger except ImportError: import logging - logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') + + logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" + ) logger = logging.getLogger(__name__) DEFAULT_DB_PATH = "database/tg-bot-database.db" @@ -44,19 +48,19 @@ async def column_exists(conn: aiosqlite.Connection, table: str, column: str) -> async def main(db_path: str) -> None: """ Основная функция миграции. - + Добавляет колонку ml_scores в таблицу post_from_telegram_suggest. Миграция идемпотентна - можно запускать повторно без ошибок. """ db_path = os.path.abspath(db_path) - + if not os.path.exists(db_path): logger.error(f"База данных не найдена: {db_path}") return - + async with aiosqlite.connect(db_path) as conn: await conn.execute("PRAGMA foreign_keys = ON") - + # Проверяем и добавляем колонку ml_scores if not await column_exists(conn, "post_from_telegram_suggest", "ml_scores"): await conn.execute( @@ -65,7 +69,7 @@ async def main(db_path: str) -> None: logger.info("Колонка ml_scores добавлена в post_from_telegram_suggest") else: logger.info("Колонка ml_scores уже существует") - + await conn.commit() logger.info("Миграция add_ml_scores_columns завершена успешно") diff --git a/scripts/apply_migrations.py b/scripts/apply_migrations.py index aff54d3..42d8a9b 100644 --- a/scripts/apply_migrations.py +++ b/scripts/apply_migrations.py @@ -4,6 +4,7 @@ Сканирует папку scripts/ и применяет все новые миграции, которые еще не были применены. """ + import argparse import asyncio import importlib.util @@ -15,9 +16,9 @@ from typing import List, Tuple # Исключаем служебные скрипты из миграций EXCLUDED_SCRIPTS = { - 'apply_migrations.py', - 'test_s3_connection.py', - 'voice_cleanup.py', + "apply_migrations.py", + "test_s3_connection.py", + "voice_cleanup.py", } DEFAULT_DB_PATH = "database/tg-bot-database.db" @@ -26,7 +27,7 @@ DEFAULT_DB_PATH = "database/tg-bot-database.db" def get_migration_scripts(scripts_dir: Path) -> List[Tuple[str, Path]]: """ Получает список скриптов миграций из папки scripts. - + Возвращает список кортежей (имя_файла, путь_к_файлу), отсортированный по имени файла. """ scripts = [] @@ -39,24 +40,25 @@ def get_migration_scripts(scripts_dir: Path) -> List[Tuple[str, Path]]: async def is_migration_script(script_path: Path) -> bool: """ Проверяет, является ли скрипт миграцией. - + Миграция должна иметь функцию main() с параметром db_path. """ try: spec = importlib.util.spec_from_file_location("migration_script", script_path) if spec is None or spec.loader is None: return False - + module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) - + # Проверяем наличие функции main - if hasattr(module, 'main'): + if hasattr(module, "main"): import inspect + sig = inspect.signature(module.main) # Проверяем, что функция принимает db_path params = list(sig.parameters.keys()) - return 'db_path' in params + return "db_path" in params return False except Exception: # Если не удалось проверить, считаем что это не миграция @@ -66,12 +68,12 @@ async def is_migration_script(script_path: Path) -> bool: async def apply_migration(script_path: Path, db_path: str) -> bool: """ Применяет миграцию, запуская скрипт. - + Returns: True если миграция применена успешно, False в противном случае. """ script_name = script_path.name - + try: # Запускаем скрипт как отдельный процесс result = subprocess.run( @@ -79,9 +81,9 @@ async def apply_migration(script_path: Path, db_path: str) -> bool: cwd=script_path.parent.parent, capture_output=True, text=True, - timeout=300 # 5 минут максимум на миграцию + timeout=300, # 5 минут максимум на миграцию ) - + if result.returncode == 0: if result.stdout: print(f" {result.stdout.strip()}") @@ -93,7 +95,7 @@ async def apply_migration(script_path: Path, db_path: str) -> bool: if result.stderr: print(f" STDERR: {result.stderr}") return False - + except subprocess.TimeoutExpired: print(f" ❌ Превышен лимит времени (5 минут)") return False @@ -105,7 +107,7 @@ async def apply_migration(script_path: Path, db_path: str) -> bool: async def main(db_path: str, dry_run: bool = False) -> None: """ Основная функция для применения миграций. - + Args: db_path: Путь к базе данных dry_run: Если True, только показывает какие миграции будут применены @@ -113,7 +115,7 @@ async def main(db_path: str, dry_run: bool = False) -> None: # Импортируем зависимости только когда они действительно нужны project_root = Path(__file__).resolve().parent.parent sys.path.insert(0, str(project_root)) - + # Проверяем наличие необходимых зависимостей try: import aiosqlite @@ -121,53 +123,60 @@ async def main(db_path: str, dry_run: bool = False) -> None: print("❌ Ошибка: модуль aiosqlite не установлен.") print("💡 Установите зависимости: pip install -r requirements.txt") sys.exit(1) - + # Импортируем logger try: from logs.custom_logger import logger except ImportError: import logging - logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') + + logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" + ) logger = logging.getLogger(__name__) - + # Импортируем MigrationRepository напрямую из файла - migration_repo_path = project_root / "database" / "repositories" / "migration_repository.py" + migration_repo_path = ( + project_root / "database" / "repositories" / "migration_repository.py" + ) if not migration_repo_path.exists(): print(f"❌ Файл migration_repository.py не найден: {migration_repo_path}") sys.exit(1) - - spec = importlib.util.spec_from_file_location("migration_repository", migration_repo_path) + + spec = importlib.util.spec_from_file_location( + "migration_repository", migration_repo_path + ) if spec is None or spec.loader is None: print("❌ Не удалось загрузить модуль migration_repository") sys.exit(1) - + migration_module = importlib.util.module_from_spec(spec) spec.loader.exec_module(migration_module) MigrationRepository = migration_module.MigrationRepository - + db_path = os.path.abspath(db_path) if not os.path.exists(db_path): logger.error(f"База данных не найдена: {db_path}") print(f"❌ Ошибка: база данных не найдена: {db_path}") return - + scripts_dir = project_root / "scripts" if not scripts_dir.exists(): logger.error(f"Папка scripts не найдена: {scripts_dir}") print(f"❌ Ошибка: папка scripts не найдена: {scripts_dir}") return - + # Инициализируем репозиторий миграций напрямую migration_repo = MigrationRepository(db_path) await migration_repo.create_table() - + # Получаем список примененных миграций applied_migrations = await migration_repo.get_applied_migrations() logger.info(f"Примененных миграций: {len(applied_migrations)}") - + # Получаем все скрипты миграций all_scripts = get_migration_scripts(scripts_dir) - + # Фильтруем только миграции migration_scripts = [] for script_name, script_path in all_scripts: @@ -175,30 +184,31 @@ async def main(db_path: str, dry_run: bool = False) -> None: migration_scripts.append((script_name, script_path)) else: logger.debug(f"Скрипт {script_name} не является миграцией, пропускаем") - + # Находим новые миграции new_migrations = [ - (name, path) for name, path in migration_scripts + (name, path) + for name, path in migration_scripts if name not in applied_migrations ] - + if not new_migrations: print("✅ Все миграции уже применены") logger.info("Новых миграций не найдено") return - + print(f"📋 Найдено новых миграций: {len(new_migrations)}") for name, _ in new_migrations: print(f" - {name}") - + if dry_run: print("\n🔍 DRY RUN: миграции не будут применены") return - + # Применяем миграции по порядку print("\n🚀 Применение миграций...") failed_migrations = [] - + for script_name, script_path in new_migrations: print(f"📝 {script_name}...", end=" ", flush=True) success = await apply_migration(script_path, db_path) @@ -213,7 +223,7 @@ async def main(db_path: str, dry_run: bool = False) -> None: # Прерываем выполнение при ошибке print(f"\n⚠️ Прерывание: миграция {script_name} завершилась с ошибкой") break - + if failed_migrations: print(f"\n❌ Не удалось применить {len(failed_migrations)} миграций:") for name in failed_migrations: @@ -224,9 +234,7 @@ async def main(db_path: str, dry_run: bool = False) -> None: if __name__ == "__main__": - parser = argparse.ArgumentParser( - description="Применение миграций базы данных" - ) + parser = argparse.ArgumentParser(description="Применение миграций базы данных") parser.add_argument( "--db", default=os.environ.get("DATABASE_PATH", DEFAULT_DB_PATH), diff --git a/scripts/drop_vector_hash_column.py b/scripts/drop_vector_hash_column.py index 4d0bdd5..d91a42e 100644 --- a/scripts/drop_vector_hash_column.py +++ b/scripts/drop_vector_hash_column.py @@ -8,6 +8,7 @@ SQLite не поддерживает DROP COLUMN напрямую (до версии 3.35.0), поэтому используем пересоздание таблицы. """ + import argparse import asyncio import os @@ -25,7 +26,10 @@ try: from logs.custom_logger import logger except ImportError: import logging - logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') + + logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" + ) logger = logging.getLogger(__name__) DEFAULT_DB_PATH = "database/tg-bot-database.db" @@ -42,7 +46,7 @@ async def get_sqlite_version(conn: aiosqlite.Connection) -> tuple: """Возвращает версию SQLite.""" cursor = await conn.execute("SELECT sqlite_version()") version_str = (await cursor.fetchone())[0] - return tuple(map(int, version_str.split('.'))) + return tuple(map(int, version_str.split("."))) async def main(db_path: str) -> None: @@ -50,21 +54,21 @@ async def main(db_path: str) -> None: Удаляет колонку vector_hash из таблицы post_from_telegram_suggest. """ db_path = os.path.abspath(db_path) - + if not os.path.exists(db_path): logger.error(f"База данных не найдена: {db_path}") return - + async with aiosqlite.connect(db_path) as conn: # Проверяем существует ли колонка if not await column_exists(conn, "post_from_telegram_suggest", "vector_hash"): logger.info("Колонка vector_hash не существует, миграция не требуется") return - + # Проверяем версию SQLite version = await get_sqlite_version(conn) logger.info(f"Версия SQLite: {'.'.join(map(str, version))}") - + # SQLite 3.35.0+ поддерживает DROP COLUMN if version >= (3, 35, 0): logger.info("Используем ALTER TABLE DROP COLUMN") @@ -74,15 +78,15 @@ async def main(db_path: str) -> None: else: # Для старых версий пересоздаём таблицу logger.info("Используем пересоздание таблицы (SQLite < 3.35.0)") - + # Получаем список колонок без vector_hash cursor = await conn.execute("PRAGMA table_info(post_from_telegram_suggest)") columns = await cursor.fetchall() column_names = [col[1] for col in columns if col[1] != "vector_hash"] columns_str = ", ".join(column_names) - + logger.info(f"Колонки для сохранения: {columns_str}") - + # Пересоздаём таблицу await conn.execute("BEGIN TRANSACTION") try: @@ -91,21 +95,21 @@ async def main(db_path: str) -> None: f"CREATE TABLE post_from_telegram_suggest_backup AS " f"SELECT {columns_str} FROM post_from_telegram_suggest" ) - + # Удаляем старую таблицу await conn.execute("DROP TABLE post_from_telegram_suggest") - + # Переименовываем временную await conn.execute( "ALTER TABLE post_from_telegram_suggest_backup " "RENAME TO post_from_telegram_suggest" ) - + await conn.execute("COMMIT") except Exception as e: await conn.execute("ROLLBACK") raise e - + await conn.commit() logger.info("Колонка vector_hash успешно удалена") diff --git a/scripts/test_s3_connection.py b/scripts/test_s3_connection.py index 66abe03..de957a0 100755 --- a/scripts/test_s3_connection.py +++ b/scripts/test_s3_connection.py @@ -3,6 +3,7 @@ Скрипт для проверки подключения к S3 хранилищу. Читает настройки из .env файла или переменных окружения. """ + import asyncio import os import sys @@ -14,7 +15,7 @@ sys.path.insert(0, str(project_root)) # Загружаем .env файл from dotenv import load_dotenv -env_path = os.path.join(project_root, '.env') +env_path = os.path.join(project_root, ".env") if os.path.exists(env_path): load_dotenv(env_path) @@ -26,11 +27,12 @@ except ImportError: sys.exit(1) # Данные для подключения из .env или переменных окружения -S3_ACCESS_KEY = os.getenv('S3_ACCESS_KEY', 'j3tears100@gmail.com') -S3_SECRET_KEY = os.getenv('S3_SECRET_KEY', 'wQ1-6sZEPs92sbZTSf96') -S3_ENDPOINT_URL = os.getenv('S3_ENDPOINT_URL', 'https://api.s3.miran.ru:443') -S3_BUCKET_NAME = os.getenv('S3_BUCKET_NAME', 'telegram-helper-bot') -S3_REGION = os.getenv('S3_REGION', 'us-east-1') +S3_ACCESS_KEY = os.getenv("S3_ACCESS_KEY", "j3tears100@gmail.com") +S3_SECRET_KEY = os.getenv("S3_SECRET_KEY", "wQ1-6sZEPs92sbZTSf96") +S3_ENDPOINT_URL = os.getenv("S3_ENDPOINT_URL", "https://api.s3.miran.ru:443") +S3_BUCKET_NAME = os.getenv("S3_BUCKET_NAME", "telegram-helper-bot") +S3_REGION = os.getenv("S3_REGION", "us-east-1") + async def test_s3_connection(): """Тестирует подключение к S3 хранилищу.""" @@ -40,50 +42,54 @@ async def test_s3_connection(): print(f"Region: {S3_REGION}") print(f"Access Key: {S3_ACCESS_KEY}") print() - + session = aioboto3.Session() - + try: async with session.client( - 's3', + "s3", endpoint_url=S3_ENDPOINT_URL, aws_access_key_id=S3_ACCESS_KEY, aws_secret_access_key=S3_SECRET_KEY, - region_name=S3_REGION + region_name=S3_REGION, ) as s3: # Пытаемся получить список бакетов (может не иметь прав, пропускаем если ошибка) print("📦 Получение списка бакетов...") try: response = await s3.list_buckets() - buckets = response.get('Buckets', []) + buckets = response.get("Buckets", []) print(f"✅ Подключение успешно! Найдено бакетов: {len(buckets)}") - + if buckets: print("\n📋 Список бакетов:") for bucket in buckets: - print(f" - {bucket['Name']} (создан: {bucket.get('CreationDate', 'неизвестно')})") + print( + f" - {bucket['Name']} (создан: {bucket.get('CreationDate', 'неизвестно')})" + ) else: print("\n⚠️ Бакеты не найдены.") except Exception as list_error: print(f"⚠️ Не удалось получить список бакетов: {list_error}") print(" Это нормально, если нет прав на list_buckets") print(" Продолжаем тестирование с указанным бакетом...") - + # Пытаемся создать тестовый файл в указанном бакете print("\n🧪 Тестирование записи файла...") # Используем первый найденный бакет, если указанный не найден test_bucket = S3_BUCKET_NAME if buckets: # Проверяем, есть ли указанный бакет в списке - bucket_names = [b['Name'] for b in buckets] + bucket_names = [b["Name"] for b in buckets] if test_bucket not in bucket_names: print(f"⚠️ Бакет '{test_bucket}' не найден в списке.") - print(f" Используем первый найденный бакет: '{buckets[0]['Name']}'") - test_bucket = buckets[0]['Name'] - - test_key = 'test-connection.txt' - test_content = b'Test connection to S3 storage' - + print( + f" Используем первый найденный бакет: '{buckets[0]['Name']}'" + ) + test_bucket = buckets[0]["Name"] + + test_key = "test-connection.txt" + test_content = b"Test connection to S3 storage" + try: # Проверяем существование бакета try: @@ -93,33 +99,32 @@ async def test_s3_connection(): print(f"❌ Бакет '{test_bucket}' недоступен: {head_error}") print(" Проверьте права доступа к бакету") return False - - await s3.put_object( - Bucket=test_bucket, - Key=test_key, - Body=test_content + + await s3.put_object(Bucket=test_bucket, Key=test_key, Body=test_content) + print( + f"✅ Файл успешно записан в бакет '{test_bucket}' с ключом '{test_key}'" ) - print(f"✅ Файл успешно записан в бакет '{test_bucket}' с ключом '{test_key}'") - + # Пытаемся прочитать файл print("🧪 Тестирование чтения файла...") response = await s3.get_object(Bucket=test_bucket, Key=test_key) - content = await response['Body'].read() - + content = await response["Body"].read() + if content == test_content: print("✅ Файл успешно прочитан, содержимое совпадает") else: print("⚠️ Файл прочитан, но содержимое не совпадает") - + # Удаляем тестовый файл print("🧹 Удаление тестового файла...") await s3.delete_object(Bucket=test_bucket, Key=test_key) print("✅ Тестовый файл удален") - + except Exception as e: print(f"❌ Ошибка при тестировании записи/чтения: {e}") print(f" Тип ошибки: {type(e).__name__}") import traceback + print(f" Полный traceback:") traceback.print_exc() print("\nВозможные причины:") @@ -127,9 +132,9 @@ async def test_s3_connection(): print(" 2. Нет прав на запись в бакет") print(" 3. Неверный endpoint URL или регион") print(" 4. Проблемы с форматом endpoint (попробуйте без :443)") - + return True - + except Exception as e: print(f"❌ Ошибка подключения к S3: {e}") print("\nВозможные причины:") diff --git a/scripts/voice_cleanup.py b/scripts/voice_cleanup.py index 7bb89d4..09f626e 100644 --- a/scripts/voice_cleanup.py +++ b/scripts/voice_cleanup.py @@ -2,6 +2,7 @@ """ Скрипт для диагностики и очистки проблем с голосовыми файлами """ + import asyncio import os import sys @@ -24,15 +25,15 @@ async def main(): if not os.path.exists(db_path): logger.error(f"База данных не найдена: {db_path}") return - + bot_db = AsyncBotDB(db_path) cleanup_utils = VoiceFileCleanupUtils(bot_db) - + print("=== Диагностика голосовых файлов ===") - + # Запускаем полную диагностику diagnostic_result = await cleanup_utils.run_full_diagnostic() - + print(f"\n📊 Статистика диска:") if "error" in diagnostic_result["disk_stats"]: print(f" ❌ Ошибка: {diagnostic_result['disk_stats']['error']}") @@ -41,59 +42,65 @@ async def main(): print(f" 📁 Директория: {stats['directory']}") print(f" 📄 Всего файлов: {stats['total_files']}") print(f" 💾 Размер: {stats['total_size_mb']} MB") - + print(f"\n🗄️ База данных:") print(f" 📝 Записей в БД: {diagnostic_result['db_records_count']}") - print(f" 🔍 Записей без файлов: {diagnostic_result['orphaned_db_records_count']}") + print( + f" 🔍 Записей без файлов: {diagnostic_result['orphaned_db_records_count']}" + ) print(f" 📁 Файлов без записей: {diagnostic_result['orphaned_files_count']}") - + print(f"\n📋 Статус: {diagnostic_result['status']}") - - if diagnostic_result['status'] == 'issues_found': + + if diagnostic_result["status"] == "issues_found": print("\n⚠️ Найдены проблемы!") - - if diagnostic_result['orphaned_db_records_count'] > 0: + + if diagnostic_result["orphaned_db_records_count"] > 0: print(f"\n🗑️ Записи в БД без файлов (первые 10):") - for file_name, user_id in diagnostic_result['orphaned_db_records']: + for file_name, user_id in diagnostic_result["orphaned_db_records"]: print(f" - {file_name} (user_id: {user_id})") - - if diagnostic_result['orphaned_files_count'] > 0: + + if diagnostic_result["orphaned_files_count"] > 0: print(f"\n📁 Файлы без записей в БД (первые 10):") - for file_path in diagnostic_result['orphaned_files']: + for file_path in diagnostic_result["orphaned_files"]: print(f" - {file_path}") - + # Предлагаем очистку print("\n🧹 Хотите выполнить очистку?") print("1. Удалить записи в БД без файлов") print("2. Удалить файлы без записей в БД") print("3. Выполнить полную очистку") print("4. Выход") - + choice = input("\nВыберите действие (1-4): ").strip() - + if choice == "1": print("\n🗑️ Удаление записей в БД без файлов...") deleted = await cleanup_utils.cleanup_orphaned_db_records(dry_run=False) print(f"✅ Удалено {deleted} записей") - + elif choice == "2": print("\n📁 Удаление файлов без записей в БД...") deleted = await cleanup_utils.cleanup_orphaned_files(dry_run=False) print(f"✅ Удалено {deleted} файлов") - + elif choice == "3": print("\n🧹 Полная очистка...") - db_deleted = await cleanup_utils.cleanup_orphaned_db_records(dry_run=False) - files_deleted = await cleanup_utils.cleanup_orphaned_files(dry_run=False) + db_deleted = await cleanup_utils.cleanup_orphaned_db_records( + dry_run=False + ) + files_deleted = await cleanup_utils.cleanup_orphaned_files( + dry_run=False + ) print(f"✅ Удалено {db_deleted} записей в БД и {files_deleted} файлов") - + elif choice == "4": print("👋 Выход...") else: print("❌ Неверный выбор") else: print("\n✅ Проблем не найдено!") - + except Exception as e: logger.error(f"Ошибка в скрипте: {e}") print(f"❌ Ошибка: {e}") diff --git a/tests/conftest.py b/tests/conftest.py index c9fa084..96cb8d7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -13,13 +13,13 @@ if str(_project_root) not in sys.path: import pytest from aiogram.fsm.context import FSMContext from aiogram.types import Chat, Message, User -from database.async_db import AsyncBotDB # Импортируем моки в самом начале import tests.mocks +from database.async_db import AsyncBotDB # Настройка pytest-asyncio -pytest_plugins = ('pytest_asyncio',) +pytest_plugins = ("pytest_asyncio",) @pytest.fixture(scope="session") @@ -100,19 +100,16 @@ def mock_dispatcher(): def test_settings(): """Возвращает тестовые настройки""" return { - 'Telegram': { - 'bot_token': 'test_token_123', - 'preview_link': False, - 'group_for_posts': '-1001234567890', - 'group_for_message': '-1001234567891', - 'main_public': '-1001234567892', - 'group_for_logs': '-1001234567893', - 'important_logs': '-1001234567894' + "Telegram": { + "bot_token": "test_token_123", + "preview_link": False, + "group_for_posts": "-1001234567890", + "group_for_message": "-1001234567891", + "main_public": "-1001234567892", + "group_for_logs": "-1001234567893", + "important_logs": "-1001234567894", }, - 'Settings': { - 'logs': True, - 'test': False - } + "Settings": {"logs": True, "test": False}, } @@ -129,71 +126,71 @@ def mock_factory(test_settings, mock_db): @pytest.fixture def sample_photo_message(mock_message): """Создает сообщение с фото для тестов""" - mock_message.content_type = 'photo' - mock_message.caption = 'Тестовое фото' + mock_message.content_type = "photo" + mock_message.caption = "Тестовое фото" mock_message.media_group_id = None mock_message.photo = [Mock()] - mock_message.photo[-1].file_id = 'photo_file_id' + mock_message.photo[-1].file_id = "photo_file_id" return mock_message @pytest.fixture def sample_video_message(mock_message): """Создает сообщение с видео для тестов""" - mock_message.content_type = 'video' - mock_message.caption = 'Тестовое видео' + mock_message.content_type = "video" + mock_message.caption = "Тестовое видео" mock_message.media_group_id = None mock_message.video = Mock() - mock_message.video.file_id = 'video_file_id' + mock_message.video.file_id = "video_file_id" return mock_message @pytest.fixture def sample_audio_message(mock_message): """Создает сообщение с аудио для тестов""" - mock_message.content_type = 'audio' - mock_message.caption = 'Тестовое аудио' + mock_message.content_type = "audio" + mock_message.caption = "Тестовое аудио" mock_message.media_group_id = None mock_message.audio = Mock() - mock_message.audio.file_id = 'audio_file_id' + mock_message.audio.file_id = "audio_file_id" return mock_message @pytest.fixture def sample_voice_message(mock_message): """Создает голосовое сообщение для тестов""" - mock_message.content_type = 'voice' + mock_message.content_type = "voice" mock_message.media_group_id = None mock_message.voice = Mock() - mock_message.voice.file_id = 'voice_file_id' + mock_message.voice.file_id = "voice_file_id" return mock_message @pytest.fixture def sample_video_note_message(mock_message): """Создает видеокружок для тестов""" - mock_message.content_type = 'video_note' + mock_message.content_type = "video_note" mock_message.media_group_id = None mock_message.video_note = Mock() - mock_message.video_note.file_id = 'video_note_file_id' + mock_message.video_note.file_id = "video_note_file_id" return mock_message @pytest.fixture def sample_media_group(mock_message): """Создает медиагруппу для тестов""" - mock_message.media_group_id = 'group_123' - mock_message.content_type = 'photo' + mock_message.media_group_id = "group_123" + mock_message.content_type = "photo" album = [mock_message] - album[0].caption = 'Подпись к медиагруппе' + album[0].caption = "Подпись к медиагруппе" return album @pytest.fixture def sample_text_message(mock_message): """Создает текстовое сообщение для тестов""" - mock_message.content_type = 'text' - mock_message.text = 'Тестовое текстовое сообщение' + mock_message.content_type = "text" + mock_message.text = "Тестовое текстовое сообщение" mock_message.media_group_id = None return mock_message @@ -201,7 +198,7 @@ def sample_text_message(mock_message): @pytest.fixture def sample_document_message(mock_message): """Создает сообщение с документом для тестов""" - mock_message.content_type = 'document' + mock_message.content_type = "document" mock_message.media_group_id = None return mock_message @@ -209,18 +206,10 @@ def sample_document_message(mock_message): # Маркеры для категоризации тестов def pytest_configure(config): """Настройка маркеров pytest""" - config.addinivalue_line( - "markers", "asyncio: mark test as async" - ) - config.addinivalue_line( - "markers", "slow: mark test as slow" - ) - config.addinivalue_line( - "markers", "integration: mark test as integration test" - ) - config.addinivalue_line( - "markers", "unit: mark test as unit test" - ) + config.addinivalue_line("markers", "asyncio: mark test as async") + config.addinivalue_line("markers", "slow: mark test as slow") + config.addinivalue_line("markers", "integration: mark test as integration test") + config.addinivalue_line("markers", "unit: mark test as unit test") # Автоматическая маркировка тестов @@ -230,15 +219,15 @@ def pytest_collection_modifyitems(config, items): # Маркируем асинхронные тесты if "async" in item.name or "Async" in item.name: item.add_marker(pytest.mark.asyncio) - + # Маркируем интеграционные тесты if "integration" in item.name.lower() or "Integration" in str(item.cls): item.add_marker(pytest.mark.integration) - + # Маркируем unit тесты if "unit" in item.name.lower() or "Unit" in str(item.cls): item.add_marker(pytest.mark.unit) - + # Маркируем медленные тесты if "slow" in item.name.lower() or "Slow" in str(item.cls): item.add_marker(pytest.mark.slow) diff --git a/tests/conftest_message_repository.py b/tests/conftest_message_repository.py index 573943f..90f7b8b 100644 --- a/tests/conftest_message_repository.py +++ b/tests/conftest_message_repository.py @@ -3,6 +3,7 @@ import tempfile from datetime import datetime import pytest + from database.models import UserMessage from database.repositories.message_repository import MessageRepository @@ -10,11 +11,11 @@ from database.repositories.message_repository import MessageRepository @pytest.fixture(scope="session") def test_db_path(): """Фикстура для пути к тестовой БД (сессионная область).""" - with tempfile.NamedTemporaryFile(suffix='.db', delete=False) as f: + with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f: temp_path = f.name - + yield temp_path - + # Очистка после всех тестов try: os.unlink(temp_path) @@ -32,26 +33,26 @@ def message_repository(test_db_path): def sample_messages(): """Фикстура для набора тестовых сообщений.""" base_timestamp = int(datetime.now().timestamp()) - + return [ UserMessage( message_text="Первое тестовое сообщение", user_id=1001, telegram_message_id=2001, - date=base_timestamp + date=base_timestamp, ), UserMessage( message_text="Второе тестовое сообщение", user_id=1002, telegram_message_id=2002, - date=base_timestamp + 1 + date=base_timestamp + 1, ), UserMessage( message_text="Третье тестовое сообщение", user_id=1003, telegram_message_id=2003, - date=base_timestamp + 2 - ) + date=base_timestamp + 2, + ), ] @@ -62,7 +63,7 @@ def message_without_date(): message_text="Сообщение без даты", user_id=1004, telegram_message_id=2004, - date=None + date=None, ) @@ -73,7 +74,7 @@ def message_with_zero_date(): message_text="Сообщение с нулевой датой", user_id=1005, telegram_message_id=2005, - date=0 + date=0, ) @@ -84,7 +85,7 @@ def message_with_special_chars(): message_text="Сообщение с 'кавычками', \"двойными кавычками\" и эмодзи 😊\nНовая строка", user_id=1006, telegram_message_id=2006, - date=int(datetime.now().timestamp()) + date=int(datetime.now().timestamp()), ) @@ -96,7 +97,7 @@ def long_message(): message_text=long_text, user_id=1007, telegram_message_id=2007, - date=int(datetime.now().timestamp()) + date=int(datetime.now().timestamp()), ) @@ -107,7 +108,7 @@ def message_with_unicode(): message_text="Сообщение с Unicode: 你好世界 🌍 Привет мир", user_id=1008, telegram_message_id=2008, - date=int(datetime.now().timestamp()) + date=int(datetime.now().timestamp()), ) diff --git a/tests/conftest_post_repository.py b/tests/conftest_post_repository.py index fca1784..8c660ce 100644 --- a/tests/conftest_post_repository.py +++ b/tests/conftest_post_repository.py @@ -5,6 +5,7 @@ from datetime import datetime from unittest.mock import AsyncMock, Mock import pytest + from database.models import MessageContentLink, PostContent, TelegramPost from database.repositories.post_repository import PostRepository @@ -37,7 +38,7 @@ def sample_telegram_post(): text="Тестовый пост для unit тестов", author_id=67890, helper_text_message_id=None, - created_at=int(datetime.now().timestamp()) + created_at=int(datetime.now().timestamp()), ) @@ -49,7 +50,7 @@ def sample_telegram_post_with_helper(): text="Тестовый пост с helper сообщением", author_id=67890, helper_text_message_id=99999, - created_at=int(datetime.now().timestamp()) + created_at=int(datetime.now().timestamp()), ) @@ -61,7 +62,7 @@ def sample_telegram_post_no_date(): text="Тестовый пост без даты", author_id=67890, helper_text_message_id=None, - created_at=None + created_at=None, ) @@ -69,19 +70,14 @@ def sample_telegram_post_no_date(): def sample_post_content(): """Создает тестовый объект PostContent""" return PostContent( - message_id=12345, - content_name="/path/to/test/file.jpg", - content_type="photo" + message_id=12345, content_name="/path/to/test/file.jpg", content_type="photo" ) @pytest.fixture def sample_message_content_link(): """Создает тестовый объект MessageContentLink""" - return MessageContentLink( - post_id=12345, - message_id=67890 - ) + return MessageContentLink(post_id=12345, message_id=67890) @pytest.fixture @@ -105,11 +101,11 @@ def mock_logger(): @pytest.fixture def temp_db_file(): """Создает временный файл БД для интеграционных тестов""" - with tempfile.NamedTemporaryFile(suffix='.db', delete=False) as tmp_file: + with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp_file: db_path = tmp_file.name - + yield db_path - + # Очищаем временный файл после тестов try: os.unlink(db_path) @@ -132,22 +128,22 @@ def sample_posts_batch(): text="Первый тестовый пост", author_id=11111, helper_text_message_id=None, - created_at=int(datetime.now().timestamp()) + created_at=int(datetime.now().timestamp()), ), TelegramPost( message_id=10002, text="Второй тестовый пост", author_id=22222, helper_text_message_id=None, - created_at=int(datetime.now().timestamp()) + created_at=int(datetime.now().timestamp()), ), TelegramPost( message_id=10003, text="Третий тестовый пост", author_id=33333, helper_text_message_id=88888, - created_at=int(datetime.now().timestamp()) - ) + created_at=int(datetime.now().timestamp()), + ), ] @@ -159,7 +155,7 @@ def sample_content_batch(): (10002, "/path/to/video1.mp4", "video"), (10003, "/path/to/audio1.mp3", "audio"), (10004, "/path/to/photo2.jpg", "photo"), - (10005, "/path/to/video2.mp4", "video") + (10005, "/path/to/video2.mp4", "video"), ] @@ -195,19 +191,19 @@ def sample_author_ids(): def mock_sql_queries(): """Создает мок для SQL запросов""" return { - 'create_tables': [ + "create_tables": [ "CREATE TABLE IF NOT EXISTS post_from_telegram_suggest", "CREATE TABLE IF NOT EXISTS content_post_from_telegram", - "CREATE TABLE IF NOT EXISTS message_link_to_content" + "CREATE TABLE IF NOT EXISTS message_link_to_content", ], - 'add_post': "INSERT INTO post_from_telegram_suggest", - 'add_post_status': "status", - 'update_helper': "UPDATE post_from_telegram_suggest SET helper_text_message_id", - 'update_status': "UPDATE post_from_telegram_suggest SET status = ?", - 'add_content': "INSERT OR IGNORE INTO content_post_from_telegram", - 'add_link': "INSERT OR IGNORE INTO message_link_to_content", - 'get_content': "SELECT cpft.content_name, cpft.content_type", - 'get_text': "SELECT text FROM post_from_telegram_suggest", - 'get_ids': "SELECT mltc.message_id", - 'get_author': "SELECT author_id FROM post_from_telegram_suggest" + "add_post": "INSERT INTO post_from_telegram_suggest", + "add_post_status": "status", + "update_helper": "UPDATE post_from_telegram_suggest SET helper_text_message_id", + "update_status": "UPDATE post_from_telegram_suggest SET status = ?", + "add_content": "INSERT OR IGNORE INTO content_post_from_telegram", + "add_link": "INSERT OR IGNORE INTO message_link_to_content", + "get_content": "SELECT cpft.content_name, cpft.content_type", + "get_text": "SELECT text FROM post_from_telegram_suggest", + "get_ids": "SELECT mltc.message_id", + "get_author": "SELECT author_id FROM post_from_telegram_suggest", } diff --git a/tests/mocks.py b/tests/mocks.py index 4833698..821e66a 100644 --- a/tests/mocks.py +++ b/tests/mocks.py @@ -1,6 +1,7 @@ """ Моки для тестового окружения """ + import os import sys from unittest.mock import Mock, patch @@ -11,33 +12,34 @@ def setup_test_mocks(): """Настройка моков для тестов""" # Мокаем os.getenv mock_env_vars = { - 'BOT_TOKEN': 'test_token_123', - 'LISTEN_BOT_TOKEN': '', - 'TEST_BOT_TOKEN': '', - 'PREVIEW_LINK': 'False', - 'MAIN_PUBLIC': '@test', - 'GROUP_FOR_POSTS': '-1001234567890', - 'GROUP_FOR_MESSAGE': '-1001234567891', - 'GROUP_FOR_LOGS': '-1001234567893', - 'IMPORTANT_LOGS': '-1001234567894', - 'TEST_GROUP': '-1001234567895', - 'LOGS': 'True', - 'TEST': 'False', - 'DATABASE_PATH': 'database/test.db' + "BOT_TOKEN": "test_token_123", + "LISTEN_BOT_TOKEN": "", + "TEST_BOT_TOKEN": "", + "PREVIEW_LINK": "False", + "MAIN_PUBLIC": "@test", + "GROUP_FOR_POSTS": "-1001234567890", + "GROUP_FOR_MESSAGE": "-1001234567891", + "GROUP_FOR_LOGS": "-1001234567893", + "IMPORTANT_LOGS": "-1001234567894", + "TEST_GROUP": "-1001234567895", + "LOGS": "True", + "TEST": "False", + "DATABASE_PATH": "database/test.db", } def mock_getenv(key, default=None): return mock_env_vars.get(key, default) - env_patcher = patch('os.getenv', side_effect=mock_getenv) + env_patcher = patch("os.getenv", side_effect=mock_getenv) env_patcher.start() # Мокаем AsyncBotDB mock_db = Mock() - db_patcher = patch('helper_bot.utils.base_dependency_factory.AsyncBotDB', mock_db) + db_patcher = patch("helper_bot.utils.base_dependency_factory.AsyncBotDB", mock_db) db_patcher.start() - + return env_patcher, db_patcher + # Настраиваем моки при импорте модуля -env_patcher, db_patcher = setup_test_mocks() \ No newline at end of file +env_patcher, db_patcher = setup_test_mocks() diff --git a/tests/test_admin_repository.py b/tests/test_admin_repository.py index 033be96..1eee060 100644 --- a/tests/test_admin_repository.py +++ b/tests/test_admin_repository.py @@ -3,13 +3,14 @@ from datetime import datetime from unittest.mock import AsyncMock, MagicMock, Mock, patch import pytest + from database.models import Admin from database.repositories.admin_repository import AdminRepository class TestAdminRepository: """Тесты для AdminRepository""" - + @pytest.fixture def mock_db_connection(self): """Мок для DatabaseConnection""" @@ -18,137 +19,142 @@ class TestAdminRepository: mock_connection._execute_query_with_result = AsyncMock() mock_connection.logger = Mock() return mock_connection - + @pytest.fixture def admin_repository(self, mock_db_connection): """Экземпляр AdminRepository для тестов""" # Патчим наследование от DatabaseConnection - with patch.object(AdminRepository, '__init__', return_value=None): + with patch.object(AdminRepository, "__init__", return_value=None): repo = AdminRepository() repo._execute_query = mock_db_connection._execute_query - repo._execute_query_with_result = mock_db_connection._execute_query_with_result + repo._execute_query_with_result = ( + mock_db_connection._execute_query_with_result + ) repo.logger = mock_db_connection.logger return repo - + @pytest.fixture def sample_admin(self): """Тестовый администратор""" - return Admin( - user_id=12345, - role="admin" - ) - + return Admin(user_id=12345, role="admin") + @pytest.fixture def sample_admin_with_created_at(self): """Тестовый администратор с датой создания""" return Admin( - user_id=12345, - role="admin", - created_at="1705312200" # UNIX timestamp + user_id=12345, role="admin", created_at="1705312200" # UNIX timestamp ) - + @pytest.mark.asyncio async def test_create_tables(self, admin_repository): """Тест создания таблицы администраторов""" await admin_repository.create_tables() - + # Проверяем, что включены внешние ключи admin_repository._execute_query.assert_called() calls = admin_repository._execute_query.call_args_list - + # Первый вызов должен быть для включения внешних ключей assert calls[0][0][0] == "PRAGMA foreign_keys = ON" - + # Второй вызов должен быть для создания таблицы create_table_call = calls[1] assert "CREATE TABLE IF NOT EXISTS admins" in create_table_call[0][0] assert "user_id INTEGER NOT NULL PRIMARY KEY" in create_table_call[0][0] assert "role TEXT DEFAULT 'admin'" in create_table_call[0][0] - assert "created_at INTEGER DEFAULT (strftime('%s', 'now'))" in create_table_call[0][0] - assert "FOREIGN KEY (user_id) REFERENCES our_users (user_id) ON DELETE CASCADE" in create_table_call[0][0] - + assert ( + "created_at INTEGER DEFAULT (strftime('%s', 'now'))" + in create_table_call[0][0] + ) + assert ( + "FOREIGN KEY (user_id) REFERENCES our_users (user_id) ON DELETE CASCADE" + in create_table_call[0][0] + ) + # Проверяем логирование - admin_repository.logger.info.assert_called_once_with("Таблица администраторов создана") - + admin_repository.logger.info.assert_called_once_with( + "Таблица администраторов создана" + ) + @pytest.mark.asyncio async def test_add_admin(self, admin_repository, sample_admin): """Тест добавления администратора""" await admin_repository.add_admin(sample_admin) - + # Проверяем, что метод вызван с правильными параметрами admin_repository._execute_query.assert_called_once() call_args = admin_repository._execute_query.call_args - + assert call_args[0][0] == "INSERT INTO admins (user_id, role) VALUES (?, ?)" assert call_args[0][1] == (12345, "admin") - + # Проверяем логирование admin_repository.logger.info.assert_called_once_with( "Администратор добавлен: user_id=12345, role=admin" ) - + @pytest.mark.asyncio async def test_add_admin_with_custom_role(self, admin_repository): """Тест добавления администратора с кастомной ролью""" admin = Admin(user_id=67890, role="super_admin") await admin_repository.add_admin(admin) - + call_args = admin_repository._execute_query.call_args assert call_args[0][1] == (67890, "super_admin") - + admin_repository.logger.info.assert_called_once_with( "Администратор добавлен: user_id=67890, role=super_admin" ) - + @pytest.mark.asyncio async def test_remove_admin(self, admin_repository): """Тест удаления администратора""" user_id = 12345 await admin_repository.remove_admin(user_id) - + # Проверяем, что метод вызван с правильными параметрами admin_repository._execute_query.assert_called_once() call_args = admin_repository._execute_query.call_args - + assert call_args[0][0] == "DELETE FROM admins WHERE user_id = ?" assert call_args[0][1] == (user_id,) - + # Проверяем логирование admin_repository.logger.info.assert_called_once_with( "Администратор удален: user_id=12345" ) - + @pytest.mark.asyncio async def test_is_admin_true(self, admin_repository): """Тест проверки администратора - пользователь является администратором""" user_id = 12345 # Мокаем результат запроса - пользователь найден admin_repository._execute_query_with_result.return_value = [(1,)] - + result = await admin_repository.is_admin(user_id) - + # Проверяем, что метод вызван с правильными параметрами admin_repository._execute_query_with_result.assert_called_once() call_args = admin_repository._execute_query_with_result.call_args - + assert call_args[0][0] == "SELECT 1 FROM admins WHERE user_id = ?" assert call_args[0][1] == (user_id,) - + # Проверяем результат assert result is True - + @pytest.mark.asyncio async def test_is_admin_false(self, admin_repository): """Тест проверки администратора - пользователь не является администратором""" user_id = 12345 # Мокаем результат запроса - пользователь не найден admin_repository._execute_query_with_result.return_value = [] - + result = await admin_repository.is_admin(user_id) - + # Проверяем результат assert result is False - + @pytest.mark.asyncio async def test_get_admin_found(self, admin_repository): """Тест получения информации об администраторе - администратор найден""" @@ -157,138 +163,143 @@ class TestAdminRepository: admin_repository._execute_query_with_result.return_value = [ (12345, "admin", "1705312200") ] - + result = await admin_repository.get_admin(user_id) - + # Проверяем, что метод вызван с правильными параметрами admin_repository._execute_query_with_result.assert_called_once() call_args = admin_repository._execute_query_with_result.call_args - - assert call_args[0][0] == "SELECT user_id, role, created_at FROM admins WHERE user_id = ?" + + assert ( + call_args[0][0] + == "SELECT user_id, role, created_at FROM admins WHERE user_id = ?" + ) assert call_args[0][1] == (user_id,) - + # Проверяем результат assert result is not None assert result.user_id == 12345 assert result.role == "admin" assert result.created_at == "1705312200" - + @pytest.mark.asyncio async def test_get_admin_not_found(self, admin_repository): """Тест получения информации об администраторе - администратор не найден""" user_id = 12345 # Мокаем результат запроса - администратор не найден admin_repository._execute_query_with_result.return_value = [] - + result = await admin_repository.get_admin(user_id) - + # Проверяем результат assert result is None - + @pytest.mark.asyncio async def test_get_admin_without_created_at(self, admin_repository): """Тест получения информации об администраторе без даты создания""" user_id = 12345 # Мокаем результат запроса без created_at - admin_repository._execute_query_with_result.return_value = [ - (12345, "admin") - ] - + admin_repository._execute_query_with_result.return_value = [(12345, "admin")] + result = await admin_repository.get_admin(user_id) - + # Проверяем результат assert result is not None assert result.user_id == 12345 assert result.role == "admin" assert result.created_at is None - + @pytest.mark.asyncio async def test_add_admin_error_handling(self, admin_repository, sample_admin): """Тест обработки ошибок при добавлении администратора""" # Мокаем ошибку при выполнении запроса admin_repository._execute_query.side_effect = Exception("Database error") - + with pytest.raises(Exception, match="Database error"): await admin_repository.add_admin(sample_admin) - + @pytest.mark.asyncio async def test_remove_admin_error_handling(self, admin_repository): """Тест обработки ошибок при удалении администратора""" # Мокаем ошибку при выполнении запроса admin_repository._execute_query.side_effect = Exception("Database error") - + with pytest.raises(Exception, match="Database error"): await admin_repository.remove_admin(12345) - + @pytest.mark.asyncio async def test_is_admin_error_handling(self, admin_repository): """Тест обработки ошибок при проверке администратора""" # Мокаем ошибку при выполнении запроса - admin_repository._execute_query_with_result.side_effect = Exception("Database error") - + admin_repository._execute_query_with_result.side_effect = Exception( + "Database error" + ) + with pytest.raises(Exception, match="Database error"): await admin_repository.is_admin(12345) - + @pytest.mark.asyncio async def test_get_admin_error_handling(self, admin_repository): """Тест обработки ошибок при получении информации об администраторе""" # Мокаем ошибку при выполнении запроса - admin_repository._execute_query_with_result.side_effect = Exception("Database error") - + admin_repository._execute_query_with_result.side_effect = Exception( + "Database error" + ) + with pytest.raises(Exception, match="Database error"): await admin_repository.get_admin(12345) - + @pytest.mark.asyncio async def test_create_tables_error_handling(self, admin_repository): """Тест обработки ошибок при создании таблиц""" # Мокаем ошибку при выполнении первого запроса admin_repository._execute_query.side_effect = Exception("Database error") - + with pytest.raises(Exception, match="Database error"): await admin_repository.create_tables() - + @pytest.mark.asyncio async def test_admin_model_compatibility(self, admin_repository): """Тест совместимости с моделью Admin""" user_id = 12345 role = "moderator" - + # Создаем администратора с помощью модели admin = Admin(user_id=user_id, role=role) - + # Проверяем, что можем передать его в репозиторий await admin_repository.add_admin(admin) - + # Проверяем, что вызов был с правильными параметрами call_args = admin_repository._execute_query.call_args assert call_args[0][1] == (user_id, role) - + @pytest.mark.asyncio async def test_multiple_admin_operations(self, admin_repository): """Тест множественных операций с администраторами""" # Добавляем первого администратора admin1 = Admin(user_id=111, role="admin") await admin_repository.add_admin(admin1) - + # Добавляем второго администратора admin2 = Admin(user_id=222, role="moderator") await admin_repository.add_admin(admin2) - + # Проверяем, что оба добавлены assert admin_repository._execute_query.call_count == 2 - + # Проверяем, что первый администратор существует admin_repository._execute_query_with_result.return_value = [(1,)] result1 = await admin_repository.is_admin(111) assert result1 is True - + # Проверяем, что второй администратор существует result2 = await admin_repository.is_admin(222) assert result2 is True - + # Удаляем первого администратора await admin_repository.remove_admin(111) - + # Проверяем, что он больше не существует admin_repository._execute_query_with_result.return_value = [] result3 = await admin_repository.is_admin(111) diff --git a/tests/test_async_db.py b/tests/test_async_db.py index e3ca0a6..02b9b3f 100644 --- a/tests/test_async_db.py +++ b/tests/test_async_db.py @@ -1,12 +1,13 @@ from unittest.mock import AsyncMock, Mock, patch import pytest + from database.async_db import AsyncBotDB class TestAsyncBotDB: """Тесты для AsyncBotDB""" - + @pytest.fixture def mock_factory(self): """Мок для RepositoryFactory""" @@ -23,94 +24,130 @@ class TestAsyncBotDB: mock_factory.blacklist_history.add_record_on_ban = AsyncMock() mock_factory.blacklist_history.set_unban_date = AsyncMock(return_value=True) return mock_factory - + @pytest.fixture def async_bot_db(self, mock_factory): """Экземпляр AsyncBotDB для тестов""" - with patch('database.async_db.RepositoryFactory') as mock_factory_class: + with patch("database.async_db.RepositoryFactory") as mock_factory_class: mock_factory_class.return_value = mock_factory db = AsyncBotDB("test.db") return db - + @pytest.mark.asyncio async def test_delete_audio_moderate_record(self, async_bot_db, mock_factory): """Тест метода delete_audio_moderate_record""" message_id = 12345 - + await async_bot_db.delete_audio_moderate_record(message_id) - + # Проверяем, что метод вызван в репозитории - mock_factory.audio.delete_audio_moderate_record.assert_called_once_with(message_id) - + mock_factory.audio.delete_audio_moderate_record.assert_called_once_with( + message_id + ) + @pytest.mark.asyncio - async def test_delete_audio_moderate_record_with_different_message_id(self, async_bot_db, mock_factory): + async def test_delete_audio_moderate_record_with_different_message_id( + self, async_bot_db, mock_factory + ): """Тест метода delete_audio_moderate_record с разными message_id""" test_cases = [123, 456, 789, 99999] - + for message_id in test_cases: await async_bot_db.delete_audio_moderate_record(message_id) - mock_factory.audio.delete_audio_moderate_record.assert_called_with(message_id) - + mock_factory.audio.delete_audio_moderate_record.assert_called_with( + message_id + ) + # Проверяем, что метод вызван для каждого message_id - assert mock_factory.audio.delete_audio_moderate_record.call_count == len(test_cases) - + assert mock_factory.audio.delete_audio_moderate_record.call_count == len( + test_cases + ) + @pytest.mark.asyncio - async def test_delete_audio_moderate_record_exception_handling(self, async_bot_db, mock_factory): + async def test_delete_audio_moderate_record_exception_handling( + self, async_bot_db, mock_factory + ): """Тест обработки исключений в delete_audio_moderate_record""" message_id = 12345 - mock_factory.audio.delete_audio_moderate_record.side_effect = Exception("Database error") - + mock_factory.audio.delete_audio_moderate_record.side_effect = Exception( + "Database error" + ) + # Метод должен пробросить исключение with pytest.raises(Exception, match="Database error"): await async_bot_db.delete_audio_moderate_record(message_id) - + @pytest.mark.asyncio - async def test_delete_audio_moderate_record_integration_with_other_methods(self, async_bot_db, mock_factory): + async def test_delete_audio_moderate_record_integration_with_other_methods( + self, async_bot_db, mock_factory + ): """Тест интеграции delete_audio_moderate_record с другими методами""" message_id = 12345 user_id = 67890 - + # Мокаем другие методы - mock_factory.audio.get_user_id_by_message_id_for_voice_bot = AsyncMock(return_value=user_id) - mock_factory.audio.set_user_id_and_message_id_for_voice_bot = AsyncMock(return_value=True) - + mock_factory.audio.get_user_id_by_message_id_for_voice_bot = AsyncMock( + return_value=user_id + ) + mock_factory.audio.set_user_id_and_message_id_for_voice_bot = AsyncMock( + return_value=True + ) + # Тестируем последовательность операций await async_bot_db.get_user_id_by_message_id_for_voice_bot(message_id) await async_bot_db.set_user_id_and_message_id_for_voice_bot(message_id, user_id) await async_bot_db.delete_audio_moderate_record(message_id) - + # Проверяем, что все методы вызваны - mock_factory.audio.get_user_id_by_message_id_for_voice_bot.assert_called_once_with(message_id) - mock_factory.audio.set_user_id_and_message_id_for_voice_bot.assert_called_once_with(message_id, user_id) - mock_factory.audio.delete_audio_moderate_record.assert_called_once_with(message_id) - + mock_factory.audio.get_user_id_by_message_id_for_voice_bot.assert_called_once_with( + message_id + ) + mock_factory.audio.set_user_id_and_message_id_for_voice_bot.assert_called_once_with( + message_id, user_id + ) + mock_factory.audio.delete_audio_moderate_record.assert_called_once_with( + message_id + ) + @pytest.mark.asyncio - async def test_delete_audio_moderate_record_zero_message_id(self, async_bot_db, mock_factory): + async def test_delete_audio_moderate_record_zero_message_id( + self, async_bot_db, mock_factory + ): """Тест delete_audio_moderate_record с message_id = 0""" message_id = 0 - + await async_bot_db.delete_audio_moderate_record(message_id) - - mock_factory.audio.delete_audio_moderate_record.assert_called_once_with(message_id) - + + mock_factory.audio.delete_audio_moderate_record.assert_called_once_with( + message_id + ) + @pytest.mark.asyncio - async def test_delete_audio_moderate_record_negative_message_id(self, async_bot_db, mock_factory): + async def test_delete_audio_moderate_record_negative_message_id( + self, async_bot_db, mock_factory + ): """Тест delete_audio_moderate_record с отрицательным message_id""" message_id = -12345 - + await async_bot_db.delete_audio_moderate_record(message_id) - - mock_factory.audio.delete_audio_moderate_record.assert_called_once_with(message_id) - + + mock_factory.audio.delete_audio_moderate_record.assert_called_once_with( + message_id + ) + @pytest.mark.asyncio - async def test_delete_audio_moderate_record_large_message_id(self, async_bot_db, mock_factory): + async def test_delete_audio_moderate_record_large_message_id( + self, async_bot_db, mock_factory + ): """Тест delete_audio_moderate_record с большим message_id""" message_id = 999999999 - + await async_bot_db.delete_audio_moderate_record(message_id) - - mock_factory.audio.delete_audio_moderate_record.assert_called_once_with(message_id) - + + mock_factory.audio.delete_audio_moderate_record.assert_called_once_with( + message_id + ) + @pytest.mark.asyncio async def test_set_user_blacklist_calls_history(self, async_bot_db, mock_factory): """Тест что set_user_blacklist вызывает добавление в историю""" @@ -118,21 +155,21 @@ class TestAsyncBotDB: message_for_user = "Нарушение правил" date_to_unban = 1234567890 ban_author = 999 - + await async_bot_db.set_user_blacklist( user_id=user_id, user_name=None, message_for_user=message_for_user, date_to_unban=date_to_unban, - ban_author=ban_author + ban_author=ban_author, ) - + # Проверяем, что сначала добавлен в blacklist mock_factory.blacklist.add_user.assert_called_once() - + # Проверяем, что затем добавлена запись в историю mock_factory.blacklist_history.add_record_on_ban.assert_called_once() - + # Проверяем параметры записи в историю history_call = mock_factory.blacklist_history.add_record_on_ban.call_args[0][0] assert history_call.user_id == user_id @@ -140,77 +177,89 @@ class TestAsyncBotDB: assert history_call.date_ban is not None assert history_call.date_unban is None assert history_call.ban_author == ban_author - + @pytest.mark.asyncio - async def test_set_user_blacklist_history_error_does_not_fail(self, async_bot_db, mock_factory): + async def test_set_user_blacklist_history_error_does_not_fail( + self, async_bot_db, mock_factory + ): """Тест что ошибка записи в историю не ломает процесс бана""" user_id = 12345 - mock_factory.blacklist_history.add_record_on_ban.side_effect = Exception("History error") - + mock_factory.blacklist_history.add_record_on_ban.side_effect = Exception( + "History error" + ) + # Бан должен пройти успешно, несмотря на ошибку в истории await async_bot_db.set_user_blacklist( user_id=user_id, message_for_user="Тест", date_to_unban=None, - ban_author=None + ban_author=None, ) - + # Проверяем, что пользователь все равно добавлен в blacklist mock_factory.blacklist.add_user.assert_called_once() - + # Проверяем, что попытка записи в историю была mock_factory.blacklist_history.add_record_on_ban.assert_called_once() - + @pytest.mark.asyncio - async def test_delete_user_blacklist_calls_history(self, async_bot_db, mock_factory): + async def test_delete_user_blacklist_calls_history( + self, async_bot_db, mock_factory + ): """Тест что delete_user_blacklist вызывает обновление истории""" user_id = 12345 - + result = await async_bot_db.delete_user_blacklist(user_id) - + # Проверяем, что сначала обновлена история mock_factory.blacklist_history.set_unban_date.assert_called_once() history_call = mock_factory.blacklist_history.set_unban_date.call_args assert history_call[0][0] == user_id assert history_call[0][1] is not None # date_unban timestamp - + # Проверяем, что затем удален из blacklist mock_factory.blacklist.remove_user.assert_called_once_with(user_id) - + # Проверяем результат assert result is True - + @pytest.mark.asyncio - async def test_delete_user_blacklist_history_error_does_not_fail(self, async_bot_db, mock_factory): + async def test_delete_user_blacklist_history_error_does_not_fail( + self, async_bot_db, mock_factory + ): """Тест что ошибка обновления истории не ломает процесс разбана""" user_id = 12345 - mock_factory.blacklist_history.set_unban_date.side_effect = Exception("History error") - + mock_factory.blacklist_history.set_unban_date.side_effect = Exception( + "History error" + ) + # Разбан должен пройти успешно, несмотря на ошибку в истории result = await async_bot_db.delete_user_blacklist(user_id) - + # Проверяем, что попытка обновления истории была mock_factory.blacklist_history.set_unban_date.assert_called_once() - + # Проверяем, что пользователь все равно удален из blacklist mock_factory.blacklist.remove_user.assert_called_once_with(user_id) - + # Проверяем результат assert result is True - + @pytest.mark.asyncio - async def test_delete_user_blacklist_returns_false_on_blacklist_error(self, async_bot_db, mock_factory): + async def test_delete_user_blacklist_returns_false_on_blacklist_error( + self, async_bot_db, mock_factory + ): """Тест что delete_user_blacklist возвращает False при ошибке удаления из blacklist""" user_id = 12345 mock_factory.blacklist.remove_user.return_value = False - + result = await async_bot_db.delete_user_blacklist(user_id) - + # Проверяем, что история обновлена mock_factory.blacklist_history.set_unban_date.assert_called_once() - + # Проверяем, что удаление из blacklist было попытка mock_factory.blacklist.remove_user.assert_called_once_with(user_id) - + # Проверяем результат assert result is False diff --git a/tests/test_audio_file_service.py b/tests/test_audio_file_service.py index 5452f47..7d298ed 100644 --- a/tests/test_audio_file_service.py +++ b/tests/test_audio_file_service.py @@ -3,8 +3,8 @@ from datetime import datetime from unittest.mock import AsyncMock, MagicMock, Mock, mock_open, patch import pytest -from helper_bot.handlers.voice.exceptions import (DatabaseError, - FileOperationError) + +from helper_bot.handlers.voice.exceptions import DatabaseError, FileOperationError from helper_bot.handlers.voice.services import AudioFileService @@ -17,16 +17,19 @@ def mock_bot_db(): mock_db.add_audio_record_simple = AsyncMock() return mock_db + @pytest.fixture def audio_service(mock_bot_db): """Экземпляр AudioFileService для тестов""" return AudioFileService(mock_bot_db) + @pytest.fixture def sample_datetime(): """Тестовая дата""" return datetime(2025, 1, 15, 14, 30, 0) + @pytest.fixture def mock_bot(): """Мок для бота""" @@ -35,6 +38,7 @@ def mock_bot(): bot.download_file = AsyncMock() return bot + @pytest.fixture def mock_message(): """Мок для сообщения""" @@ -43,6 +47,7 @@ def mock_message(): message.voice.file_id = "test_file_id" return message + @pytest.fixture def mock_file_info(): """Мок для информации о файле""" @@ -53,76 +58,92 @@ def mock_file_info(): class TestGenerateFileName: """Тесты для метода generate_file_name""" - + @pytest.mark.asyncio async def test_generate_file_name_first_record(self, audio_service, mock_bot_db): """Тест генерации имени файла для первой записи пользователя""" mock_bot_db.get_user_audio_records_count.return_value = 0 - + result = await audio_service.generate_file_name(12345) - + assert result == "message_from_12345_number_1" mock_bot_db.get_user_audio_records_count.assert_called_once_with(user_id=12345) - + @pytest.mark.asyncio - async def test_generate_file_name_existing_records(self, audio_service, mock_bot_db): + async def test_generate_file_name_existing_records( + self, audio_service, mock_bot_db + ): """Тест генерации имени файла для существующих записей""" mock_bot_db.get_user_audio_records_count.return_value = 3 - mock_bot_db.get_path_for_audio_record.return_value = "message_from_12345_number_3" - + mock_bot_db.get_path_for_audio_record.return_value = ( + "message_from_12345_number_3" + ) + result = await audio_service.generate_file_name(12345) - + assert result == "message_from_12345_number_4" mock_bot_db.get_user_audio_records_count.assert_called_once_with(user_id=12345) mock_bot_db.get_path_for_audio_record.assert_called_once_with(user_id=12345) - + @pytest.mark.asyncio async def test_generate_file_name_no_last_record(self, audio_service, mock_bot_db): """Тест генерации имени файла когда нет последней записи""" mock_bot_db.get_user_audio_records_count.return_value = 2 mock_bot_db.get_path_for_audio_record.return_value = None - + result = await audio_service.generate_file_name(12345) - + assert result == "message_from_12345_number_3" - + @pytest.mark.asyncio - async def test_generate_file_name_invalid_last_record_format(self, audio_service, mock_bot_db): + async def test_generate_file_name_invalid_last_record_format( + self, audio_service, mock_bot_db + ): """Тест генерации имени файла с некорректным форматом последней записи""" mock_bot_db.get_user_audio_records_count.return_value = 2 mock_bot_db.get_path_for_audio_record.return_value = "invalid_format" - + result = await audio_service.generate_file_name(12345) - + assert result == "message_from_12345_number_3" - + @pytest.mark.asyncio - async def test_generate_file_name_exception_handling(self, audio_service, mock_bot_db): + async def test_generate_file_name_exception_handling( + self, audio_service, mock_bot_db + ): """Тест обработки исключений при генерации имени файла""" - mock_bot_db.get_user_audio_records_count.side_effect = Exception("Database error") - + mock_bot_db.get_user_audio_records_count.side_effect = Exception( + "Database error" + ) + with pytest.raises(FileOperationError) as exc_info: await audio_service.generate_file_name(12345) - + assert "Не удалось сгенерировать имя файла" in str(exc_info.value) class TestSaveAudioFile: """Тесты для метода save_audio_file""" - + @pytest.mark.asyncio - async def test_save_audio_file_success(self, audio_service, mock_bot_db, sample_datetime): + async def test_save_audio_file_success( + self, audio_service, mock_bot_db, sample_datetime + ): """Тест успешного сохранения аудио файла""" file_name = "test_audio" user_id = 12345 file_id = "test_file_id" - + # Мокаем verify_file_exists чтобы он возвращал True - with patch.object(audio_service, 'verify_file_exists', return_value=True): - await audio_service.save_audio_file(file_name, user_id, sample_datetime, file_id) - - mock_bot_db.add_audio_record_simple.assert_called_once_with(file_name, user_id, sample_datetime) - + with patch.object(audio_service, "verify_file_exists", return_value=True): + await audio_service.save_audio_file( + file_name, user_id, sample_datetime, file_id + ) + + mock_bot_db.add_audio_record_simple.assert_called_once_with( + file_name, user_id, sample_datetime + ) + @pytest.mark.asyncio async def test_save_audio_file_with_string_date(self, audio_service, mock_bot_db): """Тест сохранения аудио файла со строковой датой""" @@ -130,149 +151,196 @@ class TestSaveAudioFile: user_id = 12345 date_string = "2025-01-15 14:30:00" file_id = "test_file_id" - + # Мокаем verify_file_exists чтобы он возвращал True - with patch.object(audio_service, 'verify_file_exists', return_value=True): - await audio_service.save_audio_file(file_name, user_id, date_string, file_id) - - mock_bot_db.add_audio_record_simple.assert_called_once_with(file_name, user_id, date_string) - + with patch.object(audio_service, "verify_file_exists", return_value=True): + await audio_service.save_audio_file( + file_name, user_id, date_string, file_id + ) + + mock_bot_db.add_audio_record_simple.assert_called_once_with( + file_name, user_id, date_string + ) + @pytest.mark.asyncio - async def test_save_audio_file_exception_handling(self, audio_service, mock_bot_db, sample_datetime): + async def test_save_audio_file_exception_handling( + self, audio_service, mock_bot_db, sample_datetime + ): """Тест обработки исключений при сохранении аудио файла""" mock_bot_db.add_audio_record_simple.side_effect = Exception("Database error") - + # Мокаем verify_file_exists чтобы он возвращал True - with patch.object(audio_service, 'verify_file_exists', return_value=True): + with patch.object(audio_service, "verify_file_exists", return_value=True): with pytest.raises(DatabaseError) as exc_info: - await audio_service.save_audio_file("test", 12345, sample_datetime, "file_id") - + await audio_service.save_audio_file( + "test", 12345, sample_datetime, "file_id" + ) + assert "Не удалось сохранить аудио файл в БД" in str(exc_info.value) class TestDownloadAndSaveAudio: """Тесты для метода download_and_save_audio""" - + @pytest.mark.asyncio - async def test_download_and_save_audio_success(self, audio_service, mock_bot, mock_message, mock_file_info): + async def test_download_and_save_audio_success( + self, audio_service, mock_bot, mock_message, mock_file_info + ): """Тест успешного скачивания и сохранения аудио""" mock_bot.get_file.return_value = mock_file_info - + # Мокаем скачанный файл mock_downloaded_file = Mock() mock_downloaded_file.tell.return_value = 0 mock_downloaded_file.seek = Mock() mock_downloaded_file.read.return_value = b"audio_data" - + # Настраиваем поведение tell() для получения размера файла def mock_tell(): return 0 if mock_downloaded_file.seek.call_count == 0 else 1024 + mock_downloaded_file.tell = Mock(side_effect=mock_tell) - + mock_bot.download_file.return_value = mock_downloaded_file - - with patch('builtins.open', mock_open()) as mock_file: - with patch('os.makedirs'): - with patch('os.path.exists', return_value=True): - with patch('os.path.getsize', return_value=1024): - await audio_service.download_and_save_audio(mock_bot, mock_message, "test_audio") - - mock_bot.get_file.assert_called_once_with(file_id="test_file_id") - mock_bot.download_file.assert_called_once_with(file_path="voice/test_file_id.ogg") + + with patch("builtins.open", mock_open()) as mock_file: + with patch("os.makedirs"): + with patch("os.path.exists", return_value=True): + with patch("os.path.getsize", return_value=1024): + await audio_service.download_and_save_audio( + mock_bot, mock_message, "test_audio" + ) + + mock_bot.get_file.assert_called_once_with( + file_id="test_file_id" + ) + mock_bot.download_file.assert_called_once_with( + file_path="voice/test_file_id.ogg" + ) mock_file.assert_called_once() - + @pytest.mark.asyncio async def test_download_and_save_audio_no_message(self, audio_service, mock_bot): - """Тест скачивания когда сообщение отсутствует""" - with pytest.raises(FileOperationError) as exc_info: - await audio_service.download_and_save_audio(mock_bot, None, "test_audio") - + """Тест скачивания когда сообщение отсутствует.""" + with patch( + "helper_bot.handlers.voice.services.asyncio.sleep", new_callable=AsyncMock + ): + with pytest.raises(FileOperationError) as exc_info: + await audio_service.download_and_save_audio( + mock_bot, None, "test_audio" + ) + assert "Сообщение или голосовое сообщение не найдено" in str(exc_info.value) - + @pytest.mark.asyncio async def test_download_and_save_audio_no_voice(self, audio_service, mock_bot): - """Тест скачивания когда у сообщения нет voice атрибута""" + """Тест скачивания когда у сообщения нет voice атрибута.""" message = Mock() message.voice = None - - with pytest.raises(FileOperationError) as exc_info: - await audio_service.download_and_save_audio(mock_bot, message, "test_audio") - + + with patch( + "helper_bot.handlers.voice.services.asyncio.sleep", new_callable=AsyncMock + ): + with pytest.raises(FileOperationError) as exc_info: + await audio_service.download_and_save_audio( + mock_bot, message, "test_audio" + ) + assert "Сообщение или голосовое сообщение не найдено" in str(exc_info.value) - + @pytest.mark.asyncio - async def test_download_and_save_audio_download_failed(self, audio_service, mock_bot, mock_message, mock_file_info): - """Тест скачивания когда загрузка не удалась""" + async def test_download_and_save_audio_download_failed( + self, audio_service, mock_bot, mock_message, mock_file_info + ): + """Тест скачивания когда загрузка не удалась.""" mock_bot.get_file.return_value = mock_file_info mock_bot.download_file.return_value = None - - with pytest.raises(FileOperationError) as exc_info: - await audio_service.download_and_save_audio(mock_bot, mock_message, "test_audio") - + + with patch( + "helper_bot.handlers.voice.services.asyncio.sleep", new_callable=AsyncMock + ): + with pytest.raises(FileOperationError) as exc_info: + await audio_service.download_and_save_audio( + mock_bot, mock_message, "test_audio" + ) + assert "Не удалось скачать файл" in str(exc_info.value) - + @pytest.mark.asyncio - async def test_download_and_save_audio_exception_handling(self, audio_service, mock_bot, mock_message): - """Тест обработки исключений при скачивании""" + async def test_download_and_save_audio_exception_handling( + self, audio_service, mock_bot, mock_message + ): + """Тест обработки исключений при скачивании.""" mock_bot.get_file.side_effect = Exception("Network error") - - with pytest.raises(FileOperationError) as exc_info: - await audio_service.download_and_save_audio(mock_bot, mock_message, "test_audio") - + + with patch( + "helper_bot.handlers.voice.services.asyncio.sleep", new_callable=AsyncMock + ): + with pytest.raises(FileOperationError) as exc_info: + await audio_service.download_and_save_audio( + mock_bot, mock_message, "test_audio" + ) + assert "Не удалось скачать и сохранить аудио" in str(exc_info.value) - - class TestAudioFileServiceIntegration: """Интеграционные тесты для AudioFileService""" - + @pytest.mark.asyncio async def test_full_audio_processing_workflow(self, mock_bot_db): """Тест полного рабочего процесса обработки аудио""" service = AudioFileService(mock_bot_db) - + # Настраиваем моки mock_bot_db.get_user_audio_records_count.return_value = 1 - mock_bot_db.get_path_for_audio_record.return_value = "message_from_12345_number_1" + mock_bot_db.get_path_for_audio_record.return_value = ( + "message_from_12345_number_1" + ) mock_bot_db.add_audio_record_simple = AsyncMock() - + # Тестируем генерацию имени файла file_name = await service.generate_file_name(12345) assert file_name == "message_from_12345_number_2" - + # Тестируем сохранение в БД test_date = datetime.now() - with patch.object(service, 'verify_file_exists', return_value=True): + with patch.object(service, "verify_file_exists", return_value=True): await service.save_audio_file(file_name, 12345, test_date, "test_file_id") - + # Проверяем вызовы mock_bot_db.get_user_audio_records_count.assert_called_once_with(user_id=12345) mock_bot_db.get_path_for_audio_record.assert_called_once_with(user_id=12345) - mock_bot_db.add_audio_record_simple.assert_called_once_with(file_name, 12345, test_date) - + mock_bot_db.add_audio_record_simple.assert_called_once_with( + file_name, 12345, test_date + ) + @pytest.mark.asyncio async def test_file_name_generation_sequence(self, mock_bot_db): """Тест последовательности генерации имен файлов""" service = AudioFileService(mock_bot_db) - + # Первая запись mock_bot_db.get_user_audio_records_count.return_value = 0 file_name_1 = await service.generate_file_name(12345) assert file_name_1 == "message_from_12345_number_1" - + # Вторая запись mock_bot_db.get_user_audio_records_count.return_value = 1 - mock_bot_db.get_path_for_audio_record.return_value = "message_from_12345_number_1" + mock_bot_db.get_path_for_audio_record.return_value = ( + "message_from_12345_number_1" + ) file_name_2 = await service.generate_file_name(12345) assert file_name_2 == "message_from_12345_number_2" - + # Третья запись mock_bot_db.get_user_audio_records_count.return_value = 2 - mock_bot_db.get_path_for_audio_record.return_value = "message_from_12345_number_2" + mock_bot_db.get_path_for_audio_record.return_value = ( + "message_from_12345_number_2" + ) file_name_3 = await service.generate_file_name(12345) assert file_name_3 == "message_from_12345_number_3" -if __name__ == '__main__': +if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/test_audio_repository.py b/tests/test_audio_repository.py index 37fef5a..5ed86fe 100644 --- a/tests/test_audio_repository.py +++ b/tests/test_audio_repository.py @@ -1,15 +1,16 @@ import time -from datetime import datetime +from datetime import datetime, timezone from unittest.mock import AsyncMock, MagicMock, Mock, patch import pytest + from database.models import AudioListenRecord, AudioMessage, AudioModerate from database.repositories.audio_repository import AudioRepository class TestAudioRepository: """Тесты для AudioRepository""" - + @pytest.fixture def mock_db_connection(self): """Мок для DatabaseConnection""" @@ -18,18 +19,20 @@ class TestAudioRepository: mock_connection._execute_query_with_result = AsyncMock() mock_connection.logger = Mock() return mock_connection - + @pytest.fixture def audio_repository(self, mock_db_connection): """Экземпляр AudioRepository для тестов""" # Патчим наследование от DatabaseConnection - with patch.object(AudioRepository, '__init__', return_value=None): + with patch.object(AudioRepository, "__init__", return_value=None): repo = AudioRepository() repo._execute_query = mock_db_connection._execute_query - repo._execute_query_with_result = mock_db_connection._execute_query_with_result + repo._execute_query_with_result = ( + mock_db_connection._execute_query_with_result + ) repo.logger = mock_db_connection.logger return repo - + @pytest.fixture def sample_audio_message(self): """Тестовое аудио сообщение""" @@ -38,45 +41,49 @@ class TestAudioRepository: author_id=12345, date_added="2025-01-15 14:30:00", file_id="test_file_id", - listen_count=0 + listen_count=0, ) - + @pytest.fixture def sample_datetime(self): """Тестовая дата""" return datetime(2025, 1, 15, 14, 30, 0) - + @pytest.fixture def sample_timestamp(self): """Тестовый UNIX timestamp""" return int(time.mktime(datetime(2025, 1, 15, 14, 30, 0).timetuple())) - + @pytest.mark.asyncio async def test_enable_foreign_keys(self, audio_repository): """Тест включения внешних ключей""" await audio_repository.enable_foreign_keys() - - audio_repository._execute_query.assert_called_once_with("PRAGMA foreign_keys = ON;") - + + audio_repository._execute_query.assert_called_once_with( + "PRAGMA foreign_keys = ON;" + ) + @pytest.mark.asyncio async def test_create_tables(self, audio_repository): """Тест создания таблиц""" await audio_repository.create_tables() - + # Проверяем, что все три таблицы созданы assert audio_repository._execute_query.call_count == 3 - + # Проверяем вызовы для каждой таблицы calls = audio_repository._execute_query.call_args_list assert any("audio_message_reference" in str(call) for call in calls) assert any("user_audio_listens" in str(call) for call in calls) assert any("audio_moderate" in str(call) for call in calls) - + @pytest.mark.asyncio - async def test_add_audio_record_with_string_date(self, audio_repository, sample_audio_message): + async def test_add_audio_record_with_string_date( + self, audio_repository, sample_audio_message + ): """Тест добавления аудио записи со строковой датой""" await audio_repository.add_audio_record(sample_audio_message) - + # Проверяем, что метод вызван с правильными параметрами audio_repository._execute_query.assert_called_once() call_args = audio_repository._execute_query.call_args @@ -88,7 +95,7 @@ class TestAudioRepository: assert call_args[0][1][0] == "test_audio_123.ogg" assert call_args[0][1][1] == 12345 assert isinstance(call_args[0][1][2], int) # timestamp - + @pytest.mark.asyncio async def test_add_audio_record_with_datetime_date(self, audio_repository): """Тест добавления аудио записи с datetime датой""" @@ -97,15 +104,15 @@ class TestAudioRepository: author_id=67890, date_added=datetime(2025, 1, 20, 10, 15, 0), file_id="test_file_id_2", - listen_count=0 + listen_count=0, ) - + await audio_repository.add_audio_record(audio_msg) - + # Проверяем, что date_added преобразован в timestamp call_args = audio_repository._execute_query.call_args assert isinstance(call_args[0][1][2], int) # timestamp - + @pytest.mark.asyncio async def test_add_audio_record_with_timestamp_date(self, audio_repository): """Тест добавления аудио записи с timestamp датой""" @@ -115,292 +122,321 @@ class TestAudioRepository: author_id=11111, date_added=timestamp, file_id="test_file_id_3", - listen_count=0 + listen_count=0, ) - + await audio_repository.add_audio_record(audio_msg) - + # Проверяем, что date_added остался timestamp call_args = audio_repository._execute_query.call_args assert call_args[0][1][2] == timestamp - + @pytest.mark.asyncio async def test_add_audio_record_simple_with_string_date(self, audio_repository): """Тест упрощенного добавления аудио записи со строковой датой""" - await audio_repository.add_audio_record_simple("test_audio.ogg", 12345, "2025-01-15 14:30:00") - + await audio_repository.add_audio_record_simple( + "test_audio.ogg", 12345, "2025-01-15 14:30:00" + ) + # Проверяем, что метод вызван audio_repository._execute_query.assert_called_once() call_args = audio_repository._execute_query.call_args assert call_args[0][1][0] == "test_audio.ogg" # file_name assert call_args[0][1][1] == 12345 # user_id assert isinstance(call_args[0][1][2], int) # timestamp - + @pytest.mark.asyncio - async def test_add_audio_record_simple_with_datetime_date(self, audio_repository, sample_datetime): + async def test_add_audio_record_simple_with_datetime_date( + self, audio_repository, sample_datetime + ): """Тест упрощенного добавления аудио записи с datetime датой""" - await audio_repository.add_audio_record_simple("test_audio.ogg", 12345, sample_datetime) - + await audio_repository.add_audio_record_simple( + "test_audio.ogg", 12345, sample_datetime + ) + # Проверяем, что date_added преобразован в timestamp call_args = audio_repository._execute_query.call_args assert isinstance(call_args[0][1][2], int) # timestamp - + @pytest.mark.asyncio async def test_get_last_date_audio(self, audio_repository): """Тест получения даты последнего аудио""" expected_timestamp = 1642248600 # 2022-01-17 10:30:00 - audio_repository._execute_query_with_result.return_value = [(expected_timestamp,)] - + audio_repository._execute_query_with_result.return_value = [ + (expected_timestamp,) + ] + result = await audio_repository.get_last_date_audio() - + assert result == expected_timestamp audio_repository._execute_query_with_result.assert_called_once_with( "SELECT date_added FROM audio_message_reference ORDER BY date_added DESC LIMIT 1" ) - + @pytest.mark.asyncio async def test_get_last_date_audio_no_records(self, audio_repository): """Тест получения даты последнего аудио когда записей нет""" audio_repository._execute_query_with_result.return_value = [] - + result = await audio_repository.get_last_date_audio() - + assert result is None - + @pytest.mark.asyncio async def test_get_user_audio_records_count(self, audio_repository): """Тест получения количества аудио записей пользователя""" audio_repository._execute_query_with_result.return_value = [(5,)] - + result = await audio_repository.get_user_audio_records_count(12345) - + assert result == 5 audio_repository._execute_query_with_result.assert_called_once_with( "SELECT COUNT(*) FROM audio_message_reference WHERE author_id = ?", (12345,) ) - + @pytest.mark.asyncio async def test_get_path_for_audio_record(self, audio_repository): """Тест получения пути к аудио записи пользователя""" audio_repository._execute_query_with_result.return_value = [("test_audio.ogg",)] - + result = await audio_repository.get_path_for_audio_record(12345) - + assert result == "test_audio.ogg" audio_repository._execute_query_with_result.assert_called_once_with( """ SELECT file_name FROM audio_message_reference WHERE author_id = ? ORDER BY date_added DESC LIMIT 1 - """, (12345,) + """, + (12345,), ) - + @pytest.mark.asyncio async def test_get_path_for_audio_record_no_records(self, audio_repository): """Тест получения пути к аудио записи когда записей нет""" audio_repository._execute_query_with_result.return_value = [] - + result = await audio_repository.get_path_for_audio_record(12345) - + assert result is None - + @pytest.mark.asyncio async def test_check_listen_audio(self, audio_repository): """Тест проверки непрослушанных аудио""" # Мокаем результаты запросов audio_repository._execute_query_with_result.side_effect = [ [("audio1.ogg",), ("audio2.ogg",)], # прослушанные - [("audio1.ogg",), ("audio2.ogg",), ("audio3.ogg",)] # все аудио + [("audio1.ogg",), ("audio2.ogg",), ("audio3.ogg",)], # все аудио ] - + result = await audio_repository.check_listen_audio(12345) - + # Должно вернуться только непрослушанные (audio3.ogg) assert result == ["audio3.ogg"] assert audio_repository._execute_query_with_result.call_count == 2 - + @pytest.mark.asyncio async def test_mark_listened_audio(self, audio_repository): """Тест отметки аудио как прослушанного""" await audio_repository.mark_listened_audio("test_audio.ogg", 12345) - + audio_repository._execute_query.assert_called_once_with( "INSERT OR IGNORE INTO user_audio_listens (file_name, user_id) VALUES (?, ?)", - ("test_audio.ogg", 12345) + ("test_audio.ogg", 12345), ) - + @pytest.mark.asyncio async def test_get_user_id_by_file_name(self, audio_repository): """Тест получения user_id по имени файла""" audio_repository._execute_query_with_result.return_value = [(12345,)] - + result = await audio_repository.get_user_id_by_file_name("test_audio.ogg") - + assert result == 12345 audio_repository._execute_query_with_result.assert_called_once_with( - "SELECT author_id FROM audio_message_reference WHERE file_name = ?", ("test_audio.ogg",) + "SELECT author_id FROM audio_message_reference WHERE file_name = ?", + ("test_audio.ogg",), ) - + @pytest.mark.asyncio async def test_get_user_id_by_file_name_not_found(self, audio_repository): """Тест получения user_id по имени файла когда файл не найден""" audio_repository._execute_query_with_result.return_value = [] - + result = await audio_repository.get_user_id_by_file_name("nonexistent.ogg") - + assert result is None - + @pytest.mark.asyncio async def test_get_date_by_file_name(self, audio_repository): - """Тест получения даты по имени файла""" - timestamp = 1642404600 # 2022-01-17 10:30:00 + """Тест получения даты по имени файла (UTC, без зависимости от локали).""" + timestamp = 1642404600 # 2022-01-17 10:30:00 UTC audio_repository._execute_query_with_result.return_value = [(timestamp,)] - + result = await audio_repository.get_date_by_file_name("test_audio.ogg") - - # Должна вернуться читаемая дата - assert result == "17.01.2022 10:30" - audio_repository._execute_query_with_result.assert_called_once_with( - "SELECT date_added FROM audio_message_reference WHERE file_name = ?", ("test_audio.ogg",) + + expected = datetime.fromtimestamp(timestamp, tz=timezone.utc).strftime( + "%d.%m.%Y %H:%M" ) - + assert result == expected + audio_repository._execute_query_with_result.assert_called_once_with( + "SELECT date_added FROM audio_message_reference WHERE file_name = ?", + ("test_audio.ogg",), + ) + @pytest.mark.asyncio async def test_get_date_by_file_name_not_found(self, audio_repository): """Тест получения даты по имени файла когда файл не найден""" audio_repository._execute_query_with_result.return_value = [] - + result = await audio_repository.get_date_by_file_name("nonexistent.ogg") - + assert result is None - + @pytest.mark.asyncio async def test_refresh_listen_audio(self, audio_repository): """Тест очистки записей прослушивания пользователя""" await audio_repository.refresh_listen_audio(12345) - + audio_repository._execute_query.assert_called_once_with( "DELETE FROM user_audio_listens WHERE user_id = ?", (12345,) ) - + @pytest.mark.asyncio async def test_delete_listen_count_for_user(self, audio_repository): """Тест удаления данных о прослушанных аудио пользователя""" await audio_repository.delete_listen_count_for_user(12345) - + audio_repository._execute_query.assert_called_once_with( "DELETE FROM user_audio_listens WHERE user_id = ?", (12345,) ) - + @pytest.mark.asyncio - async def test_set_user_id_and_message_id_for_voice_bot_success(self, audio_repository): + async def test_set_user_id_and_message_id_for_voice_bot_success( + self, audio_repository + ): """Тест успешной установки связи для voice bot""" - result = await audio_repository.set_user_id_and_message_id_for_voice_bot(123, 456) - + result = await audio_repository.set_user_id_and_message_id_for_voice_bot( + 123, 456 + ) + assert result is True audio_repository._execute_query.assert_called_once_with( "INSERT OR IGNORE INTO audio_moderate (user_id, message_id) VALUES (?, ?)", - (456, 123) + (456, 123), ) - + @pytest.mark.asyncio - async def test_set_user_id_and_message_id_for_voice_bot_exception(self, audio_repository): + async def test_set_user_id_and_message_id_for_voice_bot_exception( + self, audio_repository + ): """Тест установки связи для voice bot при ошибке""" audio_repository._execute_query.side_effect = Exception("Database error") - - result = await audio_repository.set_user_id_and_message_id_for_voice_bot(123, 456) - + + result = await audio_repository.set_user_id_and_message_id_for_voice_bot( + 123, 456 + ) + assert result is False - + @pytest.mark.asyncio async def test_get_user_id_by_message_id_for_voice_bot(self, audio_repository): """Тест получения user_id по message_id для voice bot""" audio_repository._execute_query_with_result.return_value = [(456,)] - + result = await audio_repository.get_user_id_by_message_id_for_voice_bot(123) - + assert result == 456 audio_repository._execute_query_with_result.assert_called_once_with( "SELECT user_id FROM audio_moderate WHERE message_id = ?", (123,) ) - + @pytest.mark.asyncio - async def test_get_user_id_by_message_id_for_voice_bot_not_found(self, audio_repository): + async def test_get_user_id_by_message_id_for_voice_bot_not_found( + self, audio_repository + ): """Тест получения user_id по message_id когда связь не найдена""" audio_repository._execute_query_with_result.return_value = [] - + result = await audio_repository.get_user_id_by_message_id_for_voice_bot(123) - + assert result is None - + @pytest.mark.asyncio async def test_delete_audio_moderate_record(self, audio_repository): """Тест удаления записи из таблицы audio_moderate""" message_id = 12345 - + await audio_repository.delete_audio_moderate_record(message_id) - + audio_repository._execute_query.assert_called_once_with( "DELETE FROM audio_moderate WHERE message_id = ?", (message_id,) ) audio_repository.logger.info.assert_called_once_with( f"Удалена запись из audio_moderate для message_id {message_id}" ) - + @pytest.mark.asyncio - async def test_add_audio_record_logging(self, audio_repository, sample_audio_message): + async def test_add_audio_record_logging( + self, audio_repository, sample_audio_message + ): """Тест логирования при добавлении аудио записи""" await audio_repository.add_audio_record(sample_audio_message) - + # Проверяем, что лог записан audio_repository.logger.info.assert_called_once() log_message = audio_repository.logger.info.call_args[0][0] assert "Аудио добавлено" in log_message assert "test_audio_123.ogg" in log_message assert "12345" in log_message - + @pytest.mark.asyncio async def test_add_audio_record_simple_logging(self, audio_repository): """Тест логирования при упрощенном добавлении аудио записи""" - await audio_repository.add_audio_record_simple("test_audio.ogg", 12345, "2025-01-15 14:30:00") - + await audio_repository.add_audio_record_simple( + "test_audio.ogg", 12345, "2025-01-15 14:30:00" + ) + # Проверяем, что лог записан audio_repository.logger.info.assert_called_once() log_message = audio_repository.logger.info.call_args[0][0] assert "Аудио добавлено" in log_message assert "test_audio.ogg" in log_message assert "12345" in log_message - + @pytest.mark.asyncio async def test_get_date_by_file_name_logging(self, audio_repository): - """Тест логирования при получении даты по имени файла""" - timestamp = 1642404600 # 2022-01-17 10:30:00 + """Тест логирования при получении даты по имени файла (UTC).""" + timestamp = 1642404600 # 2022-01-17 10:30:00 UTC audio_repository._execute_query_with_result.return_value = [(timestamp,)] - + await audio_repository.get_date_by_file_name("test_audio.ogg") - - # Проверяем, что лог записан + + expected = datetime.fromtimestamp(timestamp, tz=timezone.utc).strftime( + "%d.%m.%Y %H:%M" + ) audio_repository.logger.info.assert_called_once() log_message = audio_repository.logger.info.call_args[0][0] assert "Получена дата" in log_message - assert "17.01.2022 10:30" in log_message + assert expected in log_message assert "test_audio.ogg" in log_message class TestAudioRepositoryIntegration: """Интеграционные тесты для AudioRepository""" - + @pytest.fixture def real_audio_repository(self): """Реальный экземпляр AudioRepository для интеграционных тестов""" # Здесь можно создать реальное подключение к тестовой БД # Но для простоты используем мок return Mock() - + @pytest.mark.asyncio async def test_full_audio_workflow(self, real_audio_repository): """Тест полного рабочего процесса с аудио""" # Этот тест можно расширить для реальной БД assert True # Placeholder для будущих интеграционных тестов - + @pytest.mark.asyncio async def test_foreign_keys_enabled(self, real_audio_repository): """Тест что внешние ключи включены""" diff --git a/tests/test_audio_repository_schema.py b/tests/test_audio_repository_schema.py index b7428ea..ed57604 100644 --- a/tests/test_audio_repository_schema.py +++ b/tests/test_audio_repository_schema.py @@ -1,14 +1,15 @@ import time -from datetime import datetime +from datetime import datetime, timezone from unittest.mock import AsyncMock, MagicMock, Mock, patch import pytest + from database.repositories.audio_repository import AudioRepository class TestAudioRepositoryNewSchema: """Тесты для AudioRepository с новой схемой БД""" - + @pytest.fixture def mock_db_connection(self): """Мок для DatabaseConnection""" @@ -17,327 +18,367 @@ class TestAudioRepositoryNewSchema: mock_connection._execute_query_with_result = AsyncMock() mock_connection.logger = Mock() return mock_connection - + @pytest.fixture def audio_repository(self, mock_db_connection): """Экземпляр AudioRepository для тестов""" - with patch.object(AudioRepository, '__init__', return_value=None): + with patch.object(AudioRepository, "__init__", return_value=None): repo = AudioRepository() repo._execute_query = mock_db_connection._execute_query - repo._execute_query_with_result = mock_db_connection._execute_query_with_result + repo._execute_query_with_result = ( + mock_db_connection._execute_query_with_result + ) repo.logger = mock_db_connection.logger return repo - + @pytest.mark.asyncio async def test_create_tables_new_schema(self, audio_repository): """Тест создания таблиц с новой схемой БД""" await audio_repository.create_tables() - + # Проверяем, что все три таблицы созданы assert audio_repository._execute_query.call_count == 3 - + # Получаем все вызовы calls = audio_repository._execute_query.call_args_list - + # Проверяем таблицу audio_message_reference - audio_table_call = next(call for call in calls if "audio_message_reference" in str(call)) + audio_table_call = next( + call for call in calls if "audio_message_reference" in str(call) + ) assert "id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT" in str(audio_table_call) assert "file_name TEXT NOT NULL UNIQUE" in str(audio_table_call) assert "author_id INTEGER NOT NULL" in str(audio_table_call) assert "date_added INTEGER NOT NULL" in str(audio_table_call) - assert "FOREIGN KEY (author_id) REFERENCES our_users (user_id) ON DELETE CASCADE" in str(audio_table_call) - + assert ( + "FOREIGN KEY (author_id) REFERENCES our_users (user_id) ON DELETE CASCADE" + in str(audio_table_call) + ) + # Проверяем таблицу user_audio_listens - listens_table_call = next(call for call in calls if "user_audio_listens" in str(call)) + listens_table_call = next( + call for call in calls if "user_audio_listens" in str(call) + ) assert "file_name TEXT NOT NULL" in str(listens_table_call) assert "user_id INTEGER NOT NULL" in str(listens_table_call) assert "PRIMARY KEY (file_name, user_id)" in str(listens_table_call) - assert "FOREIGN KEY (user_id) REFERENCES our_users (user_id) ON DELETE CASCADE" in str(listens_table_call) - + assert ( + "FOREIGN KEY (user_id) REFERENCES our_users (user_id) ON DELETE CASCADE" + in str(listens_table_call) + ) + # Проверяем таблицу audio_moderate - moderate_table_call = next(call for call in calls if "audio_moderate" in str(call)) + moderate_table_call = next( + call for call in calls if "audio_moderate" in str(call) + ) assert "user_id INTEGER NOT NULL" in str(moderate_table_call) assert "message_id INTEGER" in str(moderate_table_call) assert "PRIMARY KEY (user_id, message_id)" in str(moderate_table_call) - assert "FOREIGN KEY (user_id) REFERENCES our_users (user_id) ON DELETE CASCADE" in str(moderate_table_call) - + assert ( + "FOREIGN KEY (user_id) REFERENCES our_users (user_id) ON DELETE CASCADE" + in str(moderate_table_call) + ) + @pytest.mark.asyncio async def test_add_audio_record_string_date_conversion(self, audio_repository): """Тест преобразования строковой даты в UNIX timestamp""" from database.models import AudioMessage - + audio_msg = AudioMessage( file_name="test_audio.ogg", author_id=12345, date_added="2025-01-15 14:30:00", file_id="test_file_id", - listen_count=0 + listen_count=0, ) - + await audio_repository.add_audio_record(audio_msg) - + # Проверяем, что метод вызван call_args = audio_repository._execute_query.call_args params = call_args[0][1] - + # Проверяем параметры assert params[0] == "test_audio.ogg" assert params[1] == 12345 assert isinstance(params[2], int) # timestamp - + # Проверяем, что timestamp соответствует дате expected_timestamp = int(datetime(2025, 1, 15, 14, 30, 0).timestamp()) assert params[2] == expected_timestamp - + @pytest.mark.asyncio async def test_add_audio_record_datetime_conversion(self, audio_repository): """Тест преобразования datetime в UNIX timestamp""" from database.models import AudioMessage - + test_datetime = datetime(2025, 1, 20, 10, 15, 30) audio_msg = AudioMessage( file_name="test_audio.ogg", author_id=12345, date_added=test_datetime, file_id="test_file_id", - listen_count=0 + listen_count=0, ) - + await audio_repository.add_audio_record(audio_msg) - + # Проверяем параметры call_args = audio_repository._execute_query.call_args params = call_args[0][1] - + expected_timestamp = int(test_datetime.timestamp()) assert params[2] == expected_timestamp - + @pytest.mark.asyncio async def test_add_audio_record_timestamp_no_conversion(self, audio_repository): """Тест что timestamp остается timestamp без преобразования""" from database.models import AudioMessage - + test_timestamp = int(time.time()) audio_msg = AudioMessage( file_name="test_audio.ogg", author_id=12345, date_added=test_timestamp, file_id="test_file_id", - listen_count=0 + listen_count=0, ) - + await audio_repository.add_audio_record(audio_msg) - + # Проверяем параметры call_args = audio_repository._execute_query.call_args params = call_args[0][1] - + assert params[2] == test_timestamp - + @pytest.mark.asyncio async def test_add_audio_record_simple_string_date(self, audio_repository): """Тест упрощенного добавления со строковой датой""" - await audio_repository.add_audio_record_simple("test_audio.ogg", 12345, "2025-01-15 14:30:00") - + await audio_repository.add_audio_record_simple( + "test_audio.ogg", 12345, "2025-01-15 14:30:00" + ) + # Проверяем параметры call_args = audio_repository._execute_query.call_args params = call_args[0][1] - + assert params[0] == "test_audio.ogg" assert params[1] == 12345 assert isinstance(params[2], int) # timestamp - + # Проверяем timestamp expected_timestamp = int(datetime(2025, 1, 15, 14, 30, 0).timestamp()) assert params[2] == expected_timestamp - + @pytest.mark.asyncio async def test_add_audio_record_simple_datetime(self, audio_repository): """Тест упрощенного добавления с datetime""" test_datetime = datetime(2025, 1, 25, 16, 45, 0) - await audio_repository.add_audio_record_simple("test_audio.ogg", 12345, test_datetime) - + await audio_repository.add_audio_record_simple( + "test_audio.ogg", 12345, test_datetime + ) + # Проверяем параметры call_args = audio_repository._execute_query.call_args params = call_args[0][1] - + expected_timestamp = int(test_datetime.timestamp()) assert params[2] == expected_timestamp - + @pytest.mark.asyncio async def test_get_date_by_file_name_timestamp_conversion(self, audio_repository): - """Тест преобразования UNIX timestamp в читаемую дату""" - test_timestamp = 1642248600 # 2022-01-17 10:30:00 + """Тест преобразования UNIX timestamp в читаемую дату (UTC).""" + test_timestamp = 1642248600 # 2022-01-15 12:10:00 UTC audio_repository._execute_query_with_result.return_value = [(test_timestamp,)] - + result = await audio_repository.get_date_by_file_name("test_audio.ogg") - - # Должна вернуться читаемая дата в формате dd.mm.yyyy HH:MM - assert result == "15.01.2022 15:10" + + expected = datetime.fromtimestamp(test_timestamp, tz=timezone.utc).strftime( + "%d.%m.%Y %H:%M" + ) + assert result == expected assert isinstance(result, str) - + @pytest.mark.asyncio async def test_get_date_by_file_name_different_timestamp(self, audio_repository): - """Тест преобразования другого timestamp в читаемую дату""" - test_timestamp = 1705312800 # 2024-01-16 12:00:00 + """Тест преобразования другого timestamp в читаемую дату (UTC).""" + test_timestamp = 1705312800 # 2024-01-16 12:00:00 UTC audio_repository._execute_query_with_result.return_value = [(test_timestamp,)] - + result = await audio_repository.get_date_by_file_name("test_audio.ogg") - - assert result == "15.01.2024 13:00" - + + expected = datetime.fromtimestamp(test_timestamp, tz=timezone.utc).strftime( + "%d.%m.%Y %H:%M" + ) + assert result == expected + @pytest.mark.asyncio async def test_get_date_by_file_name_midnight(self, audio_repository): - """Тест преобразования timestamp для полуночи""" - test_timestamp = 1705190400 # 2024-01-15 00:00:00 + """Тест преобразования timestamp для полуночи (UTC).""" + test_timestamp = 1705190400 # 2024-01-14 00:00:00 UTC audio_repository._execute_query_with_result.return_value = [(test_timestamp,)] - + result = await audio_repository.get_date_by_file_name("test_audio.ogg") - - assert result == "14.01.2024 03:00" - + + expected = datetime.fromtimestamp(test_timestamp, tz=timezone.utc).strftime( + "%d.%m.%Y %H:%M" + ) + assert result == expected + @pytest.mark.asyncio async def test_get_date_by_file_name_year_end(self, audio_repository): - """Тест преобразования timestamp для конца года""" - test_timestamp = 1704067200 # 2023-12-31 23:59:59 + """Тест преобразования timestamp для конца года (UTC).""" + test_timestamp = 1704067200 # 2023-12-31 00:00:00 UTC audio_repository._execute_query_with_result.return_value = [(test_timestamp,)] - + result = await audio_repository.get_date_by_file_name("test_audio.ogg") - - assert result == "01.01.2024 03:00" - + + expected = datetime.fromtimestamp(test_timestamp, tz=timezone.utc).strftime( + "%d.%m.%Y %H:%M" + ) + assert result == expected + @pytest.mark.asyncio async def test_foreign_keys_enabled_called(self, audio_repository): """Тест что метод enable_foreign_keys вызывается""" await audio_repository.enable_foreign_keys() - - audio_repository._execute_query.assert_called_once_with("PRAGMA foreign_keys = ON;") + + audio_repository._execute_query.assert_called_once_with( + "PRAGMA foreign_keys = ON;" + ) audio_repository.logger.info.assert_not_called() # Этот метод не логирует - + @pytest.mark.asyncio async def test_create_tables_logging(self, audio_repository): """Тест логирования при создании таблиц""" await audio_repository.create_tables() - + # Проверяем, что лог записан - audio_repository.logger.info.assert_called_once_with("Таблицы для аудио созданы") - + audio_repository.logger.info.assert_called_once_with( + "Таблицы для аудио созданы" + ) + @pytest.mark.asyncio async def test_add_audio_record_logging_format(self, audio_repository): """Тест формата лога при добавлении аудио записи""" from database.models import AudioMessage - + audio_msg = AudioMessage( file_name="test_audio.ogg", author_id=12345, date_added="2025-01-15 14:30:00", file_id="test_file_id", - listen_count=0 + listen_count=0, ) - + await audio_repository.add_audio_record(audio_msg) - + # Проверяем формат лога log_call = audio_repository.logger.info.call_args log_message = log_call[0][0] - + assert "Аудио добавлено:" in log_message assert "file_name=test_audio.ogg" in log_message assert "author_id=12345" in log_message - + @pytest.mark.asyncio async def test_add_audio_record_simple_logging_format(self, audio_repository): """Тест формата лога при упрощенном добавлении""" - await audio_repository.add_audio_record_simple("test_audio.ogg", 12345, "2025-01-15 14:30:00") - + await audio_repository.add_audio_record_simple( + "test_audio.ogg", 12345, "2025-01-15 14:30:00" + ) + # Проверяем формат лога log_call = audio_repository.logger.info.call_args log_message = log_call[0][0] - + assert "Аудио добавлено:" in log_message assert "file_name=test_audio.ogg" in log_message assert "user_id=12345" in log_message - + @pytest.mark.asyncio async def test_get_date_by_file_name_logging_format(self, audio_repository): - """Тест формата лога при получении даты""" - test_timestamp = 1642248600 # 2022-01-17 10:30:00 + """Тест формата лога при получении даты (UTC).""" + test_timestamp = 1642248600 # 2022-01-15 12:10:00 UTC audio_repository._execute_query_with_result.return_value = [(test_timestamp,)] - + await audio_repository.get_date_by_file_name("test_audio.ogg") - - # Проверяем формат лога + + expected = datetime.fromtimestamp(test_timestamp, tz=timezone.utc).strftime( + "%d.%m.%Y %H:%M" + ) log_call = audio_repository.logger.info.call_args log_message = log_call[0][0] - + assert "Получена дата" in log_message - assert "15.01.2022 15:10" in log_message + assert expected in log_message assert "test_audio.ogg" in log_message class TestAudioRepositoryEdgeCases: """Тесты граничных случаев для AudioRepository""" - + @pytest.fixture def audio_repository(self): """Экземпляр AudioRepository для тестов""" - with patch.object(AudioRepository, '__init__', return_value=None): + with patch.object(AudioRepository, "__init__", return_value=None): repo = AudioRepository() repo._execute_query = AsyncMock() repo._execute_query_with_result = AsyncMock() repo.logger = Mock() return repo - + @pytest.mark.asyncio async def test_add_audio_record_empty_string_date(self, audio_repository): """Тест добавления с пустой строковой датой""" from database.models import AudioMessage - + audio_msg = AudioMessage( file_name="test_audio.ogg", author_id=12345, date_added="", file_id="test_file_id", - listen_count=0 + listen_count=0, ) - + # Должно вызвать ValueError при парсинге пустой строки with pytest.raises(ValueError): await audio_repository.add_audio_record(audio_msg) - + @pytest.mark.asyncio async def test_add_audio_record_invalid_string_date(self, audio_repository): """Тест добавления с некорректной строковой датой""" from database.models import AudioMessage - + audio_msg = AudioMessage( file_name="test_audio.ogg", author_id=12345, date_added="invalid_date", file_id="test_file_id", - listen_count=0 + listen_count=0, ) - + # Должно вызвать ValueError при парсинге некорректной даты with pytest.raises(ValueError): await audio_repository.add_audio_record(audio_msg) - + @pytest.mark.asyncio async def test_add_audio_record_none_date(self, audio_repository): """Тест добавления с None датой""" from database.models import AudioMessage - + audio_msg = AudioMessage( file_name="test_audio.ogg", author_id=12345, date_added=None, file_id="test_file_id", - listen_count=0 + listen_count=0, ) - + # Метод обрабатывает None как timestamp без преобразования await audio_repository.add_audio_record(audio_msg) - + # Проверяем, что метод был вызван с None call_args = audio_repository._execute_query.call_args params = call_args[0][1] @@ -349,49 +390,61 @@ class TestAudioRepositoryEdgeCases: # Должно вызвать ValueError при парсинге пустой строки with pytest.raises(ValueError): await audio_repository.add_audio_record_simple("test_audio.ogg", 12345, "") - + @pytest.mark.asyncio async def test_add_audio_record_simple_invalid_string_date(self, audio_repository): """Тест упрощенного добавления с некорректной строковой датой""" # Должно вызвать ValueError при парсинге некорректной даты with pytest.raises(ValueError): - await audio_repository.add_audio_record_simple("test_audio.ogg", 12345, "invalid_date") - + await audio_repository.add_audio_record_simple( + "test_audio.ogg", 12345, "invalid_date" + ) + @pytest.mark.asyncio async def test_add_audio_record_simple_none_date(self, audio_repository): """Тест упрощенного добавления с None датой""" # Метод обрабатывает None как timestamp без преобразования await audio_repository.add_audio_record_simple("test_audio.ogg", 12345, None) - + # Проверяем, что метод был вызван с None call_args = audio_repository._execute_query.call_args params = call_args[0][1] assert params[2] is None - + @pytest.mark.asyncio async def test_get_date_by_file_name_zero_timestamp(self, audio_repository): - """Тест получения даты для timestamp = 0 (1970-01-01)""" + """Тест получения даты для timestamp = 0 (1970-01-01 UTC).""" audio_repository._execute_query_with_result.return_value = [(0,)] - + result = await audio_repository.get_date_by_file_name("test_audio.ogg") - - assert result == "01.01.1970 03:00" - + + expected = datetime.fromtimestamp(0, tz=timezone.utc).strftime("%d.%m.%Y %H:%M") + assert result == expected + @pytest.mark.asyncio async def test_get_date_by_file_name_negative_timestamp(self, audio_repository): - """Тест получения даты для отрицательного timestamp""" - audio_repository._execute_query_with_result.return_value = [(-3600,)] # 1969-12-31 23:00:00 - + """Тест получения даты для отрицательного timestamp (UTC).""" + ts = -3600 # 1969-12-31 23:00:00 UTC + audio_repository._execute_query_with_result.return_value = [(ts,)] + result = await audio_repository.get_date_by_file_name("test_audio.ogg") - - assert result == "01.01.1970 02:00" - + + expected = datetime.fromtimestamp(ts, tz=timezone.utc).strftime( + "%d.%m.%Y %H:%M" + ) + assert result == expected + @pytest.mark.asyncio async def test_get_date_by_file_name_future_timestamp(self, audio_repository): - """Тест получения даты для будущего timestamp""" - future_timestamp = int(datetime(2030, 12, 31, 23, 59, 59).timestamp()) + """Тест получения даты для будущего timestamp (UTC, без зависимости от локали).""" + future_timestamp = int( + datetime(2030, 12, 31, 23, 59, 59, tzinfo=timezone.utc).timestamp() + ) audio_repository._execute_query_with_result.return_value = [(future_timestamp,)] - + result = await audio_repository.get_date_by_file_name("test_audio.ogg") - - assert result == "31.12.2030 23:59" + + expected = datetime.fromtimestamp(future_timestamp, tz=timezone.utc).strftime( + "%d.%m.%Y %H:%M" + ) + assert result == expected diff --git a/tests/test_auto_unban_integration.py b/tests/test_auto_unban_integration.py index 8078d0d..1bb82b6 100644 --- a/tests/test_auto_unban_integration.py +++ b/tests/test_auto_unban_integration.py @@ -4,33 +4,34 @@ from datetime import datetime, timedelta, timezone from unittest.mock import AsyncMock, Mock, patch import pytest + from helper_bot.utils.auto_unban_scheduler import AutoUnbanScheduler class TestAutoUnbanIntegration: """Интеграционные тесты для автоматического разбана""" - + @pytest.fixture def test_db_path(self): """Путь к тестовой базе данных""" - return 'database/test_auto_unban.db' - + return "database/test_auto_unban.db" + @pytest.fixture def setup_test_db(self, test_db_path): """Создает тестовую базу данных с таблицами blacklist, our_users и blacklist_history""" # Удаляем старую тестовую базу если она существует if os.path.exists(test_db_path): os.remove(test_db_path) - + # Создаем новую базу данных conn = sqlite3.connect(test_db_path) cursor = conn.cursor() - + # Включаем поддержку внешних ключей cursor.execute("PRAGMA foreign_keys = ON") - + # Создаем таблицу our_users (нужна для внешних ключей) - cursor.execute(''' + cursor.execute(""" CREATE TABLE IF NOT EXISTS our_users ( user_id INTEGER NOT NULL PRIMARY KEY, first_name TEXT, @@ -44,10 +45,10 @@ class TestAutoUnbanIntegration: date_changed INTEGER NOT NULL, voice_bot_welcome_received BOOLEAN DEFAULT 0 ) - ''') - + """) + # Создаем таблицу blacklist - cursor.execute(''' + cursor.execute(""" CREATE TABLE IF NOT EXISTS blacklist ( user_id INTEGER NOT NULL PRIMARY KEY, message_for_user TEXT, @@ -56,10 +57,10 @@ class TestAutoUnbanIntegration: ban_author INTEGER, FOREIGN KEY (user_id) REFERENCES our_users(user_id) ON DELETE CASCADE ) - ''') - + """) + # Создаем таблицу blacklist_history - cursor.execute(''' + cursor.execute(""" CREATE TABLE IF NOT EXISTS blacklist_history ( id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, user_id INTEGER NOT NULL, @@ -72,273 +73,419 @@ class TestAutoUnbanIntegration: FOREIGN KEY (user_id) REFERENCES our_users(user_id) ON DELETE CASCADE, FOREIGN KEY (ban_author) REFERENCES our_users(user_id) ON DELETE SET NULL ) - ''') - + """) + # Создаем индексы для blacklist_history - cursor.execute(''' + cursor.execute(""" CREATE INDEX IF NOT EXISTS idx_blacklist_history_user_id ON blacklist_history(user_id) - ''') - cursor.execute(''' + """) + cursor.execute(""" CREATE INDEX IF NOT EXISTS idx_blacklist_history_date_ban ON blacklist_history(date_ban) - ''') - cursor.execute(''' + """) + cursor.execute(""" CREATE INDEX IF NOT EXISTS idx_blacklist_history_date_unban ON blacklist_history(date_unban) - ''') - + """) + # Добавляем тестовых пользователей в our_users current_time = int(datetime.now(timezone(timedelta(hours=3))).timestamp()) users_data = [ - (123, "Test", "Test User 1", "test_user1", 0, "ru", 0, "😊", current_time, current_time, 0), - (456, "Test", "Test User 2", "test_user2", 0, "ru", 0, "😊", current_time, current_time, 0), - (789, "Test", "Test User 3", "test_user3", 0, "ru", 0, "😊", current_time, current_time, 0), - (999, "Test", "Test User 4", "test_user4", 0, "ru", 0, "😊", current_time, current_time, 0), + ( + 123, + "Test", + "Test User 1", + "test_user1", + 0, + "ru", + 0, + "😊", + current_time, + current_time, + 0, + ), + ( + 456, + "Test", + "Test User 2", + "test_user2", + 0, + "ru", + 0, + "😊", + current_time, + current_time, + 0, + ), + ( + 789, + "Test", + "Test User 3", + "test_user3", + 0, + "ru", + 0, + "😊", + current_time, + current_time, + 0, + ), + ( + 999, + "Test", + "Test User 4", + "test_user4", + 0, + "ru", + 0, + "😊", + current_time, + current_time, + 0, + ), ] cursor.executemany( """INSERT INTO our_users (user_id, first_name, full_name, username, is_bot, language_code, has_stickers, emoji, date_added, date_changed, voice_bot_welcome_received) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""", - users_data + users_data, ) - + # Добавляем тестовые данные в blacklist today_timestamp = int(datetime.now(timezone(timedelta(hours=3))).timestamp()) - tomorrow_timestamp = int((datetime.now(timezone(timedelta(hours=3))) + timedelta(days=1)).timestamp()) - + tomorrow_timestamp = int( + (datetime.now(timezone(timedelta(hours=3))) + timedelta(days=1)).timestamp() + ) + blacklist_data = [ - (123, "Test ban 1", today_timestamp, current_time, None), # Разблокируется сегодня - (456, "Test ban 2", today_timestamp, current_time, None), # Разблокируется сегодня - (789, "Test ban 3", tomorrow_timestamp, current_time, None), # Разблокируется завтра - (999, "Test ban 4", None, current_time, None), # Навсегда заблокирован + ( + 123, + "Test ban 1", + today_timestamp, + current_time, + None, + ), # Разблокируется сегодня + ( + 456, + "Test ban 2", + today_timestamp, + current_time, + None, + ), # Разблокируется сегодня + ( + 789, + "Test ban 3", + tomorrow_timestamp, + current_time, + None, + ), # Разблокируется завтра + (999, "Test ban 4", None, current_time, None), # Навсегда заблокирован ] - + cursor.executemany( "INSERT INTO blacklist (user_id, message_for_user, date_to_unban, created_at, ban_author) VALUES (?, ?, ?, ?, ?)", - blacklist_data + blacklist_data, ) - + # Добавляем тестовые данные в blacklist_history # Для пользователей 123 и 456 (которые будут разблокированы) создаем записи с date_unban = NULL - yesterday_timestamp = int((datetime.now(timezone(timedelta(hours=3))) - timedelta(days=1)).timestamp()) - + yesterday_timestamp = int( + (datetime.now(timezone(timedelta(hours=3))) - timedelta(days=1)).timestamp() + ) + history_data = [ - (123, "Test ban 1", yesterday_timestamp, None, None, yesterday_timestamp, yesterday_timestamp), # Будет разблокирован - (456, "Test ban 2", yesterday_timestamp, None, None, yesterday_timestamp, yesterday_timestamp), # Будет разблокирован - (789, "Test ban 3", yesterday_timestamp, None, None, yesterday_timestamp, yesterday_timestamp), # Не будет разблокирован сегодня - (999, "Test ban 4", yesterday_timestamp, None, None, yesterday_timestamp, yesterday_timestamp), # Навсегда заблокирован + ( + 123, + "Test ban 1", + yesterday_timestamp, + None, + None, + yesterday_timestamp, + yesterday_timestamp, + ), # Будет разблокирован + ( + 456, + "Test ban 2", + yesterday_timestamp, + None, + None, + yesterday_timestamp, + yesterday_timestamp, + ), # Будет разблокирован + ( + 789, + "Test ban 3", + yesterday_timestamp, + None, + None, + yesterday_timestamp, + yesterday_timestamp, + ), # Не будет разблокирован сегодня + ( + 999, + "Test ban 4", + yesterday_timestamp, + None, + None, + yesterday_timestamp, + yesterday_timestamp, + ), # Навсегда заблокирован ] - + cursor.executemany( """INSERT INTO blacklist_history (user_id, message_for_user, date_ban, date_unban, ban_author, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?)""", - history_data + history_data, ) - + conn.commit() conn.close() - + yield test_db_path - + # Очистка после тестов if os.path.exists(test_db_path): os.remove(test_db_path) - + @pytest.fixture def mock_bdf(self, test_db_path): """Создает мок фабрики зависимостей с тестовой базой""" mock_factory = Mock() mock_factory.settings = { - 'Telegram': { - 'group_for_logs': '-1001234567890', - 'important_logs': '-1001234567891' + "Telegram": { + "group_for_logs": "-1001234567890", + "important_logs": "-1001234567891", } } - + # Создаем реальный экземпляр базы данных с тестовым файлом import os from database.async_db import AsyncBotDB + mock_factory.database = AsyncBotDB(test_db_path) - + return mock_factory - + @pytest.fixture def mock_bot(self): """Создает мок бота""" mock_bot = Mock() mock_bot.send_message = AsyncMock() return mock_bot - + @pytest.mark.asyncio - @patch('helper_bot.utils.auto_unban_scheduler.get_global_instance') - async def test_auto_unban_with_real_db(self, mock_get_instance, setup_test_db, mock_bdf, mock_bot): + @patch("helper_bot.utils.auto_unban_scheduler.get_global_instance") + async def test_auto_unban_with_real_db( + self, mock_get_instance, setup_test_db, mock_bdf, mock_bot + ): """Тест автоматического разбана с реальной базой данных""" # Настройка моков mock_get_instance.return_value = mock_bdf - + # Создаем планировщик scheduler = AutoUnbanScheduler() scheduler.bot_db = mock_bdf.database scheduler.set_bot(mock_bot) - + # Проверяем начальное состояние базы conn = sqlite3.connect(setup_test_db) cursor = conn.cursor() cursor.execute("SELECT COUNT(*) FROM blacklist") initial_count = cursor.fetchone()[0] assert initial_count == 4 - + # Проверяем начальное состояние истории: должно быть 2 записи с date_unban IS NULL для user_id 123 и 456 - cursor.execute("SELECT COUNT(*) FROM blacklist_history WHERE user_id IN (123, 456) AND date_unban IS NULL") + cursor.execute( + "SELECT COUNT(*) FROM blacklist_history WHERE user_id IN (123, 456) AND date_unban IS NULL" + ) initial_open_history = cursor.fetchone()[0] assert initial_open_history == 2 - + # Запоминаем время до разбана для проверки updated_at - before_unban_timestamp = int(datetime.now(timezone(timedelta(hours=3))).timestamp()) - + before_unban_timestamp = int( + datetime.now(timezone(timedelta(hours=3))).timestamp() + ) + # Выполняем автоматический разбан await scheduler.auto_unban_users() - + # Запоминаем время после разбана для проверки updated_at - after_unban_timestamp = int(datetime.now(timezone(timedelta(hours=3))).timestamp()) - + after_unban_timestamp = int( + datetime.now(timezone(timedelta(hours=3))).timestamp() + ) + # Проверяем, что пользователи с сегодняшней датой разблокированы current_timestamp = int(datetime.now(timezone(timedelta(hours=3))).timestamp()) - cursor.execute("SELECT COUNT(*) FROM blacklist WHERE date_to_unban IS NOT NULL AND date_to_unban <= ?", - (current_timestamp,)) + cursor.execute( + "SELECT COUNT(*) FROM blacklist WHERE date_to_unban IS NOT NULL AND date_to_unban <= ?", + (current_timestamp,), + ) today_count = cursor.fetchone()[0] assert today_count == 0 - + # Проверяем, что пользователи с завтрашней датой остались - cursor.execute("SELECT COUNT(*) FROM blacklist WHERE date_to_unban IS NOT NULL AND date_to_unban > ?", - (current_timestamp,)) + cursor.execute( + "SELECT COUNT(*) FROM blacklist WHERE date_to_unban IS NOT NULL AND date_to_unban > ?", + (current_timestamp,), + ) tomorrow_count = cursor.fetchone()[0] assert tomorrow_count == 1 - + # Проверяем, что навсегда заблокированные пользователи остались cursor.execute("SELECT COUNT(*) FROM blacklist WHERE date_to_unban IS NULL") permanent_count = cursor.fetchone()[0] assert permanent_count == 1 - + # Проверяем общее количество записей cursor.execute("SELECT COUNT(*) FROM blacklist") final_count = cursor.fetchone()[0] assert final_count == 2 # Остались только завтрашние и навсегда заблокированные - + # Проверяем историю банов: для user_id 123 и 456 должны быть установлены date_unban - cursor.execute("SELECT user_id, date_unban, updated_at FROM blacklist_history WHERE user_id IN (123, 456) ORDER BY user_id") + cursor.execute( + "SELECT user_id, date_unban, updated_at FROM blacklist_history WHERE user_id IN (123, 456) ORDER BY user_id" + ) history_records = cursor.fetchall() - + assert len(history_records) == 2 - + for user_id, date_unban, updated_at in history_records: # Проверяем, что date_unban установлен (не NULL) - assert date_unban is not None, f"date_unban должен быть установлен для user_id={user_id}" - assert isinstance(date_unban, int), f"date_unban должен быть integer для user_id={user_id}" - + assert ( + date_unban is not None + ), f"date_unban должен быть установлен для user_id={user_id}" + assert isinstance( + date_unban, int + ), f"date_unban должен быть integer для user_id={user_id}" + # Проверяем, что date_unban находится в разумных пределах (между before и after) - assert before_unban_timestamp <= date_unban <= after_unban_timestamp, \ - f"date_unban для user_id={user_id} должен быть между {before_unban_timestamp} и {after_unban_timestamp}, получен {date_unban}" - + assert ( + before_unban_timestamp <= date_unban <= after_unban_timestamp + ), f"date_unban для user_id={user_id} должен быть между {before_unban_timestamp} и {after_unban_timestamp}, получен {date_unban}" + # Проверяем, что updated_at обновлен - assert updated_at is not None, f"updated_at должен быть установлен для user_id={user_id}" - assert isinstance(updated_at, int), f"updated_at должен быть integer для user_id={user_id}" - assert before_unban_timestamp <= updated_at <= after_unban_timestamp, \ - f"updated_at для user_id={user_id} должен быть между {before_unban_timestamp} и {after_unban_timestamp}, получен {updated_at}" - + assert ( + updated_at is not None + ), f"updated_at должен быть установлен для user_id={user_id}" + assert isinstance( + updated_at, int + ), f"updated_at должен быть integer для user_id={user_id}" + assert ( + before_unban_timestamp <= updated_at <= after_unban_timestamp + ), f"updated_at для user_id={user_id} должен быть между {before_unban_timestamp} и {after_unban_timestamp}, получен {updated_at}" + # Проверяем, что для user_id 789 и 999 записи в истории остались без изменений (date_unban все еще NULL) - cursor.execute("SELECT COUNT(*) FROM blacklist_history WHERE user_id IN (789, 999) AND date_unban IS NULL") + cursor.execute( + "SELECT COUNT(*) FROM blacklist_history WHERE user_id IN (789, 999) AND date_unban IS NULL" + ) unchanged_history = cursor.fetchone()[0] - assert unchanged_history == 2, "Записи для user_id 789 и 999 должны остаться с date_unban = NULL" - + assert ( + unchanged_history == 2 + ), "Записи для user_id 789 и 999 должны остаться с date_unban = NULL" + conn.close() - + # Проверяем, что отчет был отправлен mock_bot.send_message.assert_called_once() - + @pytest.mark.asyncio - @patch('helper_bot.utils.auto_unban_scheduler.get_global_instance') - async def test_auto_unban_no_users_today(self, mock_get_instance, setup_test_db, mock_bdf, mock_bot): + @patch("helper_bot.utils.auto_unban_scheduler.get_global_instance") + async def test_auto_unban_no_users_today( + self, mock_get_instance, setup_test_db, mock_bdf, mock_bot + ): """Тест разбана когда нет пользователей для разблокировки сегодня""" # Настройка моков mock_get_instance.return_value = mock_bdf - + # Удаляем пользователей с сегодняшней датой conn = sqlite3.connect(setup_test_db) cursor = conn.cursor() current_timestamp = int(datetime.now(timezone(timedelta(hours=3))).timestamp()) - cursor.execute("DELETE FROM blacklist WHERE date_to_unban IS NOT NULL AND date_to_unban <= ?", (current_timestamp,)) - + cursor.execute( + "DELETE FROM blacklist WHERE date_to_unban IS NOT NULL AND date_to_unban <= ?", + (current_timestamp,), + ) + # Проверяем начальное состояние истории: все записи должны иметь date_unban = NULL - cursor.execute("SELECT COUNT(*) FROM blacklist_history WHERE date_unban IS NULL") + cursor.execute( + "SELECT COUNT(*) FROM blacklist_history WHERE date_unban IS NULL" + ) initial_open_history = cursor.fetchone()[0] assert initial_open_history == 4 # Все 4 записи должны быть открытыми - + conn.commit() conn.close() - + # Создаем планировщик scheduler = AutoUnbanScheduler() scheduler.bot_db = mock_bdf.database scheduler.set_bot(mock_bot) - + # Выполняем автоматический разбан await scheduler.auto_unban_users() - + # Проверяем, что история не изменилась (все записи все еще с date_unban = NULL) conn = sqlite3.connect(setup_test_db) cursor = conn.cursor() - cursor.execute("SELECT COUNT(*) FROM blacklist_history WHERE date_unban IS NULL") + cursor.execute( + "SELECT COUNT(*) FROM blacklist_history WHERE date_unban IS NULL" + ) final_open_history = cursor.fetchone()[0] - assert final_open_history == 4, "История не должна изменяться, если нет пользователей для разблокировки" + assert ( + final_open_history == 4 + ), "История не должна изменяться, если нет пользователей для разблокировки" conn.close() - + # Проверяем, что отчет не был отправлен (нет пользователей для разблокировки) mock_bot.send_message.assert_not_called() - + @pytest.mark.asyncio - @patch('helper_bot.utils.auto_unban_scheduler.get_global_instance') - async def test_auto_unban_database_error(self, mock_get_instance, setup_test_db, mock_bdf, mock_bot): + @patch("helper_bot.utils.auto_unban_scheduler.get_global_instance") + async def test_auto_unban_database_error( + self, mock_get_instance, setup_test_db, mock_bdf, mock_bot + ): """Тест обработки ошибок базы данных""" # Настройка моков mock_get_instance.return_value = mock_bdf - + # Создаем планировщик scheduler = AutoUnbanScheduler() scheduler.bot_db = mock_bdf.database scheduler.set_bot(mock_bot) - + # Удаляем таблицу чтобы вызвать ошибку conn = sqlite3.connect(setup_test_db) cursor = conn.cursor() cursor.execute("DROP TABLE blacklist") conn.commit() conn.close() - + # Выполняем автоматический разбан await scheduler.auto_unban_users() - + # Проверяем, что отчет об ошибке был отправлен mock_bot.send_message.assert_called_once() call_args = mock_bot.send_message.call_args - assert call_args[1]['chat_id'] == '-1001234567891' # important_logs - assert "Ошибка автоматического разбана" in call_args[1]['text'] - + assert call_args[1]["chat_id"] == "-1001234567891" # important_logs + assert "Ошибка автоматического разбана" in call_args[1]["text"] + @pytest.mark.asyncio - @patch('helper_bot.utils.auto_unban_scheduler.get_global_instance') - async def test_auto_unban_updates_history(self, mock_get_instance, setup_test_db, mock_bdf, mock_bot): + @patch("helper_bot.utils.auto_unban_scheduler.get_global_instance") + async def test_auto_unban_updates_history( + self, mock_get_instance, setup_test_db, mock_bdf, mock_bot + ): """Тест что автоматический разбан обновляет историю банов""" # Настройка моков mock_get_instance.return_value = mock_bdf - + # Создаем планировщик scheduler = AutoUnbanScheduler() scheduler.bot_db = mock_bdf.database scheduler.set_bot(mock_bot) - + conn = sqlite3.connect(setup_test_db) cursor = conn.cursor() - + # Проверяем начальное состояние: для user_id 123 и 456 должны быть записи с date_unban = NULL cursor.execute(""" SELECT id, user_id, date_ban, date_unban, updated_at @@ -347,26 +494,32 @@ class TestAutoUnbanIntegration: ORDER BY user_id """) initial_records = cursor.fetchall() - assert len(initial_records) == 2, "Должно быть 2 открытые записи для user_id 123 и 456" - + assert ( + len(initial_records) == 2 + ), "Должно быть 2 открытые записи для user_id 123 и 456" + # Запоминаем ID записей и их начальные значения updated_at record_ids = {row[0]: (row[1], row[4]) for row in initial_records} - + # Запоминаем время до разбана - before_unban_timestamp = int(datetime.now(timezone(timedelta(hours=3))).timestamp()) - + before_unban_timestamp = int( + datetime.now(timezone(timedelta(hours=3))).timestamp() + ) + conn.close() - + # Выполняем автоматический разбан await scheduler.auto_unban_users() - + # Запоминаем время после разбана - after_unban_timestamp = int(datetime.now(timezone(timedelta(hours=3))).timestamp()) - + after_unban_timestamp = int( + datetime.now(timezone(timedelta(hours=3))).timestamp() + ) + # Проверяем, что записи обновлены conn = sqlite3.connect(setup_test_db) cursor = conn.cursor() - + cursor.execute(""" SELECT id, user_id, date_ban, date_unban, updated_at FROM blacklist_history @@ -374,32 +527,45 @@ class TestAutoUnbanIntegration: ORDER BY user_id """) updated_records = cursor.fetchall() - + assert len(updated_records) == 2, "Должно быть 2 записи для user_id 123 и 456" - + for record_id, user_id, date_ban, date_unban, updated_at in updated_records: # Проверяем, что это одна из наших записей - assert record_id in record_ids, f"Запись с id={record_id} должна быть в исходных записях" - + assert ( + record_id in record_ids + ), f"Запись с id={record_id} должна быть в исходных записях" + # Проверяем, что date_unban установлен - assert date_unban is not None, f"date_unban должен быть установлен для user_id={user_id}" - assert isinstance(date_unban, int), f"date_unban должен быть integer для user_id={user_id}" - + assert ( + date_unban is not None + ), f"date_unban должен быть установлен для user_id={user_id}" + assert isinstance( + date_unban, int + ), f"date_unban должен быть integer для user_id={user_id}" + # Проверяем, что date_unban находится в разумных пределах - assert before_unban_timestamp <= date_unban <= after_unban_timestamp, \ - f"date_unban для user_id={user_id} должен быть между {before_unban_timestamp} и {after_unban_timestamp}" - + assert ( + before_unban_timestamp <= date_unban <= after_unban_timestamp + ), f"date_unban для user_id={user_id} должен быть между {before_unban_timestamp} и {after_unban_timestamp}" + # Проверяем, что updated_at обновлен (должен быть больше начального значения) - assert updated_at is not None, f"updated_at должен быть установлен для user_id={user_id}" - assert isinstance(updated_at, int), f"updated_at должен быть integer для user_id={user_id}" - assert before_unban_timestamp <= updated_at <= after_unban_timestamp, \ - f"updated_at для user_id={user_id} должен быть между {before_unban_timestamp} и {after_unban_timestamp}" - + assert ( + updated_at is not None + ), f"updated_at должен быть установлен для user_id={user_id}" + assert isinstance( + updated_at, int + ), f"updated_at должен быть integer для user_id={user_id}" + assert ( + before_unban_timestamp <= updated_at <= after_unban_timestamp + ), f"updated_at для user_id={user_id} должен быть между {before_unban_timestamp} и {after_unban_timestamp}" + # Проверяем, что updated_at действительно обновлен (больше начального значения) initial_updated_at = record_ids[record_id][1] - assert updated_at >= initial_updated_at, \ - f"updated_at для user_id={user_id} должен быть больше или равен начальному значению" - + assert ( + updated_at >= initial_updated_at + ), f"updated_at для user_id={user_id} должен быть больше или равен начальному значению" + # Проверяем, что обновлена только последняя запись для каждого пользователя # (если бы было несколько записей, обновилась бы только последняя) cursor.execute(""" @@ -407,29 +573,35 @@ class TestAutoUnbanIntegration: WHERE user_id IN (123, 456) AND date_unban IS NOT NULL """) closed_records = cursor.fetchone()[0] - assert closed_records == 2, "Должно быть закрыто 2 записи (по одной для каждого пользователя)" - + assert ( + closed_records == 2 + ), "Должно быть закрыто 2 записи (по одной для каждого пользователя)" + cursor.execute(""" SELECT COUNT(*) FROM blacklist_history WHERE user_id IN (123, 456) AND date_unban IS NULL """) open_records = cursor.fetchone()[0] - assert open_records == 0, "Не должно быть открытых записей для user_id 123 и 456" - + assert ( + open_records == 0 + ), "Не должно быть открытых записей для user_id 123 и 456" + conn.close() - + def test_date_format_consistency(self, setup_test_db, mock_bdf): """Тест консистентности формата дат""" scheduler = AutoUnbanScheduler() scheduler.bot_db = mock_bdf.database - + # Проверяем, что дата в базе соответствует ожидаемому формату (timestamp) conn = sqlite3.connect(setup_test_db) cursor = conn.cursor() - cursor.execute("SELECT date_to_unban FROM blacklist WHERE date_to_unban IS NOT NULL LIMIT 1") + cursor.execute( + "SELECT date_to_unban FROM blacklist WHERE date_to_unban IS NOT NULL LIMIT 1" + ) result = cursor.fetchone() conn.close() - + if result and result[0]: timestamp = result[0] # Проверяем, что это валидный timestamp (целое число) @@ -442,38 +614,39 @@ class TestAutoUnbanIntegration: class TestSchedulerLifecycle: """Тесты жизненного цикла планировщика""" - + def test_scheduler_start_stop(self): """Тест запуска и остановки планировщика""" scheduler = AutoUnbanScheduler() - + # Запускаем планировщик scheduler.start_scheduler() assert scheduler.scheduler.running - + # Останавливаем планировщик (должно пройти без ошибок) scheduler.stop_scheduler() # APScheduler может не сразу остановиться, но это нормально - + def test_scheduler_job_creation(self): """Тест создания задачи в планировщике""" scheduler = AutoUnbanScheduler() - - with patch.object(scheduler.scheduler, 'add_job') as mock_add_job: + + with patch.object(scheduler.scheduler, "add_job") as mock_add_job: scheduler.start_scheduler() - + # Проверяем, что задача была создана с правильными параметрами mock_add_job.assert_called_once() call_args = mock_add_job.call_args - + # Проверяем функцию assert call_args[0][0] == scheduler.auto_unban_users - + # Проверяем триггер (должен быть CronTrigger) from apscheduler.triggers.cron import CronTrigger + assert isinstance(call_args[0][1], CronTrigger) - + # Проверяем ID и имя задачи - assert call_args[1]['id'] == 'auto_unban_users' - assert call_args[1]['name'] == 'Автоматический разбан пользователей' - assert call_args[1]['replace_existing'] is True + assert call_args[1]["id"] == "auto_unban_users" + assert call_args[1]["name"] == "Автоматический разбан пользователей" + assert call_args[1]["replace_existing"] is True diff --git a/tests/test_auto_unban_scheduler.py b/tests/test_auto_unban_scheduler.py index 294cfb2..7f64bef 100644 --- a/tests/test_auto_unban_scheduler.py +++ b/tests/test_auto_unban_scheduler.py @@ -3,231 +3,248 @@ from datetime import datetime, timedelta, timezone from unittest.mock import AsyncMock, Mock, patch import pytest -from helper_bot.utils.auto_unban_scheduler import (AutoUnbanScheduler, - get_auto_unban_scheduler) + +from helper_bot.utils.auto_unban_scheduler import ( + AutoUnbanScheduler, + get_auto_unban_scheduler, +) class TestAutoUnbanScheduler: """Тесты для класса AutoUnbanScheduler""" - + @pytest.fixture def scheduler(self): """Создает экземпляр планировщика для тестов""" return AutoUnbanScheduler() - + @pytest.fixture def mock_bot_db(self): """Создает мок базы данных""" mock_db = Mock() - mock_db.get_users_for_unblock_today = AsyncMock(return_value={ - 123: "test_user1", - 456: "test_user2" - }) + mock_db.get_users_for_unblock_today = AsyncMock( + return_value={123: "test_user1", 456: "test_user2"} + ) mock_db.delete_user_blacklist = AsyncMock(return_value=True) return mock_db - + @pytest.fixture def mock_bdf(self): """Создает мок фабрики зависимостей""" mock_factory = Mock() mock_factory.settings = { - 'Telegram': { - 'group_for_logs': '-1001234567890', - 'important_logs': '-1001234567891' + "Telegram": { + "group_for_logs": "-1001234567890", + "important_logs": "-1001234567891", } } return mock_factory - + @pytest.fixture def mock_bot(self): """Создает мок бота""" mock_bot = Mock() mock_bot.send_message = AsyncMock() return mock_bot - + def test_scheduler_initialization(self, scheduler): """Тест инициализации планировщика""" assert scheduler.bot_db is not None assert scheduler.scheduler is not None assert scheduler.bot is None - + def test_set_bot(self, scheduler, mock_bot): """Тест установки бота""" scheduler.set_bot(mock_bot) assert scheduler.bot == mock_bot - + @pytest.mark.asyncio - @patch('helper_bot.utils.auto_unban_scheduler.get_global_instance') - async def test_auto_unban_users_success(self, mock_get_instance, scheduler, mock_bot_db, mock_bdf, mock_bot): + @patch("helper_bot.utils.auto_unban_scheduler.get_global_instance") + async def test_auto_unban_users_success( + self, mock_get_instance, scheduler, mock_bot_db, mock_bdf, mock_bot + ): """Тест успешного выполнения автоматического разбана""" # Настройка моков mock_get_instance.return_value = mock_bdf scheduler.bot_db = mock_bot_db scheduler.set_bot(mock_bot) - + # Выполнение теста await scheduler.auto_unban_users() - + # Проверки mock_bot_db.get_users_for_unblock_today.assert_called_once() assert mock_bot_db.delete_user_blacklist.call_count == 2 mock_bot.send_message.assert_called_once() - + @pytest.mark.asyncio - @patch('helper_bot.utils.auto_unban_scheduler.get_global_instance') - async def test_auto_unban_users_no_users(self, mock_get_instance, scheduler, mock_bot_db, mock_bdf, mock_bot): + @patch("helper_bot.utils.auto_unban_scheduler.get_global_instance") + async def test_auto_unban_users_no_users( + self, mock_get_instance, scheduler, mock_bot_db, mock_bdf, mock_bot + ): """Тест разбана когда нет пользователей для разблокировки""" # Настройка моков mock_get_instance.return_value = mock_bdf mock_bot_db.get_users_for_unblock_today = AsyncMock(return_value={}) scheduler.bot_db = mock_bot_db scheduler.set_bot(mock_bot) - + # Выполнение теста await scheduler.auto_unban_users() - + # Проверки mock_bot_db.get_users_for_unblock_today.assert_called_once() mock_bot_db.delete_user_blacklist.assert_not_called() mock_bot.send_message.assert_not_called() - + @pytest.mark.asyncio - @patch('helper_bot.utils.auto_unban_scheduler.get_global_instance') - async def test_auto_unban_users_partial_failure(self, mock_get_instance, scheduler, mock_bot_db, mock_bdf, mock_bot): + @patch("helper_bot.utils.auto_unban_scheduler.get_global_instance") + async def test_auto_unban_users_partial_failure( + self, mock_get_instance, scheduler, mock_bot_db, mock_bdf, mock_bot + ): """Тест разбана с частичными ошибками""" # Настройка моков mock_get_instance.return_value = mock_bdf - mock_bot_db.get_users_for_unblock_today = AsyncMock(return_value={ - 123: "test_user1", - 456: "test_user2" - }) + mock_bot_db.get_users_for_unblock_today = AsyncMock( + return_value={123: "test_user1", 456: "test_user2"} + ) # Первый вызов успешен, второй - ошибка mock_bot_db.delete_user_blacklist = AsyncMock(side_effect=[True, False]) scheduler.bot_db = mock_bot_db scheduler.set_bot(mock_bot) - + # Выполнение теста await scheduler.auto_unban_users() - + # Проверки assert mock_bot_db.delete_user_blacklist.call_count == 2 mock_bot.send_message.assert_called_once() - + @pytest.mark.asyncio - @patch('helper_bot.utils.auto_unban_scheduler.get_global_instance') - async def test_auto_unban_users_exception(self, mock_get_instance, scheduler, mock_bot_db, mock_bdf, mock_bot): + @patch("helper_bot.utils.auto_unban_scheduler.get_global_instance") + async def test_auto_unban_users_exception( + self, mock_get_instance, scheduler, mock_bot_db, mock_bdf, mock_bot + ): """Тест разбана с исключением""" # Настройка моков mock_get_instance.return_value = mock_bdf - mock_bot_db.get_users_for_unblock_today = AsyncMock(side_effect=Exception("Database error")) + mock_bot_db.get_users_for_unblock_today = AsyncMock( + side_effect=Exception("Database error") + ) scheduler.bot_db = mock_bot_db scheduler.set_bot(mock_bot) - + # Выполнение теста await scheduler.auto_unban_users() - + # Проверки mock_bot.send_message.assert_called_once() # Проверяем, что сообщение об ошибке было отправлено call_args = mock_bot.send_message.call_args - assert "Ошибка автоматического разбана" in call_args[1]['text'] - + assert "Ошибка автоматического разбана" in call_args[1]["text"] + def test_generate_report(self, scheduler): """Тест генерации отчета""" users = {123: "test_user1", 456: "test_user2"} failed_users = ["456 (test_user2)"] - + report = scheduler._generate_report(1, 1, failed_users, users) - + assert "Отчет об автоматическом разбане" in report assert "Успешно разблокировано: 1" in report assert "Ошибок: 1" in report assert "ID: 123" in report assert "456 (test_user2)" in report - + @pytest.mark.asyncio - @patch('helper_bot.utils.auto_unban_scheduler.get_global_instance') + @patch("helper_bot.utils.auto_unban_scheduler.get_global_instance") async def test_send_report(self, mock_get_instance, scheduler, mock_bdf, mock_bot): """Тест отправки отчета""" mock_get_instance.return_value = mock_bdf scheduler.set_bot(mock_bot) - + report = "Test report" await scheduler._send_report(report) - + # Проверяем, что send_message был вызван mock_bot.send_message.assert_called_once() - + # Проверяем аргументы вызова call_args = mock_bot.send_message.call_args - assert call_args[1]['text'] == report - assert call_args[1]['parse_mode'] == 'HTML' - + assert call_args[1]["text"] == report + assert call_args[1]["parse_mode"] == "HTML" + @pytest.mark.asyncio - @patch('helper_bot.utils.auto_unban_scheduler.get_global_instance') - async def test_send_error_report(self, mock_get_instance, scheduler, mock_bdf, mock_bot): + @patch("helper_bot.utils.auto_unban_scheduler.get_global_instance") + async def test_send_error_report( + self, mock_get_instance, scheduler, mock_bdf, mock_bot + ): """Тест отправки отчета об ошибке""" mock_get_instance.return_value = mock_bdf scheduler.set_bot(mock_bot) - + error_msg = "Test error" await scheduler._send_error_report(error_msg) - + # Проверяем, что send_message был вызван mock_bot.send_message.assert_called_once() - + # Проверяем аргументы вызова call_args = mock_bot.send_message.call_args - assert "Ошибка автоматического разбана" in call_args[1]['text'] - assert error_msg in call_args[1]['text'] - assert call_args[1]['parse_mode'] == 'HTML' - + assert "Ошибка автоматического разбана" in call_args[1]["text"] + assert error_msg in call_args[1]["text"] + assert call_args[1]["parse_mode"] == "HTML" + def test_start_scheduler(self, scheduler): """Тест запуска планировщика""" - with patch.object(scheduler.scheduler, 'add_job') as mock_add_job, \ - patch.object(scheduler.scheduler, 'start') as mock_start: - + with ( + patch.object(scheduler.scheduler, "add_job") as mock_add_job, + patch.object(scheduler.scheduler, "start") as mock_start, + ): + scheduler.start_scheduler() - + mock_add_job.assert_called_once() mock_start.assert_called_once() - + def test_stop_scheduler(self, scheduler): """Тест остановки планировщика""" # Сначала запускаем планировщик scheduler.start_scheduler() - + # Проверяем, что планировщик запущен assert scheduler.scheduler.running - + # Теперь останавливаем (должно пройти без ошибок) scheduler.stop_scheduler() - + # Проверяем, что метод выполнился без исключений # APScheduler может не сразу остановиться, но это нормально - + @pytest.mark.asyncio - @patch('helper_bot.utils.auto_unban_scheduler.get_global_instance') - async def test_run_manual_unban(self, mock_get_instance, scheduler, mock_bot_db, mock_bdf, mock_bot): + @patch("helper_bot.utils.auto_unban_scheduler.get_global_instance") + async def test_run_manual_unban( + self, mock_get_instance, scheduler, mock_bot_db, mock_bdf, mock_bot + ): """Тест ручного запуска разбана""" mock_get_instance.return_value = mock_bdf mock_bot_db.get_users_for_unblock_today.return_value = {} scheduler.bot_db = mock_bot_db scheduler.set_bot(mock_bot) - + await scheduler.run_manual_unban() - + mock_bot_db.get_users_for_unblock_today.assert_called_once() class TestGetAutoUnbanScheduler: """Тесты для функции get_auto_unban_scheduler""" - + def test_get_auto_unban_scheduler(self): """Тест получения глобального экземпляра планировщика""" scheduler = get_auto_unban_scheduler() assert isinstance(scheduler, AutoUnbanScheduler) - + # Проверяем, что возвращается один и тот же экземпляр scheduler2 = get_auto_unban_scheduler() assert scheduler is scheduler2 @@ -235,17 +252,17 @@ class TestGetAutoUnbanScheduler: class TestDateHandling: """Тесты для обработки дат""" - + def test_moscow_timezone(self): """Тест работы с московским временем""" scheduler = AutoUnbanScheduler() - + # Проверяем, что дата формируется в правильном формате moscow_tz = timezone(timedelta(hours=3)) today = datetime.now(moscow_tz).strftime("%Y-%m-%d") - + assert len(today) == 10 # YYYY-MM-DD - assert today.count('-') == 2 + assert today.count("-") == 2 assert today[:4].isdigit() # Год assert today[5:7].isdigit() # Месяц assert today[8:10].isdigit() # День @@ -254,35 +271,37 @@ class TestDateHandling: @pytest.mark.asyncio class TestAsyncOperations: """Тесты асинхронных операций""" - - @patch('helper_bot.utils.auto_unban_scheduler.get_global_instance') + + @patch("helper_bot.utils.auto_unban_scheduler.get_global_instance") async def test_async_auto_unban_flow(self, mock_get_instance): """Тест полного асинхронного потока разбана""" # Создаем моки mock_bdf = Mock() mock_bdf.settings = { - 'Telegram': { - 'group_for_logs': '-1001234567890', - 'important_logs': '-1001234567891' + "Telegram": { + "group_for_logs": "-1001234567890", + "important_logs": "-1001234567891", } } mock_get_instance.return_value = mock_bdf - + mock_bot_db = Mock() - mock_bot_db.get_users_for_unblock_today = AsyncMock(return_value={123: "test_user"}) + mock_bot_db.get_users_for_unblock_today = AsyncMock( + return_value={123: "test_user"} + ) mock_bot_db.delete_user_blacklist = AsyncMock(return_value=True) - + mock_bot = Mock() mock_bot.send_message = AsyncMock() - + # Создаем планировщик scheduler = AutoUnbanScheduler() scheduler.bot_db = mock_bot_db scheduler.set_bot(mock_bot) - + # Выполняем разбан await scheduler.auto_unban_users() - + # Проверяем результаты mock_bot_db.get_users_for_unblock_today.assert_called_once() mock_bot_db.delete_user_blacklist.assert_called_once_with(123) diff --git a/tests/test_blacklist_history_repository.py b/tests/test_blacklist_history_repository.py index 9ca7cba..828a222 100644 --- a/tests/test_blacklist_history_repository.py +++ b/tests/test_blacklist_history_repository.py @@ -3,14 +3,16 @@ from datetime import datetime from unittest.mock import AsyncMock, Mock, patch import pytest + from database.models import BlacklistHistoryRecord -from database.repositories.blacklist_history_repository import \ - BlacklistHistoryRepository +from database.repositories.blacklist_history_repository import ( + BlacklistHistoryRepository, +) class TestBlacklistHistoryRepository: """Тесты для BlacklistHistoryRepository""" - + @pytest.fixture def mock_db_connection(self): """Мок для DatabaseConnection""" @@ -19,18 +21,20 @@ class TestBlacklistHistoryRepository: mock_connection._execute_query_with_result = AsyncMock() mock_connection.logger = Mock() return mock_connection - + @pytest.fixture def blacklist_history_repository(self, mock_db_connection): """Экземпляр BlacklistHistoryRepository для тестов""" # Патчим наследование от DatabaseConnection - with patch.object(BlacklistHistoryRepository, '__init__', return_value=None): + with patch.object(BlacklistHistoryRepository, "__init__", return_value=None): repo = BlacklistHistoryRepository() repo._execute_query = mock_db_connection._execute_query - repo._execute_query_with_result = mock_db_connection._execute_query_with_result + repo._execute_query_with_result = ( + mock_db_connection._execute_query_with_result + ) repo.logger = mock_db_connection.logger return repo - + @pytest.fixture def sample_history_record(self): """Тестовая запись истории бана""" @@ -44,7 +48,7 @@ class TestBlacklistHistoryRepository: created_at=current_time, updated_at=current_time, ) - + @pytest.fixture def sample_history_record_with_unban(self): """Тестовая запись истории бана с датой разбана""" @@ -58,56 +62,75 @@ class TestBlacklistHistoryRepository: created_at=current_time - 86400, updated_at=current_time, ) - + @pytest.mark.asyncio async def test_create_tables(self, blacklist_history_repository): """Тест создания таблицы истории банов/разбанов""" await blacklist_history_repository.create_tables() - + # Проверяем, что метод вызван (4 раза: таблица + 3 индекса) assert blacklist_history_repository._execute_query.call_count == 4 calls = blacklist_history_repository._execute_query.call_args_list - + # Проверяем, что создается таблица с правильной структурой create_table_call = calls[0] assert "CREATE TABLE IF NOT EXISTS blacklist_history" in create_table_call[0][0] - assert "id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT" in create_table_call[0][0] + assert ( + "id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT" in create_table_call[0][0] + ) assert "user_id INTEGER NOT NULL" in create_table_call[0][0] assert "message_for_user TEXT" in create_table_call[0][0] assert "date_ban INTEGER NOT NULL" in create_table_call[0][0] assert "date_unban INTEGER" in create_table_call[0][0] assert "ban_author INTEGER" in create_table_call[0][0] - assert "created_at INTEGER DEFAULT (strftime('%s', 'now'))" in create_table_call[0][0] - assert "updated_at INTEGER DEFAULT (strftime('%s', 'now'))" in create_table_call[0][0] - assert "FOREIGN KEY (user_id) REFERENCES our_users(user_id) ON DELETE CASCADE" in create_table_call[0][0] - assert "FOREIGN KEY (ban_author) REFERENCES our_users(user_id) ON DELETE SET NULL" in create_table_call[0][0] - + assert ( + "created_at INTEGER DEFAULT (strftime('%s', 'now'))" + in create_table_call[0][0] + ) + assert ( + "updated_at INTEGER DEFAULT (strftime('%s', 'now'))" + in create_table_call[0][0] + ) + assert ( + "FOREIGN KEY (user_id) REFERENCES our_users(user_id) ON DELETE CASCADE" + in create_table_call[0][0] + ) + assert ( + "FOREIGN KEY (ban_author) REFERENCES our_users(user_id) ON DELETE SET NULL" + in create_table_call[0][0] + ) + # Проверяем создание индексов index_calls = calls[1:4] index_names = [call[0][0] for call in index_calls] assert any("idx_blacklist_history_user_id" in idx for idx in index_names) assert any("idx_blacklist_history_date_ban" in idx for idx in index_names) assert any("idx_blacklist_history_date_unban" in idx for idx in index_names) - + # Проверяем логирование blacklist_history_repository.logger.info.assert_called_once_with( "Таблица истории банов/разбанов создана" ) - + @pytest.mark.asyncio - async def test_add_record_on_ban(self, blacklist_history_repository, sample_history_record): + async def test_add_record_on_ban( + self, blacklist_history_repository, sample_history_record + ): """Тест добавления записи о бане в историю""" await blacklist_history_repository.add_record_on_ban(sample_history_record) - + # Проверяем, что метод вызван с правильными параметрами blacklist_history_repository._execute_query.assert_called_once() call_args = blacklist_history_repository._execute_query.call_args - + # Проверяем SQL запрос - sql_query = call_args[0][0].replace('\n', ' ').replace(' ', ' ').strip() + sql_query = call_args[0][0].replace("\n", " ").replace(" ", " ").strip() assert "INSERT INTO blacklist_history" in sql_query - assert "user_id, message_for_user, date_ban, date_unban, ban_author, created_at, updated_at" in sql_query - + assert ( + "user_id, message_for_user, date_ban, date_unban, ban_author, created_at, updated_at" + in sql_query + ) + # Проверяем параметры params = call_args[0][1] assert params[0] == 12345 # user_id @@ -117,13 +140,13 @@ class TestBlacklistHistoryRepository: assert params[4] == 999 # ban_author assert params[5] == sample_history_record.created_at # created_at assert params[6] == sample_history_record.updated_at # updated_at - + # Проверяем логирование blacklist_history_repository.logger.info.assert_called_once() log_call = blacklist_history_repository.logger.info.call_args[0][0] assert "Запись о бане добавлена в историю" in log_call assert "user_id=12345" in log_call - + @pytest.mark.asyncio async def test_add_record_on_ban_with_defaults(self, blacklist_history_repository): """Тест добавления записи о бане с дефолтными значениями created_at и updated_at""" @@ -136,122 +159,130 @@ class TestBlacklistHistoryRepository: created_at=None, # Будет установлено автоматически updated_at=None, # Будет установлено автоматически ) - + await blacklist_history_repository.add_record_on_ban(record) - + # Проверяем, что метод вызван blacklist_history_repository._execute_query.assert_called_once() call_args = blacklist_history_repository._execute_query.call_args - + # Проверяем, что created_at и updated_at установлены (не None) params = call_args[0][1] assert params[5] is not None # created_at assert params[6] is not None # updated_at assert isinstance(params[5], int) assert isinstance(params[6], int) - + @pytest.mark.asyncio async def test_set_unban_date_success(self, blacklist_history_repository): """Тест успешного обновления даты разбана""" user_id = 12345 date_unban = int(time.time()) - + # Мокируем результат проверки - находим открытую запись - blacklist_history_repository._execute_query_with_result.return_value = [(100,)] # id записи - + blacklist_history_repository._execute_query_with_result.return_value = [ + (100,) + ] # id записи + result = await blacklist_history_repository.set_unban_date(user_id, date_unban) - + # Проверяем, что сначала проверяется наличие записи assert blacklist_history_repository._execute_query_with_result.call_count == 1 check_call = blacklist_history_repository._execute_query_with_result.call_args assert "SELECT id FROM blacklist_history" in check_call[0][0] assert check_call[0][1] == (user_id,) - + # Проверяем, что затем обновляется запись assert blacklist_history_repository._execute_query.call_count == 1 update_call = blacklist_history_repository._execute_query.call_args - update_query = update_call[0][0].replace('\n', ' ').replace(' ', ' ').strip() + update_query = ( + update_call[0][0].replace("\n", " ").replace(" ", " ").strip() + ) assert "UPDATE blacklist_history" in update_query assert "SET date_unban = ?" in update_query assert "updated_at = ?" in update_query - + # Проверяем параметры обновления update_params = update_call[0][1] assert update_params[0] == date_unban assert update_params[1] is not None # updated_at (текущее время) assert isinstance(update_params[1], int) assert update_params[2] == 100 # id записи - + # Проверяем результат assert result is True - + # Проверяем логирование blacklist_history_repository.logger.info.assert_called_once() log_call = blacklist_history_repository.logger.info.call_args[0][0] assert "Дата разбана обновлена в истории" in log_call assert f"user_id={user_id}" in log_call - + @pytest.mark.asyncio async def test_set_unban_date_no_open_record(self, blacklist_history_repository): """Тест обновления даты разбана когда нет открытой записи""" user_id = 12345 date_unban = int(time.time()) - + # Мокируем результат проверки - нет открытых записей blacklist_history_repository._execute_query_with_result.return_value = [] - + result = await blacklist_history_repository.set_unban_date(user_id, date_unban) - + # Проверяем, что проверка была выполнена assert blacklist_history_repository._execute_query_with_result.call_count == 1 - + # Проверяем, что UPDATE не был вызван (нет записей для обновления) blacklist_history_repository._execute_query.assert_not_called() - + # Проверяем результат assert result is False - + # Проверяем логирование предупреждения blacklist_history_repository.logger.warning.assert_called_once() log_call = blacklist_history_repository.logger.warning.call_args[0][0] assert "Не найдена открытая запись в истории для обновления" in log_call assert f"user_id={user_id}" in log_call - + @pytest.mark.asyncio async def test_set_unban_date_exception(self, blacklist_history_repository): """Тест обработки исключения при обновлении даты разбана""" user_id = 12345 date_unban = int(time.time()) - + # Мокируем исключение при проверке - blacklist_history_repository._execute_query_with_result.side_effect = Exception("Database error") - + blacklist_history_repository._execute_query_with_result.side_effect = Exception( + "Database error" + ) + result = await blacklist_history_repository.set_unban_date(user_id, date_unban) - + # Проверяем, что метод вернул False при ошибке assert result is False - + # Проверяем логирование ошибки blacklist_history_repository.logger.error.assert_called_once() log_call = blacklist_history_repository.logger.error.call_args[0][0] assert "Ошибка обновления даты разбана в истории" in log_call assert f"user_id={user_id}" in log_call - + @pytest.mark.asyncio async def test_set_unban_date_update_exception(self, blacklist_history_repository): """Тест обработки исключения при обновлении записи""" user_id = 12345 date_unban = int(time.time()) - + # Мокируем успешную проверку, но ошибку при обновлении blacklist_history_repository._execute_query_with_result.return_value = [(100,)] - blacklist_history_repository._execute_query.side_effect = Exception("Update error") - + blacklist_history_repository._execute_query.side_effect = Exception( + "Update error" + ) + result = await blacklist_history_repository.set_unban_date(user_id, date_unban) - + # Проверяем, что метод вернул False при ошибке assert result is False - + # Проверяем логирование ошибки blacklist_history_repository.logger.error.assert_called_once() log_call = blacklist_history_repository.logger.error.call_args[0][0] diff --git a/tests/test_blacklist_repository.py b/tests/test_blacklist_repository.py index f1cbf88..97caf4f 100644 --- a/tests/test_blacklist_repository.py +++ b/tests/test_blacklist_repository.py @@ -3,13 +3,14 @@ from datetime import datetime from unittest.mock import AsyncMock, MagicMock, Mock, patch import pytest + from database.models import BlacklistUser from database.repositories.blacklist_repository import BlacklistRepository class TestBlacklistRepository: """Тесты для BlacklistRepository""" - + @pytest.fixture def mock_db_connection(self): """Мок для DatabaseConnection""" @@ -18,18 +19,20 @@ class TestBlacklistRepository: mock_connection._execute_query_with_result = AsyncMock() mock_connection.logger = Mock() return mock_connection - + @pytest.fixture def blacklist_repository(self, mock_db_connection): """Экземпляр BlacklistRepository для тестов""" # Патчим наследование от DatabaseConnection - with patch.object(BlacklistRepository, '__init__', return_value=None): + with patch.object(BlacklistRepository, "__init__", return_value=None): repo = BlacklistRepository() repo._execute_query = mock_db_connection._execute_query - repo._execute_query_with_result = mock_db_connection._execute_query_with_result + repo._execute_query_with_result = ( + mock_db_connection._execute_query_with_result + ) repo.logger = mock_db_connection.logger return repo - + @pytest.fixture def sample_blacklist_user(self): """Тестовый пользователь в черном списке""" @@ -40,7 +43,7 @@ class TestBlacklistRepository: created_at=int(time.time()), ban_author=999, ) - + @pytest.fixture def sample_blacklist_user_permanent(self): """Тестовый пользователь с постоянным баном""" @@ -51,144 +54,171 @@ class TestBlacklistRepository: created_at=int(time.time()), ban_author=None, ) - + @pytest.mark.asyncio async def test_create_tables(self, blacklist_repository): """Тест создания таблицы черного списка""" await blacklist_repository.create_tables() - + # Проверяем, что метод вызван blacklist_repository._execute_query.assert_called() calls = blacklist_repository._execute_query.call_args_list - + # Проверяем, что создается таблица с правильной структурой create_table_call = calls[0] assert "CREATE TABLE IF NOT EXISTS blacklist" in create_table_call[0][0] assert "user_id INTEGER NOT NULL PRIMARY KEY" in create_table_call[0][0] assert "message_for_user TEXT" in create_table_call[0][0] assert "date_to_unban INTEGER" in create_table_call[0][0] - assert "created_at INTEGER DEFAULT (strftime('%s', 'now'))" in create_table_call[0][0] - assert "FOREIGN KEY (user_id) REFERENCES our_users (user_id) ON DELETE CASCADE" in create_table_call[0][0] - + assert ( + "created_at INTEGER DEFAULT (strftime('%s', 'now'))" + in create_table_call[0][0] + ) + assert ( + "FOREIGN KEY (user_id) REFERENCES our_users (user_id) ON DELETE CASCADE" + in create_table_call[0][0] + ) + # Проверяем логирование - blacklist_repository.logger.info.assert_called_once_with("Таблица черного списка создана") - + blacklist_repository.logger.info.assert_called_once_with( + "Таблица черного списка создана" + ) + @pytest.mark.asyncio async def test_add_user(self, blacklist_repository, sample_blacklist_user): """Тест добавления пользователя в черный список""" await blacklist_repository.add_user(sample_blacklist_user) - + # Проверяем, что метод вызван с правильными параметрами blacklist_repository._execute_query.assert_called_once() call_args = blacklist_repository._execute_query.call_args - + # Проверяем SQL запрос (учитываем форматирование) - sql_query = call_args[0][0].replace('\n', ' ').replace(' ', ' ').replace(' ', ' ').strip() + sql_query = ( + call_args[0][0] + .replace("\n", " ") + .replace(" ", " ") + .replace(" ", " ") + .strip() + ) expected_sql = "INSERT INTO blacklist (user_id, message_for_user, date_to_unban, ban_author) VALUES (?, ?, ?, ?)" assert sql_query == expected_sql - + # Проверяем параметры - assert call_args[0][1] == (12345, "Нарушение правил", sample_blacklist_user.date_to_unban, 999) - + assert call_args[0][1] == ( + 12345, + "Нарушение правил", + sample_blacklist_user.date_to_unban, + 999, + ) + # Проверяем логирование blacklist_repository.logger.info.assert_called_once_with( "Пользователь добавлен в черный список: user_id=12345" ) - + @pytest.mark.asyncio - async def test_add_user_permanent_ban(self, blacklist_repository, sample_blacklist_user_permanent): + async def test_add_user_permanent_ban( + self, blacklist_repository, sample_blacklist_user_permanent + ): """Тест добавления пользователя с постоянным баном""" await blacklist_repository.add_user(sample_blacklist_user_permanent) - + call_args = blacklist_repository._execute_query.call_args assert call_args[0][1] == (67890, "Постоянный бан", None, None) - + blacklist_repository.logger.info.assert_called_once_with( "Пользователь добавлен в черный список: user_id=67890" ) - + @pytest.mark.asyncio async def test_remove_user_success(self, blacklist_repository): """Тест успешного удаления пользователя из черного списка""" await blacklist_repository.remove_user(12345) - + # Проверяем, что метод вызван с правильными параметрами blacklist_repository._execute_query.assert_called_once() call_args = blacklist_repository._execute_query.call_args - + assert call_args[0][0] == "DELETE FROM blacklist WHERE user_id = ?" assert call_args[0][1] == (12345,) - + # Проверяем логирование blacklist_repository.logger.info.assert_called_once_with( "Пользователь с идентификатором 12345 успешно удален из черного списка." ) - + @pytest.mark.asyncio async def test_remove_user_failure(self, blacklist_repository): """Тест неудачного удаления пользователя из черного списка""" # Симулируем ошибку при удалении blacklist_repository._execute_query.side_effect = Exception("Database error") - + result = await blacklist_repository.remove_user(12345) - + # Проверяем, что возвращается False при ошибке assert result is False - + # Проверяем логирование ошибки blacklist_repository.logger.error.assert_called_once() error_log = blacklist_repository.logger.error.call_args[0][0] assert "Ошибка удаления пользователя с идентификатором 12345" in error_log assert "Database error" in error_log - + @pytest.mark.asyncio async def test_user_exists_true(self, blacklist_repository): """Тест проверки существования пользователя (пользователь существует)""" # Симулируем результат запроса - пользователь найден blacklist_repository._execute_query_with_result.return_value = [(1,)] - + result = await blacklist_repository.user_exists(12345) - + # Проверяем, что возвращается True assert result is True - + # Проверяем, что метод вызван с правильными параметрами blacklist_repository._execute_query_with_result.assert_called_once() call_args = blacklist_repository._execute_query_with_result.call_args - + assert call_args[0][0] == "SELECT 1 FROM blacklist WHERE user_id = ?" assert call_args[0][1] == (12345,) - + # Проверяем логирование blacklist_repository.logger.info.assert_called_once_with( "Существует ли пользователь: user_id=12345 Итог: [(1,)]" ) - + @pytest.mark.asyncio async def test_user_exists_false(self, blacklist_repository): """Тест проверки существования пользователя (пользователь не существует)""" # Симулируем результат запроса - пользователь не найден blacklist_repository._execute_query_with_result.return_value = [] - + result = await blacklist_repository.user_exists(12345) - + # Проверяем, что возвращается False assert result is False - + # Проверяем логирование blacklist_repository.logger.info.assert_called_once_with( "Существует ли пользователь: user_id=12345 Итог: []" ) - + @pytest.mark.asyncio async def test_get_user_success(self, blacklist_repository): """Тест успешного получения пользователя по ID""" # Симулируем результат запроса - mock_row = (12345, "Нарушение правил", int(time.time()) + 86400, int(time.time()), 111) + mock_row = ( + 12345, + "Нарушение правил", + int(time.time()) + 86400, + int(time.time()), + 111, + ) blacklist_repository._execute_query_with_result.return_value = [mock_row] - + result = await blacklist_repository.get_user(12345) - + # Проверяем, что возвращается правильный объект assert result is not None assert result.user_id == 12345 @@ -196,37 +226,40 @@ class TestBlacklistRepository: assert result.date_to_unban == mock_row[2] assert result.created_at == mock_row[3] assert result.ban_author == mock_row[4] - + # Проверяем, что метод вызван с правильными параметрами blacklist_repository._execute_query_with_result.assert_called_once() call_args = blacklist_repository._execute_query_with_result.call_args - - assert "SELECT user_id, message_for_user, date_to_unban, created_at, ban_author" in call_args[0][0] + + assert ( + "SELECT user_id, message_for_user, date_to_unban, created_at, ban_author" + in call_args[0][0] + ) assert call_args[0][1] == (12345,) - + @pytest.mark.asyncio async def test_get_user_not_found(self, blacklist_repository): """Тест получения пользователя по ID (пользователь не найден)""" # Симулируем результат запроса - пользователь не найден blacklist_repository._execute_query_with_result.return_value = [] - + result = await blacklist_repository.get_user(12345) - + # Проверяем, что возвращается None assert result is None - + @pytest.mark.asyncio async def test_get_all_users_with_limits(self, blacklist_repository): """Тест получения пользователей с лимитами""" # Симулируем результат запроса mock_rows = [ (12345, "Нарушение правил", int(time.time()) + 86400, int(time.time())), - (67890, "Постоянный бан", None, int(time.time()) - 86400) + (67890, "Постоянный бан", None, int(time.time()) - 86400), ] blacklist_repository._execute_query_with_result.return_value = mock_rows - + result = await blacklist_repository.get_all_users(offset=0, limit=10) - + # Проверяем, что возвращается правильный список assert len(result) == 2 assert result[0].user_id == 12345 @@ -234,188 +267,211 @@ class TestBlacklistRepository: assert result[1].user_id == 67890 assert result[1].message_for_user == "Постоянный бан" assert result[1].date_to_unban is None - + # Проверяем, что метод вызван с правильными параметрами blacklist_repository._execute_query_with_result.assert_called_once() call_args = blacklist_repository._execute_query_with_result.call_args - + # Нормализуем SQL запрос (убираем лишние пробелы и переносы строк) - actual_query = ' '.join(call_args[0][0].split()) + actual_query = " ".join(call_args[0][0].split()) expected_query = "SELECT user_id, message_for_user, date_to_unban, created_at, ban_author FROM blacklist LIMIT ?, ?" assert actual_query == expected_query assert call_args[0][1] == (0, 10) - + # Проверяем логирование blacklist_repository.logger.info.assert_called_once_with( "Получен список пользователей в черном списке (offset=0, limit=10): 2" ) - + @pytest.mark.asyncio async def test_get_all_users_no_limit(self, blacklist_repository): """Тест получения всех пользователей без лимитов""" # Симулируем результат запроса (теперь включает ban_author) mock_rows = [ - (12345, "Нарушение правил", int(time.time()) + 86400, int(time.time()), 999), - (67890, "Постоянный бан", None, int(time.time()) - 86400, None) + ( + 12345, + "Нарушение правил", + int(time.time()) + 86400, + int(time.time()), + 999, + ), + (67890, "Постоянный бан", None, int(time.time()) - 86400, None), ] blacklist_repository._execute_query_with_result.return_value = mock_rows - + result = await blacklist_repository.get_all_users_no_limit() - + # Проверяем, что возвращается правильный список assert len(result) == 2 - + # Проверяем, что метод вызван без лимитов blacklist_repository._execute_query_with_result.assert_called_once() call_args = blacklist_repository._execute_query_with_result.call_args - + # Нормализуем SQL запрос (убираем лишние пробелы и переносы строк) - actual_query = ' '.join(call_args[0][0].split()) + actual_query = " ".join(call_args[0][0].split()) expected_query = "SELECT user_id, message_for_user, date_to_unban, created_at, ban_author FROM blacklist" assert actual_query == expected_query # Проверяем, что параметры пустые (без лимитов) assert len(call_args[0]) == 1 # Только SQL запрос, без параметров - + # Проверяем логирование blacklist_repository.logger.info.assert_called_once_with( "Получен список всех пользователей в черном списке: 2" ) - + @pytest.mark.asyncio async def test_get_users_for_unblock_today(self, blacklist_repository): """Тест получения пользователей для разблокировки сегодня""" current_timestamp = int(time.time()) - + # Симулируем результат запроса - пользователи с истекшим сроком mock_rows = [(12345,), (67890,)] blacklist_repository._execute_query_with_result.return_value = mock_rows - - result = await blacklist_repository.get_users_for_unblock_today(current_timestamp) - + + result = await blacklist_repository.get_users_for_unblock_today( + current_timestamp + ) + # Проверяем, что возвращается правильный словарь assert len(result) == 2 assert 12345 in result assert 67890 in result assert result[12345] == 12345 assert result[67890] == 67890 - + # Проверяем, что метод вызван с правильными параметрами blacklist_repository._execute_query_with_result.assert_called_once() call_args = blacklist_repository._execute_query_with_result.call_args - - assert call_args[0][0] == "SELECT user_id FROM blacklist WHERE date_to_unban IS NOT NULL AND date_to_unban <= ?" + + assert ( + call_args[0][0] + == "SELECT user_id FROM blacklist WHERE date_to_unban IS NOT NULL AND date_to_unban <= ?" + ) assert call_args[0][1] == (current_timestamp,) - + # Проверяем логирование blacklist_repository.logger.info.assert_called_once_with( "Получен список пользователей для разблокировки: {12345: 12345, 67890: 67890}" ) - + @pytest.mark.asyncio async def test_get_users_for_unblock_today_empty(self, blacklist_repository): """Тест получения пользователей для разблокировки (пустой результат)""" current_timestamp = int(time.time()) - + # Симулируем пустой результат запроса blacklist_repository._execute_query_with_result.return_value = [] - - result = await blacklist_repository.get_users_for_unblock_today(current_timestamp) - + + result = await blacklist_repository.get_users_for_unblock_today( + current_timestamp + ) + # Проверяем, что возвращается пустой словарь assert result == {} - + # Проверяем логирование blacklist_repository.logger.info.assert_called_once_with( "Получен список пользователей для разблокировки: {}" ) - + @pytest.mark.asyncio async def test_get_count(self, blacklist_repository): """Тест получения количества пользователей в черном списке""" # Симулируем результат запроса blacklist_repository._execute_query_with_result.return_value = [(5,)] - + result = await blacklist_repository.get_count() - + # Проверяем, что возвращается правильное количество assert result == 5 - + # Проверяем, что метод вызван с правильными параметрами blacklist_repository._execute_query_with_result.assert_called_once() call_args = blacklist_repository._execute_query_with_result.call_args - + assert call_args[0][0] == "SELECT COUNT(*) FROM blacklist" # Проверяем, что параметры пустые assert len(call_args[0]) == 1 # Только SQL запрос, без параметров - + @pytest.mark.asyncio async def test_get_count_zero(self, blacklist_repository): """Тест получения количества пользователей (0 пользователей)""" # Симулируем пустой результат запроса blacklist_repository._execute_query_with_result.return_value = [] - + result = await blacklist_repository.get_count() - + # Проверяем, что возвращается 0 assert result == 0 - + @pytest.mark.asyncio async def test_get_count_none_result(self, blacklist_repository): """Тест получения количества пользователей (None результат)""" # Симулируем None результат запроса blacklist_repository._execute_query_with_result.return_value = None - + result = await blacklist_repository.get_count() - + # Проверяем, что возвращается 0 assert result == 0 - + @pytest.mark.asyncio async def test_error_handling_in_get_user(self, blacklist_repository): """Тест обработки ошибок при получении пользователя""" # Симулируем ошибку базы данных - blacklist_repository._execute_query_with_result.side_effect = Exception("Database connection failed") - + blacklist_repository._execute_query_with_result.side_effect = Exception( + "Database connection failed" + ) + # Проверяем, что исключение пробрасывается with pytest.raises(Exception) as exc_info: await blacklist_repository.get_user(12345) - + assert "Database connection failed" in str(exc_info.value) - + @pytest.mark.asyncio async def test_error_handling_in_get_all_users(self, blacklist_repository): """Тест обработки ошибок при получении всех пользователей""" # Симулируем ошибку базы данных - blacklist_repository._execute_query_with_result.side_effect = Exception("Database connection failed") - + blacklist_repository._execute_query_with_result.side_effect = Exception( + "Database connection failed" + ) + # Проверяем, что исключение пробрасывается with pytest.raises(Exception) as exc_info: await blacklist_repository.get_all_users() - + assert "Database connection failed" in str(exc_info.value) - + @pytest.mark.asyncio async def test_error_handling_in_get_count(self, blacklist_repository): """Тест обработки ошибок при получении количества""" # Симулируем ошибку базы данных - blacklist_repository._execute_query_with_result.side_effect = Exception("Database connection failed") - + blacklist_repository._execute_query_with_result.side_effect = Exception( + "Database connection failed" + ) + # Проверяем, что исключение пробрасывается with pytest.raises(Exception) as exc_info: await blacklist_repository.get_count() - + assert "Database connection failed" in str(exc_info.value) - + @pytest.mark.asyncio - async def test_error_handling_in_get_users_for_unblock_today(self, blacklist_repository): + async def test_error_handling_in_get_users_for_unblock_today( + self, blacklist_repository + ): """Тест обработки ошибок при получении пользователей для разблокировки""" # Симулируем ошибку базы данных - blacklist_repository._execute_query_with_result.side_effect = Exception("Database connection failed") - + blacklist_repository._execute_query_with_result.side_effect = Exception( + "Database connection failed" + ) + # Проверяем, что исключение пробрасывается with pytest.raises(Exception) as exc_info: await blacklist_repository.get_users_for_unblock_today(int(time.time())) - + assert "Database connection failed" in str(exc_info.value) # TODO: 20-й тест - test_integration_workflow @@ -426,7 +482,7 @@ class TestBlacklistRepository: # 4. Получение общего количества пользователей # 5. Удаление пользователя из черного списка # 6. Проверка, что пользователь больше не существует - # + # # Проблема: тест падает из-за сложности мокирования возвращаемых значений # при создании объектов BlacklistUser из результатов запросов к БД. # Требует более сложной настройки моков для корректной работы. diff --git a/tests/test_callback_handlers.py b/tests/test_callback_handlers.py index a26c108..b011e82 100644 --- a/tests/test_callback_handlers.py +++ b/tests/test_callback_handlers.py @@ -3,13 +3,16 @@ from datetime import datetime from unittest.mock import AsyncMock, MagicMock, Mock, patch import pytest + from helper_bot.handlers.callback.callback_handlers import ( change_page, delete_voice_message, process_ban_user, process_unlock_user, return_to_main_menu, + save_voice_message, +, ) from helper_bot.handlers.voice.constants import CALLBACK_DELETE, CALLBACK_SAVE @@ -27,6 +30,7 @@ def mock_call(): call.answer = AsyncMock() return call + @pytest.fixture def mock_bot_db(): """Мок для базы данных""" @@ -35,20 +39,20 @@ def mock_bot_db(): mock_db.delete_audio_moderate_record = AsyncMock() return mock_db + @pytest.fixture def mock_settings(): """Мок для настроек""" - return { - 'Telegram': { - 'group_for_posts': 'test_group_id' - } - } + return {"Telegram": {"group_for_posts": "test_group_id"}} + @pytest.fixture def mock_audio_service(): """Мок для AudioFileService""" mock_service = Mock() - mock_service.generate_file_name = AsyncMock(return_value="message_from_67890_number_1") + mock_service.generate_file_name = AsyncMock( + return_value="message_from_67890_number_1" + ) mock_service.save_audio_file = AsyncMock() mock_service.download_and_save_audio = AsyncMock() return mock_service @@ -56,143 +60,205 @@ def mock_audio_service(): class TestSaveVoiceMessage: """Тесты для функции save_voice_message""" - + @pytest.mark.asyncio - async def test_save_voice_message_success(self, mock_call, mock_bot_db, mock_settings, mock_audio_service): + async def test_save_voice_message_success( + self, mock_call, mock_bot_db, mock_settings, mock_audio_service + ): """Тест успешного сохранения голосового сообщения""" - with patch('helper_bot.handlers.callback.callback_handlers.AudioFileService') as mock_service_class: + with patch( + "helper_bot.handlers.callback.callback_handlers.AudioFileService" + ) as mock_service_class: mock_service_class.return_value = mock_audio_service - - await save_voice_message(mock_call, bot_db=mock_bot_db, settings=mock_settings) - + + await save_voice_message( + mock_call, bot_db=mock_bot_db, settings=mock_settings + ) + # Проверяем, что все методы вызваны - mock_bot_db.get_user_id_by_message_id_for_voice_bot.assert_called_once_with(12345) + mock_bot_db.get_user_id_by_message_id_for_voice_bot.assert_called_once_with( + 12345 + ) mock_audio_service.generate_file_name.assert_called_once_with(67890) mock_audio_service.save_audio_file.assert_called_once() mock_audio_service.download_and_save_audio.assert_called_once_with( mock_call.bot, mock_call.message, "message_from_67890_number_1" ) - + # Проверяем удаление сообщения из чата mock_call.bot.delete_message.assert_called_once_with( - chat_id='test_group_id', - message_id=12345 + chat_id="test_group_id", message_id=12345 ) - + # Проверяем удаление записи из audio_moderate mock_bot_db.delete_audio_moderate_record.assert_called_once_with(12345) - + # Проверяем ответ пользователю - mock_call.answer.assert_called_once_with(text='Сохранено!', cache_time=3) - + mock_call.answer.assert_called_once_with(text="Сохранено!", cache_time=3) + @pytest.mark.asyncio - async def test_save_voice_message_with_correct_parameters(self, mock_call, mock_bot_db, mock_settings, mock_audio_service): + async def test_save_voice_message_with_correct_parameters( + self, mock_call, mock_bot_db, mock_settings, mock_audio_service + ): """Тест сохранения с правильными параметрами""" - with patch('helper_bot.handlers.callback.callback_handlers.AudioFileService') as mock_service_class: + with patch( + "helper_bot.handlers.callback.callback_handlers.AudioFileService" + ) as mock_service_class: mock_service_class.return_value = mock_audio_service - - await save_voice_message(mock_call, bot_db=mock_bot_db, settings=mock_settings) - + + await save_voice_message( + mock_call, bot_db=mock_bot_db, settings=mock_settings + ) + # Проверяем параметры save_audio_file save_call_args = mock_audio_service.save_audio_file.call_args assert save_call_args[0][0] == "message_from_67890_number_1" # file_name assert save_call_args[0][1] == 67890 # user_id assert isinstance(save_call_args[0][2], datetime) # date_added assert save_call_args[0][3] == "test_file_id_123" # file_id - + @pytest.mark.asyncio - async def test_save_voice_message_exception_handling(self, mock_call, mock_bot_db, mock_settings): + async def test_save_voice_message_exception_handling( + self, mock_call, mock_bot_db, mock_settings + ): """Тест обработки исключений при сохранении""" - mock_bot_db.get_user_id_by_message_id_for_voice_bot.side_effect = Exception("Database error") - + mock_bot_db.get_user_id_by_message_id_for_voice_bot.side_effect = Exception( + "Database error" + ) + await save_voice_message(mock_call, bot_db=mock_bot_db, settings=mock_settings) - + # Проверяем, что при ошибке отправляется соответствующий ответ - mock_call.answer.assert_called_once_with(text='Ошибка при сохранении!', cache_time=3) - + mock_call.answer.assert_called_once_with( + text="Ошибка при сохранении!", cache_time=3 + ) + @pytest.mark.asyncio - async def test_save_voice_message_audio_service_exception(self, mock_call, mock_bot_db, mock_settings, mock_audio_service): + async def test_save_voice_message_audio_service_exception( + self, mock_call, mock_bot_db, mock_settings, mock_audio_service + ): """Тест обработки исключений в AudioFileService""" mock_audio_service.save_audio_file.side_effect = Exception("Save error") - - with patch('helper_bot.handlers.callback.callback_handlers.AudioFileService') as mock_service_class: + + with patch( + "helper_bot.handlers.callback.callback_handlers.AudioFileService" + ) as mock_service_class: mock_service_class.return_value = mock_audio_service - - await save_voice_message(mock_call, bot_db=mock_bot_db, settings=mock_settings) - + + await save_voice_message( + mock_call, bot_db=mock_bot_db, settings=mock_settings + ) + # Проверяем, что при ошибке отправляется соответствующий ответ - mock_call.answer.assert_called_once_with(text='Ошибка при сохранении!', cache_time=3) - + mock_call.answer.assert_called_once_with( + text="Ошибка при сохранении!", cache_time=3 + ) + @pytest.mark.asyncio - async def test_save_voice_message_download_exception(self, mock_call, mock_bot_db, mock_settings, mock_audio_service): + async def test_save_voice_message_download_exception( + self, mock_call, mock_bot_db, mock_settings, mock_audio_service + ): """Тест обработки исключений при скачивании файла""" - mock_audio_service.download_and_save_audio.side_effect = Exception("Download error") - - with patch('helper_bot.handlers.callback.callback_handlers.AudioFileService') as mock_service_class: + mock_audio_service.download_and_save_audio.side_effect = Exception( + "Download error" + ) + + with patch( + "helper_bot.handlers.callback.callback_handlers.AudioFileService" + ) as mock_service_class: mock_service_class.return_value = mock_audio_service - - await save_voice_message(mock_call, bot_db=mock_bot_db, settings=mock_settings) - + + await save_voice_message( + mock_call, bot_db=mock_bot_db, settings=mock_settings + ) + # Проверяем, что при ошибке отправляется соответствующий ответ - mock_call.answer.assert_called_once_with(text='Ошибка при сохранении!', cache_time=3) + mock_call.answer.assert_called_once_with( + text="Ошибка при сохранении!", cache_time=3 + ) class TestDeleteVoiceMessage: """Тесты для функции delete_voice_message""" - + @pytest.mark.asyncio - async def test_delete_voice_message_success(self, mock_call, mock_bot_db, mock_settings): + async def test_delete_voice_message_success( + self, mock_call, mock_bot_db, mock_settings + ): """Тест успешного удаления голосового сообщения""" - await delete_voice_message(mock_call, bot_db=mock_bot_db, settings=mock_settings) - + await delete_voice_message( + mock_call, bot_db=mock_bot_db, settings=mock_settings + ) + # Проверяем удаление сообщения из чата mock_call.bot.delete_message.assert_called_once_with( - chat_id='test_group_id', - message_id=12345 + chat_id="test_group_id", message_id=12345 ) - + # Проверяем удаление записи из audio_moderate mock_bot_db.delete_audio_moderate_record.assert_called_once_with(12345) - + # Проверяем ответ пользователю - mock_call.answer.assert_called_once_with(text='Удалено!', cache_time=3) - + mock_call.answer.assert_called_once_with(text="Удалено!", cache_time=3) + @pytest.mark.asyncio - async def test_delete_voice_message_exception_handling(self, mock_call, mock_bot_db, mock_settings): + async def test_delete_voice_message_exception_handling( + self, mock_call, mock_bot_db, mock_settings + ): """Тест обработки исключений при удалении""" mock_call.bot.delete_message.side_effect = Exception("Delete error") - - await delete_voice_message(mock_call, bot_db=mock_bot_db, settings=mock_settings) - + + await delete_voice_message( + mock_call, bot_db=mock_bot_db, settings=mock_settings + ) + # Проверяем, что при ошибке отправляется соответствующий ответ - mock_call.answer.assert_called_once_with(text='Ошибка при удалении!', cache_time=3) - + mock_call.answer.assert_called_once_with( + text="Ошибка при удалении!", cache_time=3 + ) + @pytest.mark.asyncio - async def test_delete_voice_message_database_exception(self, mock_call, mock_bot_db, mock_settings): + async def test_delete_voice_message_database_exception( + self, mock_call, mock_bot_db, mock_settings + ): """Тест обработки исключений в базе данных при удалении""" - mock_bot_db.delete_audio_moderate_record.side_effect = Exception("Database error") - - await delete_voice_message(mock_call, bot_db=mock_bot_db, settings=mock_settings) - + mock_bot_db.delete_audio_moderate_record.side_effect = Exception( + "Database error" + ) + + await delete_voice_message( + mock_call, bot_db=mock_bot_db, settings=mock_settings + ) + # Проверяем, что при ошибке отправляется соответствующий ответ - mock_call.answer.assert_called_once_with(text='Ошибка при удалении!', cache_time=3) + mock_call.answer.assert_called_once_with( + text="Ошибка при удалении!", cache_time=3 + ) class TestCallbackHandlersIntegration: """Интеграционные тесты для callback handlers""" - + @pytest.mark.asyncio - async def test_save_voice_message_full_workflow(self, mock_call, mock_bot_db, mock_settings): + async def test_save_voice_message_full_workflow( + self, mock_call, mock_bot_db, mock_settings + ): """Тест полного рабочего процесса сохранения""" - with patch('helper_bot.handlers.callback.callback_handlers.AudioFileService') as mock_service_class: + with patch( + "helper_bot.handlers.callback.callback_handlers.AudioFileService" + ) as mock_service_class: mock_service = Mock() - mock_service.generate_file_name = AsyncMock(return_value="message_from_67890_number_1") + mock_service.generate_file_name = AsyncMock( + return_value="message_from_67890_number_1" + ) mock_service.save_audio_file = AsyncMock() mock_service.download_and_save_audio = AsyncMock() mock_service_class.return_value = mock_service - - await save_voice_message(mock_call, bot_db=mock_bot_db, settings=mock_settings) - + + await save_voice_message( + mock_call, bot_db=mock_bot_db, settings=mock_settings + ) + # Проверяем последовательность вызовов assert mock_bot_db.get_user_id_by_message_id_for_voice_bot.called assert mock_service.generate_file_name.called @@ -201,46 +267,62 @@ class TestCallbackHandlersIntegration: assert mock_call.bot.delete_message.called assert mock_bot_db.delete_audio_moderate_record.called assert mock_call.answer.called - + @pytest.mark.asyncio - async def test_delete_voice_message_full_workflow(self, mock_call, mock_bot_db, mock_settings): + async def test_delete_voice_message_full_workflow( + self, mock_call, mock_bot_db, mock_settings + ): """Тест полного рабочего процесса удаления""" - await delete_voice_message(mock_call, bot_db=mock_bot_db, settings=mock_settings) - + await delete_voice_message( + mock_call, bot_db=mock_bot_db, settings=mock_settings + ) + # Проверяем последовательность вызовов assert mock_call.bot.delete_message.called assert mock_bot_db.delete_audio_moderate_record.called assert mock_call.answer.called - + @pytest.mark.asyncio - async def test_audio_moderate_cleanup_consistency(self, mock_call, mock_bot_db, mock_settings): + async def test_audio_moderate_cleanup_consistency( + self, mock_call, mock_bot_db, mock_settings + ): """Тест консистентности очистки audio_moderate""" - # Тестируем, что в обоих случаях (сохранение и удаление) + # Тестируем, что в обоих случаях (сохранение и удаление) # вызывается delete_audio_moderate_record - + # Создаем отдельные моки для каждого теста mock_bot_db_save = Mock() - mock_bot_db_save.get_user_id_by_message_id_for_voice_bot = AsyncMock(return_value=67890) + mock_bot_db_save.get_user_id_by_message_id_for_voice_bot = AsyncMock( + return_value=67890 + ) mock_bot_db_save.delete_audio_moderate_record = AsyncMock() - + mock_bot_db_delete = Mock() mock_bot_db_delete.delete_audio_moderate_record = AsyncMock() - + # Тест для сохранения - with patch('helper_bot.handlers.callback.callback_handlers.AudioFileService') as mock_service_class: + with patch( + "helper_bot.handlers.callback.callback_handlers.AudioFileService" + ) as mock_service_class: mock_service = Mock() - mock_service.generate_file_name = AsyncMock(return_value="message_from_67890_number_1") + mock_service.generate_file_name = AsyncMock( + return_value="message_from_67890_number_1" + ) mock_service.save_audio_file = AsyncMock() mock_service.download_and_save_audio = AsyncMock() mock_service_class.return_value = mock_service - - await save_voice_message(mock_call, bot_db=mock_bot_db_save, settings=mock_settings) + + await save_voice_message( + mock_call, bot_db=mock_bot_db_save, settings=mock_settings + ) save_calls = mock_bot_db_save.delete_audio_moderate_record.call_count - + # Тест для удаления - await delete_voice_message(mock_call, bot_db=mock_bot_db_delete, settings=mock_settings) + await delete_voice_message( + mock_call, bot_db=mock_bot_db_delete, settings=mock_settings + ) delete_calls = mock_bot_db_delete.delete_audio_moderate_record.call_count - + # Проверяем, что в обоих случаях вызывается очистка assert save_calls == 1 assert delete_calls == 1 @@ -248,9 +330,11 @@ class TestCallbackHandlersIntegration: class TestCallbackHandlersEdgeCases: """Тесты граничных случаев для callback handlers""" - + @pytest.mark.asyncio - async def test_save_voice_message_no_voice_attribute(self, mock_bot_db, mock_settings): + async def test_save_voice_message_no_voice_attribute( + self, mock_bot_db, mock_settings + ): """Тест сохранения когда у сообщения нет voice атрибута""" call = Mock() call.message = Mock() @@ -259,26 +343,36 @@ class TestCallbackHandlersEdgeCases: call.bot = Mock() call.bot.delete_message = AsyncMock() call.answer = AsyncMock() - - with patch('helper_bot.handlers.callback.callback_handlers.AudioFileService'): + + with patch("helper_bot.handlers.callback.callback_handlers.AudioFileService"): await save_voice_message(call, bot_db=mock_bot_db, settings=mock_settings) - + # Должна быть ошибка - call.answer.assert_called_once_with(text='Ошибка при сохранении!', cache_time=3) - + call.answer.assert_called_once_with( + text="Ошибка при сохранении!", cache_time=3 + ) + @pytest.mark.asyncio - async def test_save_voice_message_user_not_found(self, mock_call, mock_bot_db, mock_settings): + async def test_save_voice_message_user_not_found( + self, mock_call, mock_bot_db, mock_settings + ): """Тест сохранения когда пользователь не найден""" mock_bot_db.get_user_id_by_message_id_for_voice_bot.return_value = None - - with patch('helper_bot.handlers.callback.callback_handlers.AudioFileService'): - await save_voice_message(mock_call, bot_db=mock_bot_db, settings=mock_settings) - + + with patch("helper_bot.handlers.callback.callback_handlers.AudioFileService"): + await save_voice_message( + mock_call, bot_db=mock_bot_db, settings=mock_settings + ) + # Должна быть ошибка - mock_call.answer.assert_called_once_with(text='Ошибка при сохранении!', cache_time=3) - + mock_call.answer.assert_called_once_with( + text="Ошибка при сохранении!", cache_time=3 + ) + @pytest.mark.asyncio - async def test_delete_voice_message_with_different_message_id(self, mock_bot_db, mock_settings): + async def test_delete_voice_message_with_different_message_id( + self, mock_bot_db, mock_settings + ): """Тест удаления с другим message_id""" call = Mock() call.message = Mock() @@ -286,13 +380,12 @@ class TestCallbackHandlersEdgeCases: call.bot = Mock() call.bot.delete_message = AsyncMock() call.answer = AsyncMock() - + await delete_voice_message(call, bot_db=mock_bot_db, settings=mock_settings) - + # Проверяем, что используется правильный message_id call.bot.delete_message.assert_called_once_with( - chat_id='test_group_id', - message_id=99999 + chat_id="test_group_id", message_id=99999 ) mock_bot_db.delete_audio_moderate_record.assert_called_once_with(99999) @@ -536,5 +629,5 @@ class TestProcessUnlockUser: ) -if __name__ == '__main__': +if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/test_improved_media_processing.py b/tests/test_improved_media_processing.py index d39fac0..05d78e7 100644 --- a/tests/test_improved_media_processing.py +++ b/tests/test_improved_media_processing.py @@ -8,67 +8,81 @@ from unittest.mock import AsyncMock, MagicMock, Mock, patch import pytest from aiogram import types + from helper_bot.utils.helper_func import ( - add_in_db_media, add_in_db_media_mediagroup, download_file, - send_media_group_message_to_private_chat) + add_in_db_media, + add_in_db_media_mediagroup, + download_file, + send_media_group_message_to_private_chat, +) class TestDownloadFile: """Тесты для функции download_file""" - + @pytest.mark.asyncio async def test_download_file_success_photo(self): """Тест успешного скачивания фото""" # Создаем временную директорию with tempfile.TemporaryDirectory() as temp_dir: - with patch('helper_bot.utils.helper_func.os.makedirs'), \ - patch('helper_bot.utils.helper_func.os.path.exists', return_value=True), \ - patch('helper_bot.utils.helper_func.os.path.getsize', return_value=1024), \ - patch('helper_bot.utils.helper_func.os.path.basename', return_value='photo.jpg'), \ - patch('helper_bot.utils.helper_func.os.path.splitext', return_value=('photo', '.jpg')): - + with ( + patch("helper_bot.utils.helper_func.os.makedirs"), + patch("helper_bot.utils.helper_func.os.path.exists", return_value=True), + patch( + "helper_bot.utils.helper_func.os.path.getsize", return_value=1024 + ), + patch( + "helper_bot.utils.helper_func.os.path.basename", + return_value="photo.jpg", + ), + patch( + "helper_bot.utils.helper_func.os.path.splitext", + return_value=("photo", ".jpg"), + ), + ): + # Мокаем сообщение и бота mock_message = Mock() mock_message.bot = Mock() mock_file = Mock() - mock_file.file_path = 'photos/photo.jpg' + mock_file.file_path = "photos/photo.jpg" mock_message.bot.get_file = AsyncMock(return_value=mock_file) mock_message.bot.download_file = AsyncMock() - + # Вызываем функцию - result = await download_file(mock_message, 'test_file_id', 'photo') - + result = await download_file(mock_message, "test_file_id", "photo") + # Проверяем результат assert result is not None - assert 'files/photos/test_file_id.jpg' in result - mock_message.bot.get_file.assert_called_once_with('test_file_id') + assert "files/photos/test_file_id.jpg" in result + mock_message.bot.get_file.assert_called_once_with("test_file_id") mock_message.bot.download_file.assert_called_once() - + @pytest.mark.asyncio async def test_download_file_invalid_parameters(self): """Тест с неверными параметрами""" - result = await download_file(None, 'test_file_id', 'photo') + result = await download_file(None, "test_file_id", "photo") assert result is None - + mock_message = Mock() mock_message.bot = None - result = await download_file(mock_message, 'test_file_id', 'photo') + result = await download_file(mock_message, "test_file_id", "photo") assert result is None - + @pytest.mark.asyncio async def test_download_file_error(self): """Тест обработки ошибки при скачивании""" mock_message = Mock() mock_message.bot = Mock() mock_message.bot.get_file = AsyncMock(side_effect=Exception("Network error")) - - result = await download_file(mock_message, 'test_file_id', 'photo') + + result = await download_file(mock_message, "test_file_id", "photo") assert result is None class TestAddInDbMedia: """Тесты для функции add_in_db_media""" - + @pytest.mark.asyncio async def test_add_in_db_media_success_photo(self): """Тест успешного добавления фото в БД""" @@ -76,65 +90,75 @@ class TestAddInDbMedia: mock_message = Mock() mock_message.message_id = 123 mock_message.photo = [Mock()] - mock_message.photo[-1].file_id = 'photo_123' + mock_message.photo[-1].file_id = "photo_123" mock_message.video = None mock_message.voice = None mock_message.audio = None mock_message.video_note = None - + # Мокаем БД mock_db = AsyncMock() mock_db.add_post_content = AsyncMock(return_value=True) - - with patch('helper_bot.utils.helper_func.download_file', return_value='files/photos/photo_123.jpg'): + + with patch( + "helper_bot.utils.helper_func.download_file", + return_value="files/photos/photo_123.jpg", + ): result = await add_in_db_media(mock_message, mock_db) - + assert result is True - mock_db.add_post_content.assert_called_once_with(123, 123, 'files/photos/photo_123.jpg', 'photo') - + mock_db.add_post_content.assert_called_once_with( + 123, 123, "files/photos/photo_123.jpg", "photo" + ) + @pytest.mark.asyncio async def test_add_in_db_media_download_fails(self): """Тест когда скачивание файла не удается""" mock_message = Mock() mock_message.message_id = 123 mock_message.photo = [Mock()] - mock_message.photo[-1].file_id = 'photo_123' + mock_message.photo[-1].file_id = "photo_123" mock_message.video = None mock_message.voice = None mock_message.audio = None mock_message.video_note = None - + mock_db = AsyncMock() - - with patch('helper_bot.utils.helper_func.download_file', return_value=None): + + with patch("helper_bot.utils.helper_func.download_file", return_value=None): result = await add_in_db_media(mock_message, mock_db) - + assert result is False mock_db.add_post_content.assert_not_called() - + @pytest.mark.asyncio async def test_add_in_db_media_db_fails(self): """Тест когда добавление в БД не удается""" mock_message = Mock() mock_message.message_id = 123 mock_message.photo = [Mock()] - mock_message.photo[-1].file_id = 'photo_123' + mock_message.photo[-1].file_id = "photo_123" mock_message.video = None mock_message.voice = None mock_message.audio = None mock_message.video_note = None - + mock_db = AsyncMock() mock_db.add_post_content = AsyncMock(return_value=False) - - with patch('helper_bot.utils.helper_func.download_file', return_value='files/photos/photo_123.jpg'), \ - patch('helper_bot.utils.helper_func.os.remove'): - + + with ( + patch( + "helper_bot.utils.helper_func.download_file", + return_value="files/photos/photo_123.jpg", + ), + patch("helper_bot.utils.helper_func.os.remove"), + ): + result = await add_in_db_media(mock_message, mock_db) - + assert result is False mock_db.add_post_content.assert_called_once() - + @pytest.mark.asyncio async def test_add_in_db_media_unsupported_content(self): """Тест с неподдерживаемым типом контента""" @@ -145,18 +169,18 @@ class TestAddInDbMedia: mock_message.voice = None mock_message.audio = None mock_message.video_note = None - + mock_db = AsyncMock() - + result = await add_in_db_media(mock_message, mock_db) - + assert result is False mock_db.add_post_content.assert_not_called() class TestAddInDbMediaMediagroup: """Тесты для функции add_in_db_media_mediagroup""" - + @pytest.mark.asyncio async def test_add_in_db_media_mediagroup_success(self): """Тест успешного добавления медиагруппы в БД""" @@ -164,43 +188,47 @@ class TestAddInDbMediaMediagroup: mock_message1 = Mock() mock_message1.message_id = 1 mock_message1.photo = [Mock()] - mock_message1.photo[-1].file_id = 'photo_1' + mock_message1.photo[-1].file_id = "photo_1" mock_message1.video = None mock_message1.voice = None mock_message1.audio = None mock_message1.video_note = None - + mock_message2 = Mock() mock_message2.message_id = 2 mock_message2.photo = None mock_message2.video = Mock() - mock_message2.video.file_id = 'video_1' + mock_message2.video.file_id = "video_1" mock_message2.voice = None mock_message2.audio = None mock_message2.video_note = None - + sent_messages = [mock_message1, mock_message2] - + # Мокаем БД mock_db = AsyncMock() mock_db.add_post_content = AsyncMock(return_value=True) - - with patch('helper_bot.utils.helper_func.download_file', return_value='files/test.jpg'): - result = await add_in_db_media_mediagroup(sent_messages, mock_db, main_post_id=100) - + + with patch( + "helper_bot.utils.helper_func.download_file", return_value="files/test.jpg" + ): + result = await add_in_db_media_mediagroup( + sent_messages, mock_db, main_post_id=100 + ) + assert result is True assert mock_db.add_post_content.call_count == 2 - + @pytest.mark.asyncio async def test_add_in_db_media_mediagroup_empty_list(self): """Тест с пустым списком сообщений""" mock_db = AsyncMock() - + result = await add_in_db_media_mediagroup([], mock_db) - + assert result is False mock_db.add_post_content.assert_not_called() - + @pytest.mark.asyncio async def test_add_in_db_media_mediagroup_partial_failure(self): """Тест когда часть сообщений обрабатывается успешно""" @@ -208,12 +236,12 @@ class TestAddInDbMediaMediagroup: mock_message1 = Mock() mock_message1.message_id = 1 mock_message1.photo = [Mock()] - mock_message1.photo[-1].file_id = 'photo_1' + mock_message1.photo[-1].file_id = "photo_1" mock_message1.video = None mock_message1.voice = None mock_message1.audio = None mock_message1.video_note = None - + mock_message2 = Mock() mock_message2.message_id = 2 mock_message2.photo = None @@ -221,16 +249,18 @@ class TestAddInDbMediaMediagroup: mock_message2.voice = None mock_message2.audio = None mock_message2.video_note = None # Неподдерживаемый тип - + sent_messages = [mock_message1, mock_message2] - + # Мокаем БД mock_db = AsyncMock() mock_db.add_post_content = AsyncMock(return_value=True) - - with patch('helper_bot.utils.helper_func.download_file', return_value='files/test.jpg'): + + with patch( + "helper_bot.utils.helper_func.download_file", return_value="files/test.jpg" + ): result = await add_in_db_media_mediagroup(sent_messages, mock_db) - + # Должен вернуть False, так как есть ошибки (второе сообщение не поддерживается) assert result is False assert mock_db.add_post_content.call_count == 1 @@ -238,7 +268,7 @@ class TestAddInDbMediaMediagroup: class TestSendMediaGroupMessageToPrivateChat: """Тесты для функции send_media_group_message_to_private_chat""" - + @pytest.mark.asyncio async def test_send_media_group_message_success(self): """Тест успешной отправки медиагруппы""" @@ -246,25 +276,29 @@ class TestSendMediaGroupMessageToPrivateChat: mock_message = Mock() mock_message.from_user.id = 123 mock_message.bot = Mock() - + # Мокаем отправленное сообщение mock_sent_message = Mock() mock_sent_message.message_id = 456 mock_sent_message.caption = "Test caption" mock_message.bot.send_media_group = AsyncMock(return_value=[mock_sent_message]) - + # Мокаем БД mock_db = AsyncMock() - - with patch('helper_bot.utils.helper_func.add_in_db_media_mediagroup', return_value=True): - with patch('asyncio.create_task'): # Мокаем create_task, чтобы фоновая задача не выполнялась + + with patch( + "helper_bot.utils.helper_func.add_in_db_media_mediagroup", return_value=True + ): + with patch( + "asyncio.create_task" + ): # Мокаем create_task, чтобы фоновая задача не выполнялась result = await send_media_group_message_to_private_chat( 100, mock_message, [], mock_db, main_post_id=789 ) - + assert result == [456] # Функция возвращает список message_id mock_message.bot.send_media_group.assert_called_once() - + @pytest.mark.asyncio async def test_send_media_group_message_media_processing_fails(self): """Тест когда обработка медиа не удается""" @@ -272,22 +306,27 @@ class TestSendMediaGroupMessageToPrivateChat: mock_message = Mock() mock_message.from_user.id = 123 mock_message.bot = Mock() - + # Мокаем отправленное сообщение mock_sent_message = Mock() mock_sent_message.message_id = 456 mock_sent_message.caption = "Test caption" mock_message.bot.send_media_group = AsyncMock(return_value=[mock_sent_message]) - + # Мокаем БД mock_db = AsyncMock() - - with patch('helper_bot.utils.helper_func.add_in_db_media_mediagroup', return_value=False): - with patch('asyncio.create_task'): # Мокаем create_task, чтобы фоновая задача не выполнялась + + with patch( + "helper_bot.utils.helper_func.add_in_db_media_mediagroup", + return_value=False, + ): + with patch( + "asyncio.create_task" + ): # Мокаем create_task, чтобы фоновая задача не выполнялась result = await send_media_group_message_to_private_chat( 100, mock_message, [], mock_db, main_post_id=789 ) - + assert result == [456] # Функция возвращает список message_id mock_message.bot.send_media_group.assert_called_once() diff --git a/tests/test_keyboards_and_filters.py b/tests/test_keyboards_and_filters.py index f6e25d2..48de3c1 100644 --- a/tests/test_keyboards_and_filters.py +++ b/tests/test_keyboards_and_filters.py @@ -1,475 +1,489 @@ from unittest.mock import AsyncMock, Mock, patch import pytest -from aiogram.types import (InlineKeyboardButton, InlineKeyboardMarkup, - KeyboardButton, ReplyKeyboardMarkup) +from aiogram.types import ( + InlineKeyboardButton, + InlineKeyboardMarkup, + KeyboardButton, + ReplyKeyboardMarkup, +) + from database.async_db import AsyncBotDB from helper_bot.filters.main import ChatTypeFilter -from helper_bot.keyboards.keyboards import (create_keyboard_with_pagination, - get_reply_keyboard, - get_reply_keyboard_admin, - get_reply_keyboard_for_post, - get_reply_keyboard_leave_chat) +from helper_bot.keyboards.keyboards import ( + create_keyboard_with_pagination, + get_reply_keyboard, + get_reply_keyboard_admin, + get_reply_keyboard_for_post, + get_reply_keyboard_leave_chat, +) class TestKeyboards: """Тесты для клавиатур""" - + @pytest.fixture def mock_db(self): """Создает мок базы данных""" db = Mock(spec=AsyncBotDB) - db.get_user_info = Mock(return_value={ - 'stickers': True, - 'admin': False - }) + db.get_user_info = Mock(return_value={"stickers": True, "admin": False}) return db - + @pytest.mark.asyncio async def test_get_reply_keyboard_basic(self, mock_db): """Тест базовой клавиатуры""" user_id = 123456 - + keyboard = await get_reply_keyboard(mock_db, user_id) - + # Проверяем, что возвращается клавиатура assert isinstance(keyboard, ReplyKeyboardMarkup) assert keyboard.keyboard is not None assert len(keyboard.keyboard) > 0 - + # Проверяем, что каждая кнопка в отдельной строке for row in keyboard.keyboard: assert len(row) == 1 # Каждая строка содержит только одну кнопку - + # Проверяем наличие основных кнопок all_buttons = [] for row in keyboard.keyboard: for button in row: all_buttons.append(button.text) - + # Проверяем наличие основных кнопок - assert '📢Предложить свой пост' in all_buttons - assert '👋🏼Сказать пока!' in all_buttons - assert '📩Связаться с админами' in all_buttons - + assert "📢Предложить свой пост" in all_buttons + assert "👋🏼Сказать пока!" in all_buttons + assert "📩Связаться с админами" in all_buttons + @pytest.mark.asyncio async def test_get_reply_keyboard_with_stickers(self, mock_db): """Тест клавиатуры со стикерами""" user_id = 123456 # Мокаем метод get_stickers_info mock_db.get_stickers_info = AsyncMock(return_value=False) - + keyboard = await get_reply_keyboard(mock_db, user_id) - + all_buttons = [] for row in keyboard.keyboard: for button in row: all_buttons.append(button.text) - + # Проверяем наличие кнопки стикеров - assert '🤪Хочу стикеры' in all_buttons - + assert "🤪Хочу стикеры" in all_buttons + @pytest.mark.asyncio async def test_get_reply_keyboard_without_stickers(self, mock_db): """Тест клавиатуры без стикеров""" user_id = 123456 # Мокаем метод get_stickers_info mock_db.get_stickers_info = AsyncMock(return_value=True) - + keyboard = await get_reply_keyboard(mock_db, user_id) - + all_buttons = [] for row in keyboard.keyboard: for button in row: all_buttons.append(button.text) - + # Проверяем отсутствие кнопки стикеров - assert '🤪Хочу стикеры' not in all_buttons - + assert "🤪Хочу стикеры" not in all_buttons + @pytest.mark.asyncio async def test_get_reply_keyboard_admin(self, mock_db): """Тест клавиатуры для админа""" user_id = 123456 # Мокаем метод get_stickers_info mock_db.get_stickers_info = AsyncMock(return_value=False) - + keyboard = await get_reply_keyboard(mock_db, user_id) - + all_buttons = [] for row in keyboard.keyboard: for button in row: all_buttons.append(button.text) - + # Проверяем наличие основных кнопок - assert '📢Предложить свой пост' in all_buttons - assert '👋🏼Сказать пока!' in all_buttons - assert '📩Связаться с админами' in all_buttons - + assert "📢Предложить свой пост" in all_buttons + assert "👋🏼Сказать пока!" in all_buttons + assert "📩Связаться с админами" in all_buttons + def test_get_reply_keyboard_admin_keyboard(self): """Тест админской клавиатуры""" keyboard = get_reply_keyboard_admin() - + assert isinstance(keyboard, ReplyKeyboardMarkup) assert keyboard.keyboard is not None assert len(keyboard.keyboard) == 3 # Три строки - + # Проверяем первую строку (3 кнопки) first_row = keyboard.keyboard[0] assert len(first_row) == 3 assert first_row[0].text == "Бан (Список)" assert first_row[1].text == "Бан по нику" assert first_row[2].text == "Бан по ID" - + # Проверяем вторую строку (2 кнопки) second_row = keyboard.keyboard[1] assert len(second_row) == 2 assert second_row[0].text == "Разбан (список)" assert second_row[1].text == "📊 ML Статистика" - + # Проверяем третью строку (1 кнопка) third_row = keyboard.keyboard[2] assert len(third_row) == 1 assert third_row[0].text == "Вернуться в бота" - + def test_get_reply_keyboard_for_post(self): """Тест клавиатуры для постов""" keyboard = get_reply_keyboard_for_post() - + assert isinstance(keyboard, InlineKeyboardMarkup) assert keyboard.inline_keyboard is not None assert len(keyboard.inline_keyboard) > 0 - + all_buttons = [] for row in keyboard.inline_keyboard: for button in row: all_buttons.append(button.text) - + # Проверяем наличие кнопок для постов - assert 'Опубликовать' in all_buttons - assert 'Отклонить' in all_buttons - + assert "Опубликовать" in all_buttons + assert "Отклонить" in all_buttons + def test_get_reply_keyboard_leave_chat(self): """Тест клавиатуры для выхода из чата""" keyboard = get_reply_keyboard_leave_chat() - + assert isinstance(keyboard, ReplyKeyboardMarkup) assert keyboard.keyboard is not None assert len(keyboard.keyboard) > 0 - + all_buttons = [] for row in keyboard.keyboard: for button in row: all_buttons.append(button.text) - + # Проверяем наличие кнопки выхода - assert 'Выйти из чата' in all_buttons - + assert "Выйти из чата" in all_buttons + def test_keyboard_resize(self): """Тест настройки resize клавиатуры""" keyboard = get_reply_keyboard_for_post() - + # Проверяем, что клавиатура настроена правильно # InlineKeyboardMarkup не имеет resize_keyboard assert isinstance(keyboard, InlineKeyboardMarkup) - + def test_keyboard_one_time(self): """Тест настройки one_time клавиатуры""" keyboard = get_reply_keyboard_leave_chat() - + # Проверяем, что клавиатура настроена правильно - assert hasattr(keyboard, 'one_time_keyboard') + assert hasattr(keyboard, "one_time_keyboard") assert keyboard.one_time_keyboard is True class TestChatTypeFilter: """Тесты для фильтра типа чата""" - + @pytest.fixture def mock_message(self): """Создает мок сообщения""" message = Mock() message.chat = Mock() return message - + @pytest.mark.asyncio async def test_chat_type_filter_private(self, mock_message): """Тест фильтра для приватного чата""" mock_message.chat.type = "private" - + filter_private = ChatTypeFilter(chat_type=["private"]) filter_group = ChatTypeFilter(chat_type=["group"]) filter_supergroup = ChatTypeFilter(chat_type=["supergroup"]) - + # Проверяем, что фильтр работает правильно assert await filter_private(mock_message) is True assert await filter_group(mock_message) is False assert await filter_supergroup(mock_message) is False - + @pytest.mark.asyncio async def test_chat_type_filter_group(self, mock_message): """Тест фильтра для группового чата""" mock_message.chat.type = "group" - + filter_private = ChatTypeFilter(chat_type=["private"]) filter_group = ChatTypeFilter(chat_type=["group"]) filter_supergroup = ChatTypeFilter(chat_type=["supergroup"]) - + # Проверяем, что фильтр работает правильно assert await filter_private(mock_message) is False assert await filter_group(mock_message) is True assert await filter_supergroup(mock_message) is False - + @pytest.mark.asyncio async def test_chat_type_filter_supergroup(self, mock_message): """Тест фильтра для супергруппы""" mock_message.chat.type = "supergroup" - + filter_private = ChatTypeFilter(chat_type=["private"]) filter_group = ChatTypeFilter(chat_type=["group"]) filter_supergroup = ChatTypeFilter(chat_type=["supergroup"]) - + # Проверяем, что фильтр работает правильно assert await filter_private(mock_message) is False assert await filter_group(mock_message) is False assert await filter_supergroup(mock_message) is True - + @pytest.mark.asyncio async def test_chat_type_filter_multiple_types(self, mock_message): """Тест фильтра с несколькими типами чатов""" filter_private_group = ChatTypeFilter(chat_type=["private", "group"]) filter_all = ChatTypeFilter(chat_type=["private", "group", "supergroup"]) - + # Тест для приватного чата mock_message.chat.type = "private" assert await filter_private_group(mock_message) is True assert await filter_all(mock_message) is True - + # Тест для группового чата mock_message.chat.type = "group" assert await filter_private_group(mock_message) is True assert await filter_all(mock_message) is True - + # Тест для супергруппы mock_message.chat.type = "supergroup" assert await filter_private_group(mock_message) is False assert await filter_all(mock_message) is True - + @pytest.mark.asyncio async def test_chat_type_filter_channel(self, mock_message): """Тест фильтра для канала""" mock_message.chat.type = "channel" - + filter_channel = ChatTypeFilter(chat_type=["channel"]) filter_private = ChatTypeFilter(chat_type=["private"]) - + assert await filter_channel(mock_message) is True assert await filter_private(mock_message) is False - + @pytest.mark.asyncio async def test_chat_type_filter_empty_list(self, mock_message): """Тест фильтра с пустым списком типов""" mock_message.chat.type = "private" - + filter_empty = ChatTypeFilter(chat_type=[]) - + # Фильтр с пустым списком должен возвращать False assert await filter_empty(mock_message) is False - + @pytest.mark.asyncio async def test_chat_type_filter_invalid_type(self, mock_message): """Тест фильтра с несуществующим типом чата""" mock_message.chat.type = "invalid_type" - + filter_private = ChatTypeFilter(chat_type=["private"]) filter_invalid = ChatTypeFilter(chat_type=["invalid_type"]) - + assert await filter_private(mock_message) is False assert await filter_invalid(mock_message) is True class TestKeyboardIntegration: """Интеграционные тесты клавиатур""" - + @pytest.mark.asyncio async def test_keyboard_structure_consistency(self): """Тест консистентности структуры клавиатур""" # Мокаем базу данных mock_db = Mock(spec=AsyncBotDB) mock_db.get_stickers_info = AsyncMock(return_value=False) - + # Тестируем все типы клавиатур keyboard1 = await get_reply_keyboard(mock_db, 123456) keyboard2 = get_reply_keyboard_for_post() keyboard3 = get_reply_keyboard_leave_chat() - + # Проверяем первую клавиатуру (ReplyKeyboardMarkup) assert isinstance(keyboard1, ReplyKeyboardMarkup) - assert hasattr(keyboard1, 'keyboard') + assert hasattr(keyboard1, "keyboard") assert isinstance(keyboard1.keyboard, list) - + # Проверяем вторую клавиатуру (InlineKeyboardMarkup) assert isinstance(keyboard2, InlineKeyboardMarkup) - assert hasattr(keyboard2, 'inline_keyboard') + assert hasattr(keyboard2, "inline_keyboard") assert isinstance(keyboard2.inline_keyboard, list) - + # Проверяем третью клавиатуру (ReplyKeyboardMarkup) assert isinstance(keyboard3, ReplyKeyboardMarkup) - assert hasattr(keyboard3, 'keyboard') + assert hasattr(keyboard3, "keyboard") assert isinstance(keyboard3.keyboard, list) - + @pytest.mark.asyncio async def test_keyboard_button_texts(self): """Тест текстов кнопок клавиатур""" # Тестируем основные кнопки db = Mock(spec=AsyncBotDB) db.get_stickers_info = AsyncMock(return_value=False) - + main_keyboard = await get_reply_keyboard(db, 123456) post_keyboard = get_reply_keyboard_for_post() leave_keyboard = get_reply_keyboard_leave_chat() - + # Собираем все тексты кнопок main_buttons = [] for row in main_keyboard.keyboard: for button in row: main_buttons.append(button.text) - + post_buttons = [] for row in post_keyboard.inline_keyboard: for button in row: post_buttons.append(button.text) - + leave_buttons = [] for row in leave_keyboard.keyboard: for button in row: leave_buttons.append(button.text) - + # Проверяем наличие основных кнопок - assert '📢Предложить свой пост' in main_buttons - assert '👋🏼Сказать пока!' in main_buttons - assert '📩Связаться с админами' in main_buttons - assert '🤪Хочу стикеры' in main_buttons - + assert "📢Предложить свой пост" in main_buttons + assert "👋🏼Сказать пока!" in main_buttons + assert "📩Связаться с админами" in main_buttons + assert "🤪Хочу стикеры" in main_buttons + # Проверяем кнопки для постов - assert 'Опубликовать' in post_buttons - assert 'Отклонить' in post_buttons - + assert "Опубликовать" in post_buttons + assert "Отклонить" in post_buttons + # Проверяем кнопку выхода - assert 'Выйти из чата' in leave_buttons + assert "Выйти из чата" in leave_buttons class TestPagination: """Тесты для функции create_keyboard_with_pagination""" - + def test_pagination_empty_list(self): """Тест с пустым списком элементов""" - keyboard = create_keyboard_with_pagination(1, 0, [], 'test') + keyboard = create_keyboard_with_pagination(1, 0, [], "test") assert keyboard is not None # Проверяем, что есть только кнопка "Назад" assert len(keyboard.inline_keyboard) == 1 assert keyboard.inline_keyboard[0][0].text == "🏠 Назад" - + def test_pagination_single_page(self): """Тест с одной страницей""" items = [("User1", 1), ("User2", 2), ("User3", 3)] - keyboard = create_keyboard_with_pagination(1, 3, items, 'test') - + keyboard = create_keyboard_with_pagination(1, 3, items, "test") + # Проверяем количество кнопок (3 пользователя + кнопка "Назад") - assert len(keyboard.inline_keyboard) == 2 # 1 ряд с пользователями + 1 ряд с "Назад" + assert ( + len(keyboard.inline_keyboard) == 2 + ) # 1 ряд с пользователями + 1 ряд с "Назад" assert len(keyboard.inline_keyboard[0]) == 3 # 3 пользователя в первом ряду assert keyboard.inline_keyboard[1][0].text == "🏠 Назад" - + # Проверяем, что нет кнопок навигации assert len(keyboard.inline_keyboard[0]) == 3 # только пользователи - + def test_pagination_multiple_pages(self): """Тест с несколькими страницами""" items = [("User" + str(i), i) for i in range(1, 15)] # 14 пользователей - keyboard = create_keyboard_with_pagination(1, 14, items, 'test') - + keyboard = create_keyboard_with_pagination(1, 14, items, "test") + # На первой странице должно быть 9 пользователей (3 ряда по 3) + кнопка "Следующая" + "Назад" - assert len(keyboard.inline_keyboard) == 5 # 3 ряда пользователей + навигация + назад + assert ( + len(keyboard.inline_keyboard) == 5 + ) # 3 ряда пользователей + навигация + назад assert len(keyboard.inline_keyboard[0]) == 3 # первый ряд: 3 пользователя assert len(keyboard.inline_keyboard[1]) == 3 # второй ряд: 3 пользователя assert len(keyboard.inline_keyboard[2]) == 3 # третий ряд: 3 пользователя assert keyboard.inline_keyboard[3][0].text == "➡️ Следующая" # кнопка навигации assert keyboard.inline_keyboard[4][0].text == "🏠 Назад" # кнопка назад - + def test_pagination_second_page(self): """Тест второй страницы""" items = [("User" + str(i), i) for i in range(1, 15)] # 14 пользователей - keyboard = create_keyboard_with_pagination(2, 14, items, 'test') - + keyboard = create_keyboard_with_pagination(2, 14, items, "test") + # На второй странице должно быть 5 пользователей (2 ряда: 3+2) + кнопки "Предыдущая" и "Назад" - assert len(keyboard.inline_keyboard) == 4 # 2 ряда пользователей + навигация + назад + assert ( + len(keyboard.inline_keyboard) == 4 + ) # 2 ряда пользователей + навигация + назад assert len(keyboard.inline_keyboard[0]) == 3 # первый ряд: 3 пользователя assert len(keyboard.inline_keyboard[1]) == 2 # второй ряд: 2 пользователя assert keyboard.inline_keyboard[2][0].text == "⬅️ Предыдущая" assert keyboard.inline_keyboard[3][0].text == "🏠 Назад" - + def test_pagination_middle_page(self): """Тест средней страницы""" items = [("User" + str(i), i) for i in range(1, 25)] # 24 пользователя - keyboard = create_keyboard_with_pagination(2, 24, items, 'test') - + keyboard = create_keyboard_with_pagination(2, 24, items, "test") + # На второй странице должно быть 9 пользователей (3 ряда по 3) + кнопки "Предыдущая" и "Следующая" - assert len(keyboard.inline_keyboard) == 5 # 3 ряда пользователей + навигация + назад + assert ( + len(keyboard.inline_keyboard) == 5 + ) # 3 ряда пользователей + навигация + назад assert len(keyboard.inline_keyboard[0]) == 3 # первый ряд: 3 пользователя assert len(keyboard.inline_keyboard[1]) == 3 # второй ряд: 3 пользователя assert len(keyboard.inline_keyboard[2]) == 3 # третий ряд: 3 пользователя assert keyboard.inline_keyboard[3][0].text == "⬅️ Предыдущая" assert keyboard.inline_keyboard[3][1].text == "➡️ Следующая" - + def test_pagination_invalid_page_number(self): """Тест с некорректным номером страницы""" items = [("User" + str(i), i) for i in range(1, 10)] # 9 пользователей - keyboard = create_keyboard_with_pagination(0, 9, items, 'test') # страница 0 - + keyboard = create_keyboard_with_pagination(0, 9, items, "test") # страница 0 + # Должна вернуться первая страница assert len(keyboard.inline_keyboard) == 4 # 3 ряда пользователей + назад assert len(keyboard.inline_keyboard[0]) == 3 # первый ряд: 3 пользователя assert len(keyboard.inline_keyboard[1]) == 3 # второй ряд: 3 пользователя assert len(keyboard.inline_keyboard[2]) == 3 # третий ряд: 3 пользователя - + def test_pagination_page_out_of_range(self): """Тест с номером страницы больше максимального""" items = [("User" + str(i), i) for i in range(1, 10)] # 9 пользователей - keyboard = create_keyboard_with_pagination(5, 9, items, 'test') # страница 5 при 1 странице - + keyboard = create_keyboard_with_pagination( + 5, 9, items, "test" + ) # страница 5 при 1 странице + # Должна вернуться первая страница assert len(keyboard.inline_keyboard) == 4 # 3 ряда пользователей + назад assert len(keyboard.inline_keyboard[0]) == 3 # первый ряд: 3 пользователя assert len(keyboard.inline_keyboard[1]) == 3 # второй ряд: 3 пользователя assert len(keyboard.inline_keyboard[2]) == 3 # третий ряд: 3 пользователя - + def test_pagination_callback_data_format(self): """Тест формата callback_data""" items = [("User1", 123), ("User2", 456)] - keyboard = create_keyboard_with_pagination(1, 2, items, 'ban') - + keyboard = create_keyboard_with_pagination(1, 2, items, "ban") + # Проверяем формат callback_data для пользователей assert keyboard.inline_keyboard[0][0].callback_data == "ban_123" assert keyboard.inline_keyboard[0][1].callback_data == "ban_456" - + # Проверяем формат callback_data для кнопки "Назад" assert keyboard.inline_keyboard[1][0].callback_data == "return" - + def test_pagination_navigation_callback_data(self): """Тест callback_data для кнопок навигации""" items = [("User" + str(i), i) for i in range(1, 15)] # 14 пользователей - keyboard = create_keyboard_with_pagination(2, 14, items, 'test') - + keyboard = create_keyboard_with_pagination(2, 14, items, "test") + # Проверяем callback_data для кнопки "Предыдущая" assert keyboard.inline_keyboard[2][0].callback_data == "page_1" - + # Проверяем callback_data для кнопки "Назад" assert keyboard.inline_keyboard[3][0].callback_data == "return" - + def test_pagination_exactly_items_per_page(self): """Тест когда количество элементов точно равно items_per_page""" items = [("User" + str(i), i) for i in range(1, 10)] # ровно 9 пользователей - keyboard = create_keyboard_with_pagination(1, 9, items, 'test') - + keyboard = create_keyboard_with_pagination(1, 9, items, "test") + # Должна быть только одна страница без кнопок навигации assert len(keyboard.inline_keyboard) == 4 # 3 ряда пользователей + назад assert len(keyboard.inline_keyboard[0]) == 3 # первый ряд: 3 пользователя @@ -478,5 +492,5 @@ class TestPagination: assert keyboard.inline_keyboard[3][0].text == "🏠 Назад" -if __name__ == '__main__': - pytest.main([__file__, '-v']) +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/test_message_repository.py b/tests/test_message_repository.py index a9d4a85..b8f2557 100644 --- a/tests/test_message_repository.py +++ b/tests/test_message_repository.py @@ -3,23 +3,24 @@ from datetime import datetime from unittest.mock import AsyncMock, MagicMock import pytest + from database.models import UserMessage from database.repositories.message_repository import MessageRepository class TestMessageRepository: """Тесты для MessageRepository.""" - + @pytest.fixture def mock_db_path(self): """Фикстура для пути к тестовой БД.""" return ":memory:" - + @pytest.fixture def message_repository(self, mock_db_path): """Фикстура для MessageRepository.""" return MessageRepository(mock_db_path) - + @pytest.fixture def sample_message(self): """Фикстура для тестового сообщения.""" @@ -27,9 +28,9 @@ class TestMessageRepository: message_text="Тестовое сообщение", user_id=12345, telegram_message_id=67890, - date=int(datetime.now().timestamp()) + date=int(datetime.now().timestamp()), ) - + @pytest.fixture def sample_message_no_date(self): """Фикстура для тестового сообщения без даты.""" @@ -37,134 +38,145 @@ class TestMessageRepository: message_text="Тестовое сообщение без даты", user_id=12345, telegram_message_id=67891, - date=None + date=None, ) - + @pytest.mark.asyncio async def test_create_tables(self, message_repository): """Тест создания таблиц.""" # Мокаем _execute_query message_repository._execute_query = AsyncMock() - + await message_repository.create_tables() - + message_repository._execute_query.assert_called_once() call_args = message_repository._execute_query.call_args[0][0] assert "CREATE TABLE IF NOT EXISTS user_messages" in call_args assert "telegram_message_id INTEGER NOT NULL" in call_args assert "date INTEGER NOT NULL" in call_args - assert "FOREIGN KEY (user_id) REFERENCES our_users (user_id) ON DELETE CASCADE" in call_args - + assert ( + "FOREIGN KEY (user_id) REFERENCES our_users (user_id) ON DELETE CASCADE" + in call_args + ) + @pytest.mark.asyncio async def test_add_message_with_date(self, message_repository, sample_message): """Тест добавления сообщения с датой.""" # Мокаем _execute_query message_repository._execute_query = AsyncMock() - + await message_repository.add_message(sample_message) - + message_repository._execute_query.assert_called_once() call_args = message_repository._execute_query.call_args query = call_args[0][0] params = call_args[0][1] - + assert "INSERT INTO user_messages" in query assert "VALUES (?, ?, ?, ?)" in query assert params == ( sample_message.message_text, sample_message.user_id, sample_message.telegram_message_id, - sample_message.date + sample_message.date, ) - + @pytest.mark.asyncio - async def test_add_message_without_date(self, message_repository, sample_message_no_date): + async def test_add_message_without_date( + self, message_repository, sample_message_no_date + ): """Тест добавления сообщения без даты (должна генерироваться автоматически).""" # Мокаем _execute_query message_repository._execute_query = AsyncMock() - + await message_repository.add_message(sample_message_no_date) - + # Проверяем, что дата была установлена assert sample_message_no_date.date is not None assert isinstance(sample_message_no_date.date, int) assert sample_message_no_date.date > 0 - + message_repository._execute_query.assert_called_once() call_args = message_repository._execute_query.call_args params = call_args[0][1] - + assert params[3] == sample_message_no_date.date # date field - + @pytest.mark.asyncio async def test_add_message_logs_correctly(self, message_repository, sample_message): """Тест логирования при добавлении сообщения.""" # Мокаем _execute_query и logger message_repository._execute_query = AsyncMock() message_repository.logger = MagicMock() - + await message_repository.add_message(sample_message) - + message_repository.logger.info.assert_called_once() log_message = message_repository.logger.info.call_args[0][0] - assert f"telegram_message_id={sample_message.telegram_message_id}" in log_message - + assert ( + f"telegram_message_id={sample_message.telegram_message_id}" in log_message + ) + @pytest.mark.asyncio async def test_get_user_by_message_id_found(self, message_repository): """Тест получения пользователя по message_id (пользователь найден).""" message_id = 67890 expected_user_id = 12345 - + # Мокаем _execute_query_with_result message_repository._execute_query_with_result = AsyncMock( return_value=[[expected_user_id]] ) - + result = await message_repository.get_user_by_message_id(message_id) - + assert result == expected_user_id message_repository._execute_query_with_result.assert_called_once_with( "SELECT user_id FROM user_messages WHERE telegram_message_id = ?", - (message_id,) + (message_id,), ) - + @pytest.mark.asyncio async def test_get_user_by_message_id_not_found(self, message_repository): """Тест получения пользователя по message_id (пользователь не найден).""" message_id = 99999 - + # Мокаем _execute_query_with_result message_repository._execute_query_with_result = AsyncMock(return_value=[]) - + result = await message_repository.get_user_by_message_id(message_id) - + assert result is None message_repository._execute_query_with_result.assert_called_once_with( "SELECT user_id FROM user_messages WHERE telegram_message_id = ?", - (message_id,) + (message_id,), ) - + @pytest.mark.asyncio async def test_get_user_by_message_id_empty_result(self, message_repository): """Тест получения пользователя по message_id (пустой результат).""" message_id = 99999 - + # Мокаем _execute_query_with_result message_repository._execute_query_with_result = AsyncMock(return_value=[[]]) - + result = await message_repository.get_user_by_message_id(message_id) - + assert result is None - + @pytest.mark.asyncio - async def test_add_message_handles_exception(self, message_repository, sample_message): + async def test_add_message_handles_exception( + self, message_repository, sample_message + ): """Тест обработки исключений при добавлении сообщения.""" # Мокаем _execute_query для вызова исключения - message_repository._execute_query = AsyncMock(side_effect=Exception("Database error")) - + message_repository._execute_query = AsyncMock( + side_effect=Exception("Database error") + ) + with pytest.raises(Exception, match="Database error"): await message_repository.add_message(sample_message) - + @pytest.mark.asyncio async def test_get_user_by_message_id_handles_exception(self, message_repository): """Тест обработки исключений при получении пользователя.""" @@ -172,10 +184,10 @@ class TestMessageRepository: message_repository._execute_query_with_result = AsyncMock( side_effect=Exception("Database error") ) - + with pytest.raises(Exception, match="Database error"): await message_repository.get_user_by_message_id(12345) - + @pytest.mark.asyncio async def test_add_message_with_zero_date(self, message_repository): """Тест добавления сообщения с датой равной 0 (должна генерироваться новая).""" @@ -183,19 +195,19 @@ class TestMessageRepository: message_text="Тестовое сообщение с нулевой датой", user_id=12345, telegram_message_id=67892, - date=0 + date=0, ) - + # Мокаем _execute_query message_repository._execute_query = AsyncMock() - + await message_repository.add_message(message) - + # Проверяем, что дата была изменена с 0 (теперь это происходит только если date is None) # В текущей реализации дата 0 считается валидной и не изменяется assert isinstance(message.date, int) assert message.date >= 0 - + message_repository._execute_query.assert_called_once() params = message_repository._execute_query.call_args[0][1] assert params[3] == message.date # date field diff --git a/tests/test_message_repository_integration.py b/tests/test_message_repository_integration.py index d52b650..84c3b3a 100644 --- a/tests/test_message_repository_integration.py +++ b/tests/test_message_repository_integration.py @@ -4,17 +4,18 @@ import tempfile from datetime import datetime import pytest + from database.models import UserMessage from database.repositories.message_repository import MessageRepository class TestMessageRepositoryIntegration: """Интеграционные тесты для MessageRepository с реальной БД.""" - + async def _setup_test_database(self, message_repository): """Вспомогательная функция для настройки тестовой БД.""" # Сначала создаем таблицу our_users для тестов - await message_repository._execute_query(''' + await message_repository._execute_query(""" CREATE TABLE IF NOT EXISTS our_users ( user_id INTEGER NOT NULL PRIMARY KEY, first_name TEXT, @@ -28,36 +29,42 @@ class TestMessageRepositoryIntegration: date_changed INTEGER NOT NULL, voice_bot_welcome_received BOOLEAN DEFAULT 0 ) - ''') - + """) + # Добавляем тестового пользователя await message_repository._execute_query( "INSERT OR REPLACE INTO our_users (user_id, first_name, full_name, date_added, date_changed) VALUES (?, ?, ?, ?, ?)", - (12345, "Test", "Test User", int(datetime.now().timestamp()), int(datetime.now().timestamp())) + ( + 12345, + "Test", + "Test User", + int(datetime.now().timestamp()), + int(datetime.now().timestamp()), + ), ) - + # Теперь создаем таблицу user_messages await message_repository.create_tables() - + @pytest.fixture def temp_db_path(self): """Фикстура для временного пути к БД.""" - with tempfile.NamedTemporaryFile(suffix='.db', delete=False) as f: + with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f: temp_path = f.name - + yield temp_path - + # Очистка после тестов try: os.unlink(temp_path) except OSError: pass - + @pytest.fixture def message_repository(self, temp_db_path): """Фикстура для MessageRepository с реальной БД.""" return MessageRepository(temp_db_path) - + @pytest.fixture def sample_message(self): """Фикстура для тестового сообщения.""" @@ -65,9 +72,9 @@ class TestMessageRepositoryIntegration: message_text="Интеграционное тестовое сообщение", user_id=12345, telegram_message_id=67890, - date=int(datetime.now().timestamp()) + date=int(datetime.now().timestamp()), ) - + @pytest.fixture def sample_message_no_date(self): """Фикстура для тестового сообщения без даты.""" @@ -75,137 +82,155 @@ class TestMessageRepositoryIntegration: message_text="Интеграционное тестовое сообщение без даты", user_id=12345, telegram_message_id=67891, - date=None + date=None, ) - + @pytest.mark.asyncio async def test_create_tables_integration(self, message_repository): """Интеграционный тест создания таблиц.""" # Настраиваем тестовую БД await self._setup_test_database(message_repository) - + # Проверяем, что таблица создана, пытаясь добавить сообщение message = UserMessage( message_text="Тест создания таблиц", user_id=12345, telegram_message_id=67890, - date=int(datetime.now().timestamp()) + date=int(datetime.now().timestamp()), ) - + # Не должно вызывать ошибку await message_repository.add_message(message) - + @pytest.mark.asyncio - async def test_add_and_retrieve_message_integration(self, message_repository, sample_message): + async def test_add_and_retrieve_message_integration( + self, message_repository, sample_message + ): """Интеграционный тест добавления и получения сообщения.""" # Настраиваем тестовую БД await self._setup_test_database(message_repository) - + # Добавляем сообщение await message_repository.add_message(sample_message) - + # Получаем пользователя по message_id - user_id = await message_repository.get_user_by_message_id(sample_message.telegram_message_id) - + user_id = await message_repository.get_user_by_message_id( + sample_message.telegram_message_id + ) + # Проверяем результат assert user_id == sample_message.user_id - + @pytest.mark.asyncio - async def test_add_message_without_date_integration(self, message_repository, sample_message_no_date): + async def test_add_message_without_date_integration( + self, message_repository, sample_message_no_date + ): """Интеграционный тест добавления сообщения без даты.""" # Настраиваем тестовую БД await self._setup_test_database(message_repository) - + # Добавляем сообщение без даты await message_repository.add_message(sample_message_no_date) - + # Проверяем, что дата была установлена assert sample_message_no_date.date is not None assert isinstance(sample_message_no_date.date, int) assert sample_message_no_date.date > 0 - + # Проверяем, что сообщение можно найти - user_id = await message_repository.get_user_by_message_id(sample_message_no_date.telegram_message_id) + user_id = await message_repository.get_user_by_message_id( + sample_message_no_date.telegram_message_id + ) assert user_id == sample_message_no_date.user_id - + @pytest.mark.asyncio - async def test_get_user_by_message_id_not_found_integration(self, message_repository): + async def test_get_user_by_message_id_not_found_integration( + self, message_repository + ): """Интеграционный тест поиска несуществующего сообщения.""" # Настраиваем тестовую БД await self._setup_test_database(message_repository) - + # Ищем несуществующее сообщение user_id = await message_repository.get_user_by_message_id(99999) - + # Должно вернуть None assert user_id is None - + @pytest.mark.asyncio async def test_multiple_messages_integration(self, message_repository): """Интеграционный тест работы с несколькими сообщениями.""" # Настраиваем тестовую БД await self._setup_test_database(message_repository) - + # Добавляем несколько сообщений (используем существующий user_id 12345) messages = [ UserMessage( message_text=f"Сообщение {i}", user_id=12345, # Используем существующий user_id telegram_message_id=2000 + i, - date=int(datetime.now().timestamp()) + i + date=int(datetime.now().timestamp()) + i, ) for i in range(1, 4) ] - + for message in messages: await message_repository.add_message(message) - + # Проверяем, что все сообщения можно найти for message in messages: - user_id = await message_repository.get_user_by_message_id(message.telegram_message_id) + user_id = await message_repository.get_user_by_message_id( + message.telegram_message_id + ) assert user_id == message.user_id - + @pytest.mark.asyncio - async def test_message_with_special_characters_integration(self, message_repository): + async def test_message_with_special_characters_integration( + self, message_repository + ): """Интеграционный тест сообщения со специальными символами.""" # Настраиваем тестовую БД await self._setup_test_database(message_repository) - + # Сообщение со специальными символами special_message = UserMessage( message_text="Сообщение с 'кавычками' и \"двойными кавычками\" и эмодзи 😊", user_id=12345, telegram_message_id=67892, - date=int(datetime.now().timestamp()) + date=int(datetime.now().timestamp()), ) - + # Добавляем сообщение await message_repository.add_message(special_message) - + # Проверяем, что можно найти - user_id = await message_repository.get_user_by_message_id(special_message.telegram_message_id) + user_id = await message_repository.get_user_by_message_id( + special_message.telegram_message_id + ) assert user_id == special_message.user_id - + @pytest.mark.asyncio async def test_foreign_key_constraint_integration(self, message_repository): """Интеграционный тест ограничения внешнего ключа.""" # Настраиваем тестовую БД await self._setup_test_database(message_repository) - + # Пытаемся добавить сообщение с несуществующим user_id invalid_message = UserMessage( message_text="Сообщение с несуществующим пользователем", user_id=99999, # Несуществующий пользователь telegram_message_id=67893, - date=int(datetime.now().timestamp()) + date=int(datetime.now().timestamp()), ) - + # В SQLite с включенными внешними ключами это должно вызвать ошибку # Теперь у нас есть таблица our_users, поэтому внешний ключ должен работать try: await message_repository.add_message(invalid_message) # Если не вызвало ошибку, проверяем что сообщение не добавилось - user_id = await message_repository.get_user_by_message_id(invalid_message.telegram_message_id) + user_id = await message_repository.get_user_by_message_id( + invalid_message.telegram_message_id + ) assert user_id is None except Exception: # Ожидаемое поведение при нарушении внешнего ключа diff --git a/tests/test_post_repository.py b/tests/test_post_repository.py index 867a3be..797be28 100644 --- a/tests/test_post_repository.py +++ b/tests/test_post_repository.py @@ -3,23 +3,24 @@ from datetime import datetime from unittest.mock import AsyncMock, MagicMock import pytest + from database.models import MessageContentLink, PostContent, TelegramPost from database.repositories.post_repository import PostRepository class TestPostRepository: """Тесты для PostRepository.""" - + @pytest.fixture def mock_db_path(self): """Фикстура для пути к тестовой БД.""" return ":memory:" - + @pytest.fixture def post_repository(self, mock_db_path): """Фикстура для PostRepository.""" return PostRepository(mock_db_path) - + @pytest.fixture def sample_post(self): """Фикстура для тестового поста.""" @@ -28,9 +29,9 @@ class TestPostRepository: text="Тестовый пост", author_id=67890, helper_text_message_id=None, - created_at=int(datetime.now().timestamp()) + created_at=int(datetime.now().timestamp()), ) - + @pytest.fixture def sample_post_no_date(self): """Фикстура для тестового поста без даты.""" @@ -42,73 +43,85 @@ class TestPostRepository: created_at=None, status="suggest", ) - + @pytest.fixture def sample_post_content(self): """Фикстура для тестового контента поста.""" return PostContent( - message_id=12345, - content_name="/path/to/file.jpg", - content_type="photo" + message_id=12345, content_name="/path/to/file.jpg", content_type="photo" ) - + @pytest.fixture def sample_message_link(self): """Фикстура для тестовой связи сообщения с контентом.""" - return MessageContentLink( - post_id=12345, - message_id=67890 - ) - + return MessageContentLink(post_id=12345, message_id=67890) + @pytest.mark.asyncio async def test_create_tables(self, post_repository): """Тест создания таблиц.""" # Мокаем _execute_query и _execute_query_with_result post_repository._execute_query = AsyncMock() - post_repository._execute_query_with_result = AsyncMock(return_value=[]) # Для проверки столбца - + post_repository._execute_query_with_result = AsyncMock( + return_value=[] + ) # Для проверки столбца + await post_repository.create_tables() - + # Проверяем, что create_tables вызвался минимум 3 раза (для каждой таблицы) # Может быть больше из-за ALTER TABLE и индексов assert post_repository._execute_query.call_count >= 3 - + # Проверяем, что все нужные таблицы созданы (порядок может быть разным из-за ALTER TABLE) calls = post_repository._execute_query.call_args_list all_queries = [call[0][0] for call in calls] - + # Проверяем создание таблицы постов - post_table_queries = [q for q in all_queries if "CREATE TABLE IF NOT EXISTS post_from_telegram_suggest" in q] + post_table_queries = [ + q + for q in all_queries + if "CREATE TABLE IF NOT EXISTS post_from_telegram_suggest" in q + ] assert len(post_table_queries) > 0 assert "message_id INTEGER NOT NULL PRIMARY KEY" in post_table_queries[0] assert "created_at INTEGER NOT NULL" in post_table_queries[0] assert "status TEXT NOT NULL DEFAULT 'suggest'" in post_table_queries[0] assert "is_anonymous INTEGER" in post_table_queries[0] - assert "FOREIGN KEY (author_id) REFERENCES our_users (user_id) ON DELETE CASCADE" in post_table_queries[0] - + assert ( + "FOREIGN KEY (author_id) REFERENCES our_users (user_id) ON DELETE CASCADE" + in post_table_queries[0] + ) + # Проверяем создание таблицы контента - content_table_queries = [q for q in all_queries if "CREATE TABLE IF NOT EXISTS content_post_from_telegram" in q] + content_table_queries = [ + q + for q in all_queries + if "CREATE TABLE IF NOT EXISTS content_post_from_telegram" in q + ] assert len(content_table_queries) > 0 assert "PRIMARY KEY (message_id, content_name)" in content_table_queries[0] - + # Проверяем создание таблицы связей - link_table_queries = [q for q in all_queries if "CREATE TABLE IF NOT EXISTS message_link_to_content" in q] + link_table_queries = [ + q + for q in all_queries + if "CREATE TABLE IF NOT EXISTS message_link_to_content" in q + ] assert len(link_table_queries) > 0 assert "PRIMARY KEY (post_id, message_id)" in link_table_queries[0] - + @pytest.mark.asyncio async def test_add_post_with_date(self, post_repository, sample_post): """Тест добавления поста с датой.""" # Мокаем _execute_query post_repository._execute_query = AsyncMock() - + await post_repository.add_post(sample_post) - + post_repository._execute_query.assert_called_once() call_args = post_repository._execute_query.call_args query = call_args[0][0] params = call_args[0][1] - + assert "INSERT OR IGNORE INTO post_from_telegram_suggest" in query assert "status" in query assert "is_anonymous" in query @@ -120,63 +133,70 @@ class TestPostRepository: assert params[3] == sample_post.created_at assert params[4] == sample_post.status # is_anonymous преобразуется в int (None -> None, True -> 1, False -> 0) - expected_is_anonymous = None if sample_post.is_anonymous is None else (1 if sample_post.is_anonymous else 0) + expected_is_anonymous = ( + None + if sample_post.is_anonymous is None + else (1 if sample_post.is_anonymous else 0) + ) assert params[5] == expected_is_anonymous - + @pytest.mark.asyncio async def test_add_post_without_date(self, post_repository, sample_post_no_date): """Тест добавления поста без даты (должна генерироваться автоматически).""" # Мокаем _execute_query post_repository._execute_query = AsyncMock() - + await post_repository.add_post(sample_post_no_date) - + # Проверяем, что дата была установлена assert sample_post_no_date.created_at is not None assert isinstance(sample_post_no_date.created_at, int) assert sample_post_no_date.created_at > 0 - + post_repository._execute_query.assert_called_once() call_args = post_repository._execute_query.call_args params = call_args[0][1] - + assert params[3] == sample_post_no_date.created_at # created_at assert params[4] == sample_post_no_date.status # status (default suggest) # Проверяем is_anonymous (должен быть в параметрах) assert len(params) == 6 # Всего 6 параметров включая is_anonymous - + @pytest.mark.asyncio async def test_add_post_logs_correctly(self, post_repository, sample_post): """Тест логирования при добавлении поста.""" # Мокаем _execute_query и logger post_repository._execute_query = AsyncMock() post_repository.logger = MagicMock() - + await post_repository.add_post(sample_post) - + # Проверяем, что логирование вызвано с новым форматом сообщения post_repository.logger.info.assert_called_once() log_call = post_repository.logger.info.call_args[0][0] assert f"message_id={sample_post.message_id}" in log_call assert "Пост добавлен" in log_call or "уже существует" in log_call - + @pytest.mark.asyncio async def test_update_helper_message(self, post_repository): """Тест обновления helper сообщения.""" # Мокаем _execute_query post_repository._execute_query = AsyncMock() - + message_id = 12345 helper_message_id = 67890 - + await post_repository.update_helper_message(message_id, helper_message_id) - + post_repository._execute_query.assert_called_once() call_args = post_repository._execute_query.call_args query = call_args[0][0] params = call_args[0][1] - - assert "UPDATE post_from_telegram_suggest SET helper_text_message_id = ? WHERE message_id = ?" in query + + assert ( + "UPDATE post_from_telegram_suggest SET helper_text_message_id = ? WHERE message_id = ?" + in query + ) assert params == (helper_message_id, message_id) @pytest.mark.asyncio @@ -192,7 +212,7 @@ class TestPostRepository: mock_conn.execute = AsyncMock(return_value=mock_cur) post_repository._get_connection.return_value = mock_conn post_repository.logger = MagicMock() - + # Создаем таблицы await post_repository.create_tables() post_repository._execute_query.reset_mock() @@ -216,7 +236,10 @@ class TestPostRepository: # Проверяем, что после создания таблиц было вызвано логирование обновления статуса post_repository.logger.info.assert_called() log_calls = [str(call) for call in post_repository.logger.info.call_args_list] - assert any("Статус поста message_id=12345 обновлён на approved" in str(call) for call in post_repository.logger.info.call_args_list) + assert any( + "Статус поста message_id=12345 обновлён на approved" in str(call) + for call in post_repository.logger.info.call_args_list + ) @pytest.mark.asyncio async def test_update_status_for_media_group_by_helper_id(self, post_repository): @@ -231,7 +254,7 @@ class TestPostRepository: mock_conn.execute = AsyncMock(return_value=mock_cur) post_repository._get_connection.return_value = mock_conn post_repository.logger = MagicMock() - + # Создаем таблицы await post_repository.create_tables() post_repository._execute_query.reset_mock() @@ -257,7 +280,11 @@ class TestPostRepository: assert params == (status, helper_message_id, helper_message_id) # Проверяем, что после создания таблиц было вызвано логирование обновления статуса post_repository.logger.info.assert_called() - assert any("Статус медиагруппы helper_message_id=99999 обновлён на declined" in str(call) for call in post_repository.logger.info.call_args_list) + assert any( + "Статус медиагруппы helper_message_id=99999 обновлён на declined" + in str(call) + for call in post_repository.logger.info.call_args_list + ) @pytest.mark.asyncio async def test_add_post_content_success(self, post_repository): @@ -265,61 +292,67 @@ class TestPostRepository: # Мокаем _execute_query post_repository._execute_query = AsyncMock() post_repository.logger = MagicMock() - + post_id = 12345 message_id = 67890 content_name = "/path/to/file.jpg" content_type = "photo" - - result = await post_repository.add_post_content(post_id, message_id, content_name, content_type) - + + result = await post_repository.add_post_content( + post_id, message_id, content_name, content_type + ) + # Проверяем, что результат True assert result is True - + # Проверяем, что _execute_query вызвался 2 раза (для связи и контента) assert post_repository._execute_query.call_count == 2 - + # Проверяем вызов для связи link_call = post_repository._execute_query.call_args_list[0] link_query = link_call[0][0] link_params = link_call[0][1] assert "INSERT OR IGNORE INTO message_link_to_content" in link_query assert link_params == (post_id, message_id) - + # Проверяем вызов для контента content_call = post_repository._execute_query.call_args_list[1] content_query = content_call[0][0] content_params = content_call[0][1] assert "INSERT OR IGNORE INTO content_post_from_telegram" in content_query assert content_params == (message_id, content_name, content_type) - + # Проверяем логирование post_repository.logger.info.assert_called_once_with( f"Контент поста добавлен: post_id={post_id}, message_id={message_id}" ) - + @pytest.mark.asyncio async def test_add_post_content_exception(self, post_repository): """Тест обработки исключения при добавлении контента поста.""" # Мокаем _execute_query чтобы вызвать исключение - post_repository._execute_query = AsyncMock(side_effect=Exception("Database error")) + post_repository._execute_query = AsyncMock( + side_effect=Exception("Database error") + ) post_repository.logger = MagicMock() - + post_id = 12345 message_id = 67890 content_name = "/path/to/file.jpg" content_type = "photo" - - result = await post_repository.add_post_content(post_id, message_id, content_name, content_type) - + + result = await post_repository.add_post_content( + post_id, message_id, content_name, content_type + ) + # Проверяем, что результат False assert result is False - + # Проверяем логирование ошибки post_repository.logger.error.assert_called_once() error_call = post_repository.logger.error.call_args[0][0] assert "Ошибка при добавлении контента поста:" in error_call - + @pytest.mark.asyncio async def test_get_post_content_by_helper_id(self, post_repository): """Тест получения контента поста по helper ID.""" @@ -327,33 +360,33 @@ class TestPostRepository: mock_result = [ ("/path/to/photo1.jpg", "photo"), ("/path/to/video1.mp4", "video"), - ("/path/to/photo2.jpg", "photo") + ("/path/to/photo2.jpg", "photo"), ] post_repository._execute_query_with_result = AsyncMock(return_value=mock_result) post_repository.logger = MagicMock() - + helper_message_id = 67890 - + result = await post_repository.get_post_content_by_helper_id(helper_message_id) - + # Проверяем результат assert result == mock_result - + # Проверяем вызов _execute_query_with_result post_repository._execute_query_with_result.assert_called_once() call_args = post_repository._execute_query_with_result.call_args query = call_args[0][0] params = call_args[0][1] - + assert "SELECT cpft.content_name, cpft.content_type" in query assert "WHERE pft.helper_text_message_id = ?" in query assert params == (helper_message_id,) - + # Проверяем логирование post_repository.logger.info.assert_called_once_with( f"Получен контент поста: {len(mock_result)} элементов" ) - + @pytest.mark.asyncio async def test_get_post_text_by_helper_id_found(self, post_repository): """Тест получения текста поста по helper ID (пост найден).""" @@ -361,28 +394,31 @@ class TestPostRepository: mock_result = [("Тестовый текст поста",)] post_repository._execute_query_with_result = AsyncMock(return_value=mock_result) post_repository.logger = MagicMock() - + helper_message_id = 67890 - + result = await post_repository.get_post_text_by_helper_id(helper_message_id) - + # Проверяем результат assert result == "Тестовый текст поста" - + # Проверяем вызов _execute_query_with_result post_repository._execute_query_with_result.assert_called_once() call_args = post_repository._execute_query_with_result.call_args query = call_args[0][0] params = call_args[0][1] - - assert "SELECT text FROM post_from_telegram_suggest WHERE helper_text_message_id = ?" in query + + assert ( + "SELECT text FROM post_from_telegram_suggest WHERE helper_text_message_id = ?" + in query + ) assert params == (helper_message_id,) - + # Проверяем логирование post_repository.logger.info.assert_called_once_with( f"Получен текст поста для helper_message_id={helper_message_id}" ) - + @pytest.mark.asyncio async def test_get_post_text_by_helper_id_not_found(self, post_repository): """Тест получения текста поста по helper ID (пост не найден).""" @@ -390,17 +426,17 @@ class TestPostRepository: mock_result = [] post_repository._execute_query_with_result = AsyncMock(return_value=mock_result) post_repository.logger = MagicMock() - + helper_message_id = 67890 - + result = await post_repository.get_post_text_by_helper_id(helper_message_id) - + # Проверяем результат assert result is None - + # Проверяем, что logger.info не вызывался post_repository.logger.info.assert_not_called() - + @pytest.mark.asyncio async def test_get_post_ids_by_helper_id(self, post_repository): """Тест получения ID сообщений по helper ID.""" @@ -408,29 +444,29 @@ class TestPostRepository: mock_result = [(12345,), (67890,), (11111,)] post_repository._execute_query_with_result = AsyncMock(return_value=mock_result) post_repository.logger = MagicMock() - + helper_message_id = 67890 - + result = await post_repository.get_post_ids_by_helper_id(helper_message_id) - + # Проверяем результат assert result == [12345, 67890, 11111] - + # Проверяем вызов _execute_query_with_result post_repository._execute_query_with_result.assert_called_once() call_args = post_repository._execute_query_with_result.call_args query = call_args[0][0] params = call_args[0][1] - + assert "SELECT mltc.message_id" in query assert "WHERE pft.helper_text_message_id = ?" in query assert params == (helper_message_id,) - + # Проверяем логирование post_repository.logger.info.assert_called_once_with( f"Получены ID сообщений: {len(mock_result)} элементов" ) - + @pytest.mark.asyncio async def test_get_author_id_by_message_id_found(self, post_repository): """Тест получения ID автора по message ID (автор найден).""" @@ -438,28 +474,31 @@ class TestPostRepository: mock_result = [(67890,)] post_repository._execute_query_with_result = AsyncMock(return_value=mock_result) post_repository.logger = MagicMock() - + message_id = 12345 - + result = await post_repository.get_author_id_by_message_id(message_id) - + # Проверяем результат assert result == 67890 - + # Проверяем вызов _execute_query_with_result post_repository._execute_query_with_result.assert_called_once() call_args = post_repository._execute_query_with_result.call_args query = call_args[0][0] params = call_args[0][1] - - assert "SELECT author_id FROM post_from_telegram_suggest WHERE message_id = ?" in query + + assert ( + "SELECT author_id FROM post_from_telegram_suggest WHERE message_id = ?" + in query + ) assert params == (message_id,) - + # Проверяем логирование post_repository.logger.info.assert_called_once_with( f"Получен author_id: {67890} для message_id={message_id}" ) - + @pytest.mark.asyncio async def test_get_author_id_by_message_id_not_found(self, post_repository): """Тест получения ID автора по message ID (автор не найден).""" @@ -467,17 +506,17 @@ class TestPostRepository: mock_result = [] post_repository._execute_query_with_result = AsyncMock(return_value=mock_result) post_repository.logger = MagicMock() - + message_id = 12345 - + result = await post_repository.get_author_id_by_message_id(message_id) - + # Проверяем результат assert result is None - + # Проверяем, что logger.info не вызывался post_repository.logger.info.assert_not_called() - + @pytest.mark.asyncio async def test_get_author_id_by_helper_message_id_found(self, post_repository): """Тест получения ID автора по helper message ID (автор найден).""" @@ -485,28 +524,33 @@ class TestPostRepository: mock_result = [(67890,)] post_repository._execute_query_with_result = AsyncMock(return_value=mock_result) post_repository.logger = MagicMock() - + helper_message_id = 12345 - - result = await post_repository.get_author_id_by_helper_message_id(helper_message_id) - + + result = await post_repository.get_author_id_by_helper_message_id( + helper_message_id + ) + # Проверяем результат assert result == 67890 - + # Проверяем вызов _execute_query_with_result post_repository._execute_query_with_result.assert_called_once() call_args = post_repository._execute_query_with_result.call_args query = call_args[0][0] params = call_args[0][1] - - assert "SELECT author_id FROM post_from_telegram_suggest WHERE helper_text_message_id = ?" in query + + assert ( + "SELECT author_id FROM post_from_telegram_suggest WHERE helper_text_message_id = ?" + in query + ) assert params == (helper_message_id,) - + # Проверяем логирование post_repository.logger.info.assert_called_once_with( f"Получен author_id: {67890} для helper_message_id={helper_message_id}" ) - + @pytest.mark.asyncio async def test_get_author_id_by_helper_message_id_not_found(self, post_repository): """Тест получения ID автора по helper message ID (автор не найден).""" @@ -514,117 +558,145 @@ class TestPostRepository: mock_result = [] post_repository._execute_query_with_result = AsyncMock(return_value=mock_result) post_repository.logger = MagicMock() - + helper_message_id = 12345 - - result = await post_repository.get_author_id_by_helper_message_id(helper_message_id) - + + result = await post_repository.get_author_id_by_helper_message_id( + helper_message_id + ) + # Проверяем результат assert result is None - + # Проверяем, что logger.info не вызывался post_repository.logger.info.assert_not_called() - + @pytest.mark.asyncio - async def test_get_post_text_and_anonymity_by_message_id_found(self, post_repository): + async def test_get_post_text_and_anonymity_by_message_id_found( + self, post_repository + ): """Тест получения текста и is_anonymous по message_id (пост найден).""" # Мокаем _execute_query_with_result mock_result = [("Тестовый текст", 1)] # is_anonymous = 1 (True) post_repository._execute_query_with_result = AsyncMock(return_value=mock_result) post_repository.logger = MagicMock() - + message_id = 12345 - - result = await post_repository.get_post_text_and_anonymity_by_message_id(message_id) - + + result = await post_repository.get_post_text_and_anonymity_by_message_id( + message_id + ) + # Проверяем результат text, is_anonymous = result assert text == "Тестовый текст" assert is_anonymous is True - + # Проверяем вызов _execute_query_with_result post_repository._execute_query_with_result.assert_called_once() call_args = post_repository._execute_query_with_result.call_args query = call_args[0][0] params = call_args[0][1] - - assert "SELECT text, is_anonymous FROM post_from_telegram_suggest WHERE message_id = ?" in query + + assert ( + "SELECT text, is_anonymous FROM post_from_telegram_suggest WHERE message_id = ?" + in query + ) assert params == (message_id,) - + @pytest.mark.asyncio - async def test_get_post_text_and_anonymity_by_message_id_with_false(self, post_repository): + async def test_get_post_text_and_anonymity_by_message_id_with_false( + self, post_repository + ): """Тест получения текста и is_anonymous по message_id (is_anonymous = False).""" # Мокаем _execute_query_with_result mock_result = [("Тестовый текст", 0)] # is_anonymous = 0 (False) post_repository._execute_query_with_result = AsyncMock(return_value=mock_result) - + message_id = 12345 - - result = await post_repository.get_post_text_and_anonymity_by_message_id(message_id) - + + result = await post_repository.get_post_text_and_anonymity_by_message_id( + message_id + ) + # Проверяем результат text, is_anonymous = result assert text == "Тестовый текст" assert is_anonymous is False - + @pytest.mark.asyncio - async def test_get_post_text_and_anonymity_by_message_id_with_null(self, post_repository): + async def test_get_post_text_and_anonymity_by_message_id_with_null( + self, post_repository + ): """Тест получения текста и is_anonymous по message_id (is_anonymous = NULL).""" # Мокаем _execute_query_with_result mock_result = [("Тестовый текст", None)] # is_anonymous = NULL post_repository._execute_query_with_result = AsyncMock(return_value=mock_result) - + message_id = 12345 - - result = await post_repository.get_post_text_and_anonymity_by_message_id(message_id) - + + result = await post_repository.get_post_text_and_anonymity_by_message_id( + message_id + ) + # Проверяем результат text, is_anonymous = result assert text == "Тестовый текст" assert is_anonymous is None - + @pytest.mark.asyncio - async def test_get_post_text_and_anonymity_by_message_id_not_found(self, post_repository): + async def test_get_post_text_and_anonymity_by_message_id_not_found( + self, post_repository + ): """Тест получения текста и is_anonymous по message_id (пост не найден).""" # Мокаем _execute_query_with_result mock_result = [] post_repository._execute_query_with_result = AsyncMock(return_value=mock_result) - + message_id = 12345 - - result = await post_repository.get_post_text_and_anonymity_by_message_id(message_id) - + + result = await post_repository.get_post_text_and_anonymity_by_message_id( + message_id + ) + # Проверяем результат text, is_anonymous = result assert text is None assert is_anonymous is None - + @pytest.mark.asyncio - async def test_get_post_text_and_anonymity_by_helper_id_found(self, post_repository): + async def test_get_post_text_and_anonymity_by_helper_id_found( + self, post_repository + ): """Тест получения текста и is_anonymous по helper_message_id (пост найден).""" # Мокаем _execute_query_with_result mock_result = [("Тестовый текст", 1)] # is_anonymous = 1 (True) post_repository._execute_query_with_result = AsyncMock(return_value=mock_result) post_repository.logger = MagicMock() - + helper_message_id = 67890 - - result = await post_repository.get_post_text_and_anonymity_by_helper_id(helper_message_id) - + + result = await post_repository.get_post_text_and_anonymity_by_helper_id( + helper_message_id + ) + # Проверяем результат text, is_anonymous = result assert text == "Тестовый текст" assert is_anonymous is True - + # Проверяем вызов _execute_query_with_result post_repository._execute_query_with_result.assert_called_once() call_args = post_repository._execute_query_with_result.call_args query = call_args[0][0] params = call_args[0][1] - - assert "SELECT text, is_anonymous FROM post_from_telegram_suggest WHERE helper_text_message_id = ?" in query + + assert ( + "SELECT text, is_anonymous FROM post_from_telegram_suggest WHERE helper_text_message_id = ?" + in query + ) assert params == (helper_message_id,) - + @pytest.mark.asyncio async def test_add_post_with_is_anonymous_true(self, post_repository): """Тест добавления поста с is_anonymous=True.""" @@ -633,19 +705,19 @@ class TestPostRepository: text="Тестовый пост анон", author_id=67890, created_at=int(datetime.now().timestamp()), - is_anonymous=True + is_anonymous=True, ) - + post_repository._execute_query = AsyncMock() - + await post_repository.add_post(post) - + call_args = post_repository._execute_query.call_args params = call_args[0][1] - + # Проверяем, что is_anonymous преобразован в 1 assert params[5] == 1 - + @pytest.mark.asyncio async def test_add_post_with_is_anonymous_false(self, post_repository): """Тест добавления поста с is_anonymous=False.""" @@ -654,19 +726,19 @@ class TestPostRepository: text="Тестовый пост неанон", author_id=67890, created_at=int(datetime.now().timestamp()), - is_anonymous=False + is_anonymous=False, ) - + post_repository._execute_query = AsyncMock() - + await post_repository.add_post(post) - + call_args = post_repository._execute_query.call_args params = call_args[0][1] - + # Проверяем, что is_anonymous преобразован в 0 assert params[5] == 0 - + @pytest.mark.asyncio async def test_add_post_with_is_anonymous_none(self, post_repository): """Тест добавления поста с is_anonymous=None.""" @@ -675,28 +747,30 @@ class TestPostRepository: text="Тестовый пост", author_id=67890, created_at=int(datetime.now().timestamp()), - is_anonymous=None + is_anonymous=None, ) - + post_repository._execute_query = AsyncMock() - + await post_repository.add_post(post) - + call_args = post_repository._execute_query.call_args params = call_args[0][1] - + # Проверяем, что is_anonymous остался None assert params[5] is None - + @pytest.mark.asyncio async def test_create_tables_logs_success(self, post_repository): """Тест логирования успешного создания таблиц.""" # Мокаем _execute_query, _execute_query_with_result и logger post_repository._execute_query = AsyncMock() - post_repository._execute_query_with_result = AsyncMock(return_value=[]) # Для проверки столбца + post_repository._execute_query_with_result = AsyncMock( + return_value=[] + ) # Для проверки столбца post_repository.logger = MagicMock() - + await post_repository.create_tables() - + # Проверяем, что финальное сообщение о создании таблиц было вызвано post_repository.logger.info.assert_any_call("Таблицы для постов созданы") diff --git a/tests/test_post_repository_integration.py b/tests/test_post_repository_integration.py index c6c21b6..d485fec 100644 --- a/tests/test_post_repository_integration.py +++ b/tests/test_post_repository_integration.py @@ -4,17 +4,18 @@ import tempfile from datetime import datetime import pytest + from database.models import MessageContentLink, PostContent, TelegramPost from database.repositories.post_repository import PostRepository class TestPostRepositoryIntegration: """Интеграционные тесты для PostRepository с реальной БД.""" - + async def _setup_test_database(self, post_repository): """Вспомогательная функция для настройки тестовой БД.""" # Сначала создаем таблицу our_users для тестов - await post_repository._execute_query(''' + await post_repository._execute_query(""" CREATE TABLE IF NOT EXISTS our_users ( user_id INTEGER NOT NULL PRIMARY KEY, first_name TEXT, @@ -28,40 +29,52 @@ class TestPostRepositoryIntegration: date_changed INTEGER NOT NULL, voice_bot_welcome_received BOOLEAN DEFAULT 0 ) - ''') - + """) + # Добавляем тестовых пользователей await post_repository._execute_query( "INSERT OR REPLACE INTO our_users (user_id, first_name, full_name, date_added, date_changed) VALUES (?, ?, ?, ?, ?)", - (67890, "Test", "Test User", int(datetime.now().timestamp()), int(datetime.now().timestamp())) + ( + 67890, + "Test", + "Test User", + int(datetime.now().timestamp()), + int(datetime.now().timestamp()), + ), ) await post_repository._execute_query( "INSERT OR REPLACE INTO our_users (user_id, first_name, full_name, date_added, date_changed) VALUES (?, ?, ?, ?, ?)", - (11111, "Test2", "Test User 2", int(datetime.now().timestamp()), int(datetime.now().timestamp())) + ( + 11111, + "Test2", + "Test User 2", + int(datetime.now().timestamp()), + int(datetime.now().timestamp()), + ), ) - + # Теперь создаем таблицы для постов await post_repository.create_tables() - + @pytest.fixture def temp_db_path(self): """Фикстура для временного файла БД.""" - with tempfile.NamedTemporaryFile(suffix='.db', delete=False) as tmp_file: + with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp_file: db_path = tmp_file.name - + yield db_path - + # Очищаем временный файл после тестов try: os.unlink(db_path) except OSError: pass - + @pytest.fixture def post_repository(self, temp_db_path): """Фикстура для PostRepository с реальной БД.""" return PostRepository(temp_db_path) - + @pytest.fixture def sample_post(self): """Фикстура для тестового поста.""" @@ -70,9 +83,9 @@ class TestPostRepositoryIntegration: text="Тестовый пост для интеграционных тестов", author_id=67890, helper_text_message_id=None, - created_at=int(datetime.now().timestamp()) + created_at=int(datetime.now().timestamp()), ) - + @pytest.fixture def sample_post_2(self): """Фикстура для второго тестового поста.""" @@ -81,9 +94,9 @@ class TestPostRepositoryIntegration: text="Второй тестовый пост", author_id=67890, helper_text_message_id=None, - created_at=int(datetime.now().timestamp()) + created_at=int(datetime.now().timestamp()), ) - + @pytest.fixture def sample_post_with_helper(self): """Фикстура для тестового поста с helper сообщением.""" @@ -92,418 +105,485 @@ class TestPostRepositoryIntegration: text="Пост с helper сообщением", author_id=67890, helper_text_message_id=None, # Будет установлен позже - created_at=int(datetime.now().timestamp()) + created_at=int(datetime.now().timestamp()), ) - + @pytest.mark.asyncio async def test_create_tables_integration(self, post_repository): """Интеграционный тест создания таблиц.""" # Настраиваем тестовую БД await self._setup_test_database(post_repository) - + # Проверяем, что таблицы созданы (попробуем вставить тестовые данные) test_post = TelegramPost( message_id=99999, text="Тест создания таблиц", author_id=67890, # Используем существующего пользователя - created_at=int(datetime.now().timestamp()) + created_at=int(datetime.now().timestamp()), ) - + # Если таблицы созданы, то insert должен пройти успешно await post_repository.add_post(test_post) - + # Проверяем, что пост действительно добавлен author_id = await post_repository.get_author_id_by_message_id(99999) assert author_id == 67890 - + @pytest.mark.asyncio async def test_add_post_integration(self, post_repository, sample_post): """Интеграционный тест добавления поста.""" # Настраиваем тестовую БД await self._setup_test_database(post_repository) - + # Добавляем пост await post_repository.add_post(sample_post) - + # Проверяем, что пост добавлен - author_id = await post_repository.get_author_id_by_message_id(sample_post.message_id) + author_id = await post_repository.get_author_id_by_message_id( + sample_post.message_id + ) assert author_id == sample_post.author_id - + @pytest.mark.asyncio async def test_add_post_without_date_integration(self, post_repository): """Интеграционный тест добавления поста без даты.""" # Настраиваем тестовую БД await self._setup_test_database(post_repository) - + post_without_date = TelegramPost( message_id=12348, text="Пост без даты", author_id=67890, helper_text_message_id=None, - created_at=None + created_at=None, ) - + # Добавляем пост await post_repository.add_post(post_without_date) - + # Проверяем, что дата была установлена автоматически assert post_without_date.created_at is not None assert isinstance(post_without_date.created_at, int) assert post_without_date.created_at > 0 - + # Проверяем, что пост добавлен - author_id = await post_repository.get_author_id_by_message_id(post_without_date.message_id) + author_id = await post_repository.get_author_id_by_message_id( + post_without_date.message_id + ) assert author_id == post_without_date.author_id - + @pytest.mark.asyncio - async def test_update_helper_message_integration(self, post_repository, sample_post): + async def test_update_helper_message_integration( + self, post_repository, sample_post + ): """Интеграционный тест обновления helper сообщения.""" # Настраиваем тестовую БД await self._setup_test_database(post_repository) - + # Добавляем пост await post_repository.add_post(sample_post) - + # Обновляем helper сообщение helper_message_id = 88888 - await post_repository.update_helper_message(sample_post.message_id, helper_message_id) - + await post_repository.update_helper_message( + sample_post.message_id, helper_message_id + ) + # Проверяем, что helper сообщение обновлено # Для этого нужно получить пост и проверить helper_text_message_id # Но у нас нет метода для получения поста по ID, поэтому проверяем косвенно # через get_author_id_by_helper_message_id - author_id = await post_repository.get_author_id_by_helper_message_id(helper_message_id) + author_id = await post_repository.get_author_id_by_helper_message_id( + helper_message_id + ) assert author_id == sample_post.author_id - + @pytest.mark.asyncio async def test_add_post_content_integration(self, post_repository, sample_post): """Интеграционный тест добавления контента поста.""" # Настраиваем тестовую БД await self._setup_test_database(post_repository) - + # Добавляем пост await post_repository.add_post(sample_post) - + # Добавляем контент message_id = 11111 content_name = "/path/to/test/photo.jpg" content_type = "photo" - + # Сначала нужно добавить сообщение с этим message_id в post_from_telegram_suggest # или использовать существующий message_id content_post = TelegramPost( message_id=message_id, text="Сообщение с контентом", author_id=11111, # Используем существующего пользователя - created_at=int(datetime.now().timestamp()) + created_at=int(datetime.now().timestamp()), ) await post_repository.add_post(content_post) - + result = await post_repository.add_post_content( sample_post.message_id, message_id, content_name, content_type ) - + # Проверяем, что контент добавлен успешно assert result is True - + # Проверяем, что контент действительно добавлен - post_content = await post_repository.get_post_content_by_helper_id(sample_post.message_id) + post_content = await post_repository.get_post_content_by_helper_id( + sample_post.message_id + ) # Поскольку у нас нет helper_message_id, контент не будет найден # Это нормальное поведение для данного теста assert isinstance(post_content, list) - + @pytest.mark.asyncio - async def test_add_post_content_with_helper_message_integration(self, post_repository, sample_post_with_helper): + async def test_add_post_content_with_helper_message_integration( + self, post_repository, sample_post_with_helper + ): """Интеграционный тест добавления контента поста с helper сообщением.""" # Настраиваем тестовую БД await self._setup_test_database(post_repository) - + # Добавляем пост await post_repository.add_post(sample_post_with_helper) - + # Создаем helper сообщение helper_message_id = 99999 helper_post = TelegramPost( message_id=helper_message_id, text="Helper сообщение", author_id=67890, - created_at=int(datetime.now().timestamp()) + created_at=int(datetime.now().timestamp()), ) await post_repository.add_post(helper_post) - + # Обновляем пост, чтобы он ссылался на helper сообщение - await post_repository.update_helper_message(sample_post_with_helper.message_id, helper_message_id) - + await post_repository.update_helper_message( + sample_post_with_helper.message_id, helper_message_id + ) + # Добавляем контент message_id = 22222 content_name = "/path/to/test/video.mp4" content_type = "video" - + # Сначала нужно добавить сообщение с этим message_id в post_from_telegram_suggest content_post = TelegramPost( message_id=message_id, text="Сообщение с видео контентом", author_id=11111, # Используем существующего пользователя - created_at=int(datetime.now().timestamp()) + created_at=int(datetime.now().timestamp()), ) await post_repository.add_post(content_post) - + result = await post_repository.add_post_content( sample_post_with_helper.message_id, message_id, content_name, content_type ) - + # Проверяем, что контент добавлен успешно assert result is True - + # Проверяем, что контент действительно добавлен - post_content = await post_repository.get_post_content_by_helper_id(helper_message_id) + post_content = await post_repository.get_post_content_by_helper_id( + helper_message_id + ) assert len(post_content) == 1 assert post_content[0][0] == content_name assert post_content[0][1] == content_type - + @pytest.mark.asyncio - async def test_get_post_text_by_helper_id_integration(self, post_repository, sample_post_with_helper): + async def test_get_post_text_by_helper_id_integration( + self, post_repository, sample_post_with_helper + ): """Интеграционный тест получения текста поста по helper ID.""" # Настраиваем тестовую БД await self._setup_test_database(post_repository) - + # Добавляем пост await post_repository.add_post(sample_post_with_helper) - + # Создаем helper сообщение helper_message_id = 99999 helper_post = TelegramPost( message_id=helper_message_id, text="Helper сообщение", author_id=67890, - created_at=int(datetime.now().timestamp()) + created_at=int(datetime.now().timestamp()), ) await post_repository.add_post(helper_post) - + # Обновляем пост, чтобы он ссылался на helper сообщение - await post_repository.update_helper_message(sample_post_with_helper.message_id, helper_message_id) - + await post_repository.update_helper_message( + sample_post_with_helper.message_id, helper_message_id + ) + # Получаем текст поста post_text = await post_repository.get_post_text_by_helper_id(helper_message_id) - + # Проверяем результат assert post_text == sample_post_with_helper.text - + @pytest.mark.asyncio - async def test_get_post_text_by_helper_id_not_found_integration(self, post_repository): + async def test_get_post_text_by_helper_id_not_found_integration( + self, post_repository + ): """Интеграционный тест получения текста поста по несуществующему helper ID.""" # Настраиваем тестовую БД await self._setup_test_database(post_repository) - + # Пытаемся получить текст поста по несуществующему helper ID post_text = await post_repository.get_post_text_by_helper_id(99999) - + # Проверяем, что результат None assert post_text is None - + @pytest.mark.asyncio - async def test_get_post_ids_by_helper_id_integration(self, post_repository, sample_post_with_helper): + async def test_get_post_ids_by_helper_id_integration( + self, post_repository, sample_post_with_helper + ): """Интеграционный тест получения ID сообщений по helper ID.""" # Настраиваем тестовую БД await self._setup_test_database(post_repository) - + # Добавляем пост await post_repository.add_post(sample_post_with_helper) - + # Создаем helper сообщение helper_message_id = 99999 helper_post = TelegramPost( message_id=helper_message_id, text="Helper сообщение", author_id=67890, - created_at=int(datetime.now().timestamp()) + created_at=int(datetime.now().timestamp()), ) await post_repository.add_post(helper_post) - + # Обновляем пост, чтобы он ссылался на helper сообщение - await post_repository.update_helper_message(sample_post_with_helper.message_id, helper_message_id) - + await post_repository.update_helper_message( + sample_post_with_helper.message_id, helper_message_id + ) + # Добавляем несколько сообщений с контентом message_ids = [33333, 44444, 55555] - content_names = ["/path/to/photo1.jpg", "/path/to/photo2.jpg", "/path/to/video.mp4"] + content_names = [ + "/path/to/photo1.jpg", + "/path/to/photo2.jpg", + "/path/to/video.mp4", + ] content_types = ["photo", "photo", "video"] - - for i, (msg_id, content_name, content_type) in enumerate(zip(message_ids, content_names, content_types)): + + for i, (msg_id, content_name, content_type) in enumerate( + zip(message_ids, content_names, content_types) + ): # Сначала нужно добавить сообщение с этим message_id в post_from_telegram_suggest content_post = TelegramPost( message_id=msg_id, text=f"Сообщение с контентом {i+1}", author_id=11111, # Используем существующего пользователя - created_at=int(datetime.now().timestamp()) + created_at=int(datetime.now().timestamp()), ) await post_repository.add_post(content_post) - + result = await post_repository.add_post_content( sample_post_with_helper.message_id, msg_id, content_name, content_type ) assert result is True - + # Получаем ID сообщений post_ids = await post_repository.get_post_ids_by_helper_id(helper_message_id) - + # Проверяем результат assert len(post_ids) == 3 for msg_id in message_ids: assert msg_id in post_ids - + @pytest.mark.asyncio - async def test_get_author_id_by_message_id_integration(self, post_repository, sample_post): + async def test_get_author_id_by_message_id_integration( + self, post_repository, sample_post + ): """Интеграционный тест получения ID автора по message ID.""" # Настраиваем тестовую БД await self._setup_test_database(post_repository) - + # Добавляем пост await post_repository.add_post(sample_post) - + # Получаем ID автора - author_id = await post_repository.get_author_id_by_message_id(sample_post.message_id) - + author_id = await post_repository.get_author_id_by_message_id( + sample_post.message_id + ) + # Проверяем результат assert author_id == sample_post.author_id - + @pytest.mark.asyncio - async def test_get_author_id_by_message_id_not_found_integration(self, post_repository): + async def test_get_author_id_by_message_id_not_found_integration( + self, post_repository + ): """Интеграционный тест получения ID автора по несуществующему message ID.""" # Настраиваем тестовую БД await self._setup_test_database(post_repository) - + # Пытаемся получить ID автора по несуществующему message ID author_id = await post_repository.get_author_id_by_message_id(99999) - + # Проверяем, что результат None assert author_id is None - + @pytest.mark.asyncio - async def test_get_author_id_by_helper_message_id_integration(self, post_repository, sample_post_with_helper): + async def test_get_author_id_by_helper_message_id_integration( + self, post_repository, sample_post_with_helper + ): """Интеграционный тест получения ID автора по helper message ID.""" # Настраиваем тестовую БД await self._setup_test_database(post_repository) - + # Добавляем пост await post_repository.add_post(sample_post_with_helper) - + # Создаем helper сообщение helper_message_id = 99999 helper_post = TelegramPost( message_id=helper_message_id, text="Helper сообщение", author_id=67890, - created_at=int(datetime.now().timestamp()) + created_at=int(datetime.now().timestamp()), ) await post_repository.add_post(helper_post) - + # Обновляем пост, чтобы он ссылался на helper сообщение - await post_repository.update_helper_message(sample_post_with_helper.message_id, helper_message_id) - + await post_repository.update_helper_message( + sample_post_with_helper.message_id, helper_message_id + ) + # Получаем ID автора - author_id = await post_repository.get_author_id_by_helper_message_id(helper_message_id) - + author_id = await post_repository.get_author_id_by_helper_message_id( + helper_message_id + ) + # Проверяем результат assert author_id == sample_post_with_helper.author_id - + @pytest.mark.asyncio - async def test_get_author_id_by_helper_message_id_not_found_integration(self, post_repository): + async def test_get_author_id_by_helper_message_id_not_found_integration( + self, post_repository + ): """Интеграционный тест получения ID автора по несуществующему helper message ID.""" # Настраиваем тестовую БД await self._setup_test_database(post_repository) - + # Пытаемся получить ID автора по несуществующему helper message ID author_id = await post_repository.get_author_id_by_helper_message_id(99999) - + # Проверяем, что результат None assert author_id is None - + @pytest.mark.asyncio - async def test_multiple_posts_integration(self, post_repository, sample_post, sample_post_2): + async def test_multiple_posts_integration( + self, post_repository, sample_post, sample_post_2 + ): """Интеграционный тест работы с несколькими постами.""" # Настраиваем тестовую БД await self._setup_test_database(post_repository) - + # Добавляем несколько постов await post_repository.add_post(sample_post) await post_repository.add_post(sample_post_2) - + # Проверяем, что оба поста добавлены - author_id_1 = await post_repository.get_author_id_by_message_id(sample_post.message_id) - author_id_2 = await post_repository.get_author_id_by_message_id(sample_post_2.message_id) - + author_id_1 = await post_repository.get_author_id_by_message_id( + sample_post.message_id + ) + author_id_2 = await post_repository.get_author_id_by_message_id( + sample_post_2.message_id + ) + assert author_id_1 == sample_post.author_id assert author_id_2 == sample_post_2.author_id - + # Проверяем, что посты имеют разные ID assert sample_post.message_id != sample_post_2.message_id assert sample_post.text != sample_post_2.text - + @pytest.mark.asyncio - async def test_post_content_relationships_integration(self, post_repository, sample_post_with_helper): + async def test_post_content_relationships_integration( + self, post_repository, sample_post_with_helper + ): """Интеграционный тест связей между постами и контентом.""" # Настраиваем тестовую БД await self._setup_test_database(post_repository) - + # Добавляем пост await post_repository.add_post(sample_post_with_helper) - + # Создаем helper сообщение helper_message_id = 99999 helper_post = TelegramPost( message_id=helper_message_id, text="Helper сообщение", author_id=67890, - created_at=int(datetime.now().timestamp()) + created_at=int(datetime.now().timestamp()), ) await post_repository.add_post(helper_post) - + # Обновляем пост, чтобы он ссылался на helper сообщение - await post_repository.update_helper_message(sample_post_with_helper.message_id, helper_message_id) - + await post_repository.update_helper_message( + sample_post_with_helper.message_id, helper_message_id + ) + # Добавляем контент разных типов content_data = [ (11111, "/path/to/photo1.jpg", "photo"), (22222, "/path/to/video1.mp4", "video"), (33333, "/path/to/audio1.mp3", "audio"), - (44444, "/path/to/photo2.jpg", "photo") + (44444, "/path/to/photo2.jpg", "photo"), ] - + for message_id, content_name, content_type in content_data: # Сначала нужно добавить сообщение с этим message_id в post_from_telegram_suggest content_post = TelegramPost( message_id=message_id, text=f"Сообщение с контентом {content_type}", author_id=11111, # Используем существующего пользователя - created_at=int(datetime.now().timestamp()) + created_at=int(datetime.now().timestamp()), ) await post_repository.add_post(content_post) - + result = await post_repository.add_post_content( - sample_post_with_helper.message_id, message_id, content_name, content_type + sample_post_with_helper.message_id, + message_id, + content_name, + content_type, ) assert result is True - + # Проверяем, что весь контент добавлен - post_content = await post_repository.get_post_content_by_helper_id(helper_message_id) + post_content = await post_repository.get_post_content_by_helper_id( + helper_message_id + ) assert len(post_content) == 4 - + # Проверяем, что ID сообщений получены правильно post_ids = await post_repository.get_post_ids_by_helper_id(helper_message_id) assert len(post_ids) == 4 - + # Проверяем, что все ожидаемые ID присутствуют expected_message_ids = [11111, 22222, 33333, 44444] for expected_id in expected_message_ids: assert expected_id in post_ids @pytest.mark.asyncio - async def test_update_status_by_message_id_integration(self, post_repository, sample_post): + async def test_update_status_by_message_id_integration( + self, post_repository, sample_post + ): """Интеграционный тест обновления статуса одиночного поста.""" await self._setup_test_database(post_repository) await post_repository.add_post(sample_post) - await post_repository.update_status_by_message_id(sample_post.message_id, "approved") + await post_repository.update_status_by_message_id( + sample_post.message_id, "approved" + ) rows = await post_repository._execute_query_with_result( "SELECT status FROM post_from_telegram_suggest WHERE message_id = ?", diff --git a/tests/test_post_service.py b/tests/test_post_service.py index 03bcba3..ffabd2b 100644 --- a/tests/test_post_service.py +++ b/tests/test_post_service.py @@ -5,13 +5,14 @@ from unittest.mock import AsyncMock, MagicMock, Mock, patch import pytest from aiogram import types + from database.models import TelegramPost, User from helper_bot.handlers.private.services import BotSettings, PostService class TestPostService: """Test class for PostService""" - + @pytest.fixture def mock_db(self): """Mock database""" @@ -21,7 +22,7 @@ class TestPostService: db.get_user_by_id = AsyncMock() db.add_message_link = AsyncMock() return db - + @pytest.fixture def mock_settings(self): """Mock bot settings""" @@ -33,14 +34,14 @@ class TestPostService: important_logs="test_important", preview_link="test_link", logs="test_logs_setting", - test="test_test_setting" + test="test_test_setting", ) - + @pytest.fixture def post_service(self, mock_db, mock_settings): """Create PostService instance""" return PostService(mock_db, mock_settings) - + @pytest.fixture def mock_message(self): """Mock Telegram message""" @@ -57,243 +58,451 @@ class TestPostService: message.chat = Mock() message.chat.id = 12345 return message - + @pytest.mark.asyncio - async def test_handle_text_post_saves_raw_text(self, post_service, mock_message, mock_db): + async def test_handle_text_post_saves_raw_text( + self, post_service, mock_message, mock_db + ): """Test that handle_text_post saves raw text to database""" mock_sent_message = Mock() mock_sent_message.message_id = 200 - - with patch('helper_bot.handlers.private.services.get_text_message', return_value="Formatted text"): - with patch('helper_bot.handlers.private.services.send_text_message', return_value=mock_sent_message): - with patch('helper_bot.handlers.private.services.get_first_name', return_value="Test"): - with patch('helper_bot.handlers.private.services.get_reply_keyboard_for_post', return_value=None): - with patch('helper_bot.handlers.private.services.determine_anonymity', return_value=False): - + + with patch( + "helper_bot.handlers.private.services.get_text_message", + return_value="Formatted text", + ): + with patch( + "helper_bot.handlers.private.services.send_text_message", + return_value=mock_sent_message, + ): + with patch( + "helper_bot.handlers.private.services.get_first_name", + return_value="Test", + ): + with patch( + "helper_bot.handlers.private.services.get_reply_keyboard_for_post", + return_value=None, + ): + with patch( + "helper_bot.handlers.private.services.determine_anonymity", + return_value=False, + ): + await post_service.handle_text_post(mock_message, "Test") - + # Check that add_post was called mock_db.add_post.assert_called_once() call_args = mock_db.add_post.call_args[0][0] - + # Check that raw text is saved assert isinstance(call_args, TelegramPost) assert call_args.text == "Тестовый пост" # Raw text assert call_args.message_id == 200 assert call_args.author_id == 12345 assert call_args.is_anonymous is False - + @pytest.mark.asyncio - async def test_handle_text_post_determines_anonymity(self, post_service, mock_message, mock_db): + async def test_handle_text_post_determines_anonymity( + self, post_service, mock_message, mock_db + ): """Test that handle_text_post determines anonymity correctly""" mock_message.text = "Тестовый пост анон" mock_sent_message = Mock() mock_sent_message.message_id = 200 - - with patch('helper_bot.handlers.private.services.get_text_message', return_value="Formatted text"): - with patch('helper_bot.handlers.private.services.send_text_message', return_value=mock_sent_message): - with patch('helper_bot.handlers.private.services.get_first_name', return_value="Test"): - with patch('helper_bot.handlers.private.services.get_reply_keyboard_for_post', return_value=None): - with patch('helper_bot.handlers.private.services.determine_anonymity', return_value=True): - + + with patch( + "helper_bot.handlers.private.services.get_text_message", + return_value="Formatted text", + ): + with patch( + "helper_bot.handlers.private.services.send_text_message", + return_value=mock_sent_message, + ): + with patch( + "helper_bot.handlers.private.services.get_first_name", + return_value="Test", + ): + with patch( + "helper_bot.handlers.private.services.get_reply_keyboard_for_post", + return_value=None, + ): + with patch( + "helper_bot.handlers.private.services.determine_anonymity", + return_value=True, + ): + await post_service.handle_text_post(mock_message, "Test") - + call_args = mock_db.add_post.call_args[0][0] assert call_args.is_anonymous is True - + @pytest.mark.asyncio - async def test_handle_photo_post_saves_raw_caption(self, post_service, mock_message, mock_db): + async def test_handle_photo_post_saves_raw_caption( + self, post_service, mock_message, mock_db + ): """Test that handle_photo_post saves raw caption to database""" mock_message.caption = "Тестовая подпись" mock_message.photo = [Mock()] mock_message.photo[-1].file_id = "photo_123" - + sent_message = Mock() sent_message.message_id = 201 sent_message.caption = "Formatted caption" - - with patch('helper_bot.handlers.private.services.get_text_message', return_value="Formatted caption"): - with patch('helper_bot.handlers.private.services.send_photo_message', return_value=sent_message): - with patch('helper_bot.handlers.private.services.get_first_name', return_value="Test"): - with patch('helper_bot.handlers.private.services.get_reply_keyboard_for_post', return_value=None): - with patch('helper_bot.handlers.private.services.determine_anonymity', return_value=False): - with patch('helper_bot.handlers.private.services.add_in_db_media', return_value=True): - - await post_service.handle_photo_post(mock_message, "Test") - + + with patch( + "helper_bot.handlers.private.services.get_text_message", + return_value="Formatted caption", + ): + with patch( + "helper_bot.handlers.private.services.send_photo_message", + return_value=sent_message, + ): + with patch( + "helper_bot.handlers.private.services.get_first_name", + return_value="Test", + ): + with patch( + "helper_bot.handlers.private.services.get_reply_keyboard_for_post", + return_value=None, + ): + with patch( + "helper_bot.handlers.private.services.determine_anonymity", + return_value=False, + ): + with patch( + "helper_bot.handlers.private.services.add_in_db_media", + return_value=True, + ): + + await post_service.handle_photo_post( + mock_message, "Test" + ) + mock_db.add_post.assert_called_once() call_args = mock_db.add_post.call_args[0][0] - + # Check that raw caption is saved - assert call_args.text == "Тестовая подпись" # Raw caption + assert ( + call_args.text == "Тестовая подпись" + ) # Raw caption assert call_args.message_id == 201 assert call_args.is_anonymous is False - + @pytest.mark.asyncio - async def test_handle_photo_post_without_caption(self, post_service, mock_message, mock_db): + async def test_handle_photo_post_without_caption( + self, post_service, mock_message, mock_db + ): """Test that handle_photo_post handles missing caption""" mock_message.caption = None mock_message.photo = [Mock()] mock_message.photo[-1].file_id = "photo_123" - + sent_message = Mock() sent_message.message_id = 202 - - with patch('helper_bot.handlers.private.services.get_text_message', return_value=""): - with patch('helper_bot.handlers.private.services.send_photo_message', return_value=sent_message): - with patch('helper_bot.handlers.private.services.get_first_name', return_value="Test"): - with patch('helper_bot.handlers.private.services.get_reply_keyboard_for_post', return_value=None): - with patch('helper_bot.handlers.private.services.determine_anonymity', return_value=False): - with patch('helper_bot.handlers.private.services.add_in_db_media', return_value=True): - - await post_service.handle_photo_post(mock_message, "Test") - + + with patch( + "helper_bot.handlers.private.services.get_text_message", return_value="" + ): + with patch( + "helper_bot.handlers.private.services.send_photo_message", + return_value=sent_message, + ): + with patch( + "helper_bot.handlers.private.services.get_first_name", + return_value="Test", + ): + with patch( + "helper_bot.handlers.private.services.get_reply_keyboard_for_post", + return_value=None, + ): + with patch( + "helper_bot.handlers.private.services.determine_anonymity", + return_value=False, + ): + with patch( + "helper_bot.handlers.private.services.add_in_db_media", + return_value=True, + ): + + await post_service.handle_photo_post( + mock_message, "Test" + ) + call_args = mock_db.add_post.call_args[0][0] - assert call_args.text == "" # Empty string for missing caption + assert ( + call_args.text == "" + ) # Empty string for missing caption assert call_args.is_anonymous is False - + @pytest.mark.asyncio - async def test_handle_video_post_saves_raw_caption(self, post_service, mock_message, mock_db): + async def test_handle_video_post_saves_raw_caption( + self, post_service, mock_message, mock_db + ): """Test that handle_video_post saves raw caption to database""" mock_message.caption = "Видео подпись" mock_message.video = Mock() mock_message.video.file_id = "video_123" - + sent_message = Mock() sent_message.message_id = 203 - - with patch('helper_bot.handlers.private.services.get_text_message', return_value="Formatted"): - with patch('helper_bot.handlers.private.services.send_video_message', return_value=sent_message): - with patch('helper_bot.handlers.private.services.get_first_name', return_value="Test"): - with patch('helper_bot.handlers.private.services.get_reply_keyboard_for_post', return_value=None): - with patch('helper_bot.handlers.private.services.determine_anonymity', return_value=True): - with patch('helper_bot.handlers.private.services.add_in_db_media', return_value=True): - - await post_service.handle_video_post(mock_message, "Test") - + + with patch( + "helper_bot.handlers.private.services.get_text_message", + return_value="Formatted", + ): + with patch( + "helper_bot.handlers.private.services.send_video_message", + return_value=sent_message, + ): + with patch( + "helper_bot.handlers.private.services.get_first_name", + return_value="Test", + ): + with patch( + "helper_bot.handlers.private.services.get_reply_keyboard_for_post", + return_value=None, + ): + with patch( + "helper_bot.handlers.private.services.determine_anonymity", + return_value=True, + ): + with patch( + "helper_bot.handlers.private.services.add_in_db_media", + return_value=True, + ): + + await post_service.handle_video_post( + mock_message, "Test" + ) + call_args = mock_db.add_post.call_args[0][0] assert call_args.text == "Видео подпись" # Raw caption assert call_args.is_anonymous is True - + @pytest.mark.asyncio - async def test_handle_audio_post_saves_raw_caption(self, post_service, mock_message, mock_db): + async def test_handle_audio_post_saves_raw_caption( + self, post_service, mock_message, mock_db + ): """Test that handle_audio_post saves raw caption to database""" mock_message.caption = "Аудио подпись" mock_message.audio = Mock() mock_message.audio.file_id = "audio_123" - + sent_message = Mock() sent_message.message_id = 204 - - with patch('helper_bot.handlers.private.services.get_text_message', return_value="Formatted"): - with patch('helper_bot.handlers.private.services.send_audio_message', return_value=sent_message): - with patch('helper_bot.handlers.private.services.get_first_name', return_value="Test"): - with patch('helper_bot.handlers.private.services.get_reply_keyboard_for_post', return_value=None): - with patch('helper_bot.handlers.private.services.determine_anonymity', return_value=False): - with patch('helper_bot.handlers.private.services.add_in_db_media', return_value=True): - - await post_service.handle_audio_post(mock_message, "Test") - + + with patch( + "helper_bot.handlers.private.services.get_text_message", + return_value="Formatted", + ): + with patch( + "helper_bot.handlers.private.services.send_audio_message", + return_value=sent_message, + ): + with patch( + "helper_bot.handlers.private.services.get_first_name", + return_value="Test", + ): + with patch( + "helper_bot.handlers.private.services.get_reply_keyboard_for_post", + return_value=None, + ): + with patch( + "helper_bot.handlers.private.services.determine_anonymity", + return_value=False, + ): + with patch( + "helper_bot.handlers.private.services.add_in_db_media", + return_value=True, + ): + + await post_service.handle_audio_post( + mock_message, "Test" + ) + call_args = mock_db.add_post.call_args[0][0] assert call_args.text == "Аудио подпись" # Raw caption assert call_args.is_anonymous is False - + @pytest.mark.asyncio - async def test_handle_video_note_post_saves_empty_string(self, post_service, mock_message, mock_db): + async def test_handle_video_note_post_saves_empty_string( + self, post_service, mock_message, mock_db + ): """Test that handle_video_note_post saves empty string""" mock_message.video_note = Mock() mock_message.video_note.file_id = "video_note_123" - + sent_message = Mock() sent_message.message_id = 205 - - with patch('helper_bot.handlers.private.services.send_video_note_message', return_value=sent_message): - with patch('helper_bot.handlers.private.services.get_reply_keyboard_for_post', return_value=None): - with patch('helper_bot.handlers.private.services.determine_anonymity', return_value=False): - with patch('helper_bot.handlers.private.services.add_in_db_media', return_value=True): - + + with patch( + "helper_bot.handlers.private.services.send_video_note_message", + return_value=sent_message, + ): + with patch( + "helper_bot.handlers.private.services.get_reply_keyboard_for_post", + return_value=None, + ): + with patch( + "helper_bot.handlers.private.services.determine_anonymity", + return_value=False, + ): + with patch( + "helper_bot.handlers.private.services.add_in_db_media", + return_value=True, + ): + await post_service.handle_video_note_post(mock_message) - + call_args = mock_db.add_post.call_args[0][0] assert call_args.text == "" # Empty string assert call_args.is_anonymous is False - + @pytest.mark.asyncio - async def test_handle_voice_post_saves_empty_string(self, post_service, mock_message, mock_db): + async def test_handle_voice_post_saves_empty_string( + self, post_service, mock_message, mock_db + ): """Test that handle_voice_post saves empty string""" mock_message.voice = Mock() mock_message.voice.file_id = "voice_123" - + sent_message = Mock() sent_message.message_id = 206 - - with patch('helper_bot.handlers.private.services.send_voice_message', return_value=sent_message): - with patch('helper_bot.handlers.private.services.get_reply_keyboard_for_post', return_value=None): - with patch('helper_bot.handlers.private.services.determine_anonymity', return_value=False): - with patch('helper_bot.handlers.private.services.add_in_db_media', return_value=True): - + + with patch( + "helper_bot.handlers.private.services.send_voice_message", + return_value=sent_message, + ): + with patch( + "helper_bot.handlers.private.services.get_reply_keyboard_for_post", + return_value=None, + ): + with patch( + "helper_bot.handlers.private.services.determine_anonymity", + return_value=False, + ): + with patch( + "helper_bot.handlers.private.services.add_in_db_media", + return_value=True, + ): + await post_service.handle_voice_post(mock_message) - + call_args = mock_db.add_post.call_args[0][0] assert call_args.text == "" # Empty string assert call_args.is_anonymous is False - + @pytest.mark.asyncio - async def test_handle_media_group_post_saves_raw_caption(self, post_service, mock_message, mock_db): + async def test_handle_media_group_post_saves_raw_caption( + self, post_service, mock_message, mock_db + ): """Test that handle_media_group_post saves raw caption to database""" mock_message.message_id = 300 mock_message.media_group_id = 1 - + album = [Mock()] album[0].caption = "Медиагруппа подпись" - + mock_helper_message = Mock() mock_helper_message.message_id = 302 - - with patch('helper_bot.handlers.private.services.get_text_message', return_value="Formatted"): - with patch('helper_bot.handlers.private.services.prepare_media_group_from_middlewares', return_value=[]): - with patch('helper_bot.handlers.private.services.send_media_group_message_to_private_chat', return_value=[301]): - with patch('helper_bot.handlers.private.services.get_first_name', return_value="Test"): - with patch('helper_bot.handlers.private.services.get_reply_keyboard_for_post', return_value=None): - with patch('helper_bot.handlers.private.services.send_text_message', return_value=mock_helper_message): - with patch('helper_bot.handlers.private.services.determine_anonymity', return_value=True): - with patch('asyncio.sleep', new_callable=AsyncMock): - - await post_service.handle_media_group_post(mock_message, album, "Test") - + + with patch( + "helper_bot.handlers.private.services.get_text_message", + return_value="Formatted", + ): + with patch( + "helper_bot.handlers.private.services.prepare_media_group_from_middlewares", + return_value=[], + ): + with patch( + "helper_bot.handlers.private.services.send_media_group_message_to_private_chat", + return_value=[301], + ): + with patch( + "helper_bot.handlers.private.services.get_first_name", + return_value="Test", + ): + with patch( + "helper_bot.handlers.private.services.get_reply_keyboard_for_post", + return_value=None, + ): + with patch( + "helper_bot.handlers.private.services.send_text_message", + return_value=mock_helper_message, + ): + with patch( + "helper_bot.handlers.private.services.determine_anonymity", + return_value=True, + ): + with patch("asyncio.sleep", new_callable=AsyncMock): + + await post_service.handle_media_group_post( + mock_message, album, "Test" + ) + # Check main post calls = mock_db.add_post.call_args_list main_post = calls[0][0][0] - - assert main_post.text == "Медиагруппа подпись" # Raw caption - assert main_post.message_id == 301 # Последний message_id из списка + + assert ( + main_post.text == "Медиагруппа подпись" + ) # Raw caption + assert ( + main_post.message_id == 301 + ) # Последний message_id из списка assert main_post.is_anonymous is True - + @pytest.mark.asyncio - async def test_handle_media_group_post_without_caption(self, post_service, mock_message, mock_db): + async def test_handle_media_group_post_without_caption( + self, post_service, mock_message, mock_db + ): """Test that handle_media_group_post handles missing caption""" mock_message.message_id = 301 mock_message.media_group_id = 1 - + album = [Mock()] album[0].caption = None - + mock_helper_message = Mock() mock_helper_message.message_id = 303 - - with patch('helper_bot.handlers.private.services.get_text_message', return_value=" "): - with patch('helper_bot.handlers.private.services.prepare_media_group_from_middlewares', return_value=[]): - with patch('helper_bot.handlers.private.services.send_media_group_message_to_private_chat', return_value=[302]): - with patch('helper_bot.handlers.private.services.get_first_name', return_value="Test"): - with patch('helper_bot.handlers.private.services.get_reply_keyboard_for_post', return_value=None): - with patch('helper_bot.handlers.private.services.send_text_message', return_value=mock_helper_message): - with patch('helper_bot.handlers.private.services.determine_anonymity', return_value=False): - with patch('asyncio.sleep', new_callable=AsyncMock): - - await post_service.handle_media_group_post(mock_message, album, "Test") - + + with patch( + "helper_bot.handlers.private.services.get_text_message", return_value=" " + ): + with patch( + "helper_bot.handlers.private.services.prepare_media_group_from_middlewares", + return_value=[], + ): + with patch( + "helper_bot.handlers.private.services.send_media_group_message_to_private_chat", + return_value=[302], + ): + with patch( + "helper_bot.handlers.private.services.get_first_name", + return_value="Test", + ): + with patch( + "helper_bot.handlers.private.services.get_reply_keyboard_for_post", + return_value=None, + ): + with patch( + "helper_bot.handlers.private.services.send_text_message", + return_value=mock_helper_message, + ): + with patch( + "helper_bot.handlers.private.services.determine_anonymity", + return_value=False, + ): + with patch("asyncio.sleep", new_callable=AsyncMock): + + await post_service.handle_media_group_post( + mock_message, album, "Test" + ) + calls = mock_db.add_post.call_args_list main_post = calls[0][0][0] - - assert main_post.text == "" # Empty string for missing caption + + assert ( + main_post.text == "" + ) # Empty string for missing caption assert main_post.is_anonymous is False diff --git a/tests/test_rate_limiter.py b/tests/test_rate_limiter.py index 9ffbc80..ed3e7bd 100644 --- a/tests/test_rate_limiter.py +++ b/tests/test_rate_limiter.py @@ -1,25 +1,32 @@ """ Тесты для rate limiter """ + import asyncio import time from unittest.mock import AsyncMock, MagicMock, patch import pytest -from helper_bot.config.rate_limit_config import (RateLimitSettings, - get_rate_limit_config) -from helper_bot.utils.rate_limit_monitor import (RateLimitMonitor, - RateLimitStats, - record_rate_limit_request) -from helper_bot.utils.rate_limiter import (ChatRateLimiter, GlobalRateLimiter, - RateLimitConfig, RetryHandler, - TelegramRateLimiter, - send_with_rate_limit) + +from helper_bot.config.rate_limit_config import RateLimitSettings, get_rate_limit_config +from helper_bot.utils.rate_limit_monitor import ( + RateLimitMonitor, + RateLimitStats, + record_rate_limit_request, +) +from helper_bot.utils.rate_limiter import ( + ChatRateLimiter, + GlobalRateLimiter, + RateLimitConfig, + RetryHandler, + TelegramRateLimiter, + send_with_rate_limit, +) class TestRateLimitConfig: """Тесты для RateLimitConfig""" - + def test_default_config(self): """Тест создания конфигурации по умолчанию""" config = RateLimitConfig() @@ -31,91 +38,94 @@ class TestRateLimitConfig: class TestChatRateLimiter: """Тесты для ChatRateLimiter""" - + def test_initialization(self): """Тест инициализации""" config = RateLimitConfig(messages_per_second=1.0, burst_limit=2) limiter = ChatRateLimiter(config) - + assert limiter.config == config assert limiter.last_send_time == 0.0 assert limiter.burst_count == 0 assert limiter.retry_delay == 1.0 - + @pytest.mark.asyncio async def test_wait_if_needed_no_wait(self): """Тест что не ждет если не нужно""" config = RateLimitConfig(messages_per_second=10.0, burst_limit=10) limiter = ChatRateLimiter(config) - + start_time = time.time() await limiter.wait_if_needed() end_time = time.time() - + # Должно пройти очень быстро assert end_time - start_time < 0.1 - + @pytest.mark.asyncio async def test_wait_if_needed_with_wait(self): - """Тест что ждет если нужно""" - config = RateLimitConfig(messages_per_second=0.5, burst_limit=10) # 1 сообщение в 2 секунды + """Тест что ждет если нужно (sleep патчится, проверяем вызов с нужной длительностью).""" + config = RateLimitConfig( + messages_per_second=0.5, burst_limit=10 + ) # 1 сообщение в 2 секунды limiter = ChatRateLimiter(config) - - # Первый вызов не должен ждать - start_time = time.time() - await limiter.wait_if_needed() - first_call_time = time.time() - start_time - - # Второй вызов должен ждать - start_time = time.time() - await limiter.wait_if_needed() - second_call_time = time.time() - start_time - - assert first_call_time < 0.1 - assert second_call_time >= 1.8 # Должно ждать около 2 секунд - + + with patch( + "helper_bot.utils.rate_limiter.asyncio.sleep", new_callable=AsyncMock + ) as mock_sleep: + await limiter.wait_if_needed() + mock_sleep.assert_not_called() + + await limiter.wait_if_needed() + mock_sleep.assert_called_once() + # min_interval = 2.0, ждём ~2 сек + call_arg = mock_sleep.call_args[0][0] + assert 1.8 <= call_arg <= 2.2 + @pytest.mark.asyncio async def test_burst_limit(self): - """Тест ограничения burst""" + """Тест ограничения burst (sleep патчится, проверяем вызов на 3-м вызове).""" config = RateLimitConfig(messages_per_second=10.0, burst_limit=2) limiter = ChatRateLimiter(config) - - # Первые два вызова не должны ждать - start_time = time.time() - await limiter.wait_if_needed() - await limiter.wait_if_needed() - first_two_calls_time = time.time() - start_time - - # Третий вызов должен ждать - start_time = time.time() - await limiter.wait_if_needed() - third_call_time = time.time() - start_time - - assert first_two_calls_time < 0.2 # Более мягкое ограничение - assert third_call_time >= 0.8 # Должно ждать около 1 секунды (с учетом погрешности) + + with patch( + "helper_bot.utils.rate_limiter.asyncio.sleep", new_callable=AsyncMock + ) as mock_sleep: + await limiter.wait_if_needed() + await limiter.wait_if_needed() + mock_sleep.reset_mock() + + await limiter.wait_if_needed() + # Третий вызов: сначала sleep по burst (~1.0 с), затем по min_interval (~0.1 с) + assert mock_sleep.call_count >= 1 + args = [c[0][0] for c in mock_sleep.call_args_list] + burst_waits = [a for a in args if 0.8 <= a <= 1.2] + assert ( + len(burst_waits) >= 1 + ), f"Ожидался вызов sleep(~1.0) по burst, получены: {args}" class TestGlobalRateLimiter: """Тесты для GlobalRateLimiter""" - + def test_initialization(self): """Тест инициализации""" config = RateLimitConfig() limiter = GlobalRateLimiter(config) - + assert limiter.config == config assert limiter.chat_limiters == {} assert limiter.global_last_send == 0.0 - + def test_get_chat_limiter(self): """Тест получения limiter для чата""" config = RateLimitConfig() limiter = GlobalRateLimiter(config) - + chat_limiter = limiter.get_chat_limiter(123) assert isinstance(chat_limiter, ChatRateLimiter) assert limiter.chat_limiters[123] == chat_limiter - + # Повторный вызов должен вернуть тот же объект same_limiter = limiter.get_chat_limiter(123) assert same_limiter is chat_limiter @@ -123,113 +133,109 @@ class TestGlobalRateLimiter: class TestRetryHandler: """Тесты для RetryHandler""" - + def test_initialization(self): """Тест инициализации""" config = RateLimitConfig() handler = RetryHandler(config) assert handler.config == config - + @pytest.mark.asyncio async def test_execute_with_retry_success(self): """Тест успешного выполнения без retry""" config = RateLimitConfig() handler = RetryHandler(config) - + mock_func = AsyncMock(return_value="success") - + result = await handler.execute_with_retry(mock_func, 123) - + assert result == "success" mock_func.assert_called_once() - + @pytest.mark.asyncio async def test_execute_with_retry_retry_after(self): - """Тест retry после RetryAfter ошибки""" + """Тест retry после RetryAfter ошибки (sleep патчится, проверяем вызов).""" from aiogram.exceptions import TelegramRetryAfter - + config = RateLimitConfig(retry_after_multiplier=1.0, max_retry_delay=1.0) handler = RetryHandler(config) - + mock_func = AsyncMock() - # Создаем мок для TelegramRetryAfter from unittest.mock import MagicMock + retry_after_error = TelegramRetryAfter( - method=MagicMock(), - message="Flood control exceeded", - retry_after=1 # 1 секунда + method=MagicMock(), message="Flood control exceeded", retry_after=1 ) - - mock_func.side_effect = [ - retry_after_error, # Первый вызов - ошибка - "success" # Второй вызов - успех - ] - - start_time = time.time() - result = await handler.execute_with_retry(mock_func, 123, max_retries=1) - end_time = time.time() - + mock_func.side_effect = [retry_after_error, "success"] + + with patch( + "helper_bot.utils.rate_limiter.asyncio.sleep", new_callable=AsyncMock + ) as mock_sleep: + result = await handler.execute_with_retry(mock_func, 123, max_retries=1) + assert result == "success" assert mock_func.call_count == 2 - assert end_time - start_time >= 0.1 # Должно ждать + mock_sleep.assert_called_once() + assert mock_sleep.call_args[0][0] == 1.0 # retry_after class TestTelegramRateLimiter: """Тесты для TelegramRateLimiter""" - + def test_initialization(self): """Тест инициализации""" config = RateLimitConfig() limiter = TelegramRateLimiter(config) - + assert limiter.config == config assert isinstance(limiter.global_limiter, GlobalRateLimiter) assert isinstance(limiter.retry_handler, RetryHandler) - + @pytest.mark.asyncio async def test_send_with_rate_limit(self): """Тест отправки с rate limiting""" config = RateLimitConfig(messages_per_second=10.0, burst_limit=10) limiter = TelegramRateLimiter(config) - + mock_send_func = AsyncMock(return_value="sent") - + result = await limiter.send_with_rate_limit(mock_send_func, 123) - + assert result == "sent" mock_send_func.assert_called_once() class TestRateLimitMonitor: """Тесты для RateLimitMonitor""" - + def test_initialization(self): """Тест инициализации""" monitor = RateLimitMonitor() - + assert monitor.stats == {} assert isinstance(monitor.global_stats, RateLimitStats) assert monitor.max_history_size == 1000 - + def test_record_request_success(self): """Тест записи успешного запроса""" monitor = RateLimitMonitor() - + monitor.record_request(123, True, 0.5) - + assert 123 in monitor.stats chat_stats = monitor.stats[123] assert chat_stats.total_requests == 1 assert chat_stats.successful_requests == 1 assert chat_stats.failed_requests == 0 assert chat_stats.total_wait_time == 0.5 - + def test_record_request_failure(self): """Тест записи неудачного запроса""" monitor = RateLimitMonitor() - + monitor.record_request(123, False, 1.0, "RetryAfter") - + assert 123 in monitor.stats chat_stats = monitor.stats[123] assert chat_stats.total_requests == 1 @@ -237,58 +243,58 @@ class TestRateLimitMonitor: assert chat_stats.failed_requests == 1 assert chat_stats.retry_after_errors == 1 assert chat_stats.total_wait_time == 1.0 - + def test_get_chat_stats(self): """Тест получения статистики чата""" monitor = RateLimitMonitor() - + # Статистика для несуществующего чата assert monitor.get_chat_stats(999) is None - + # Записываем запрос monitor.record_request(123, True, 0.5) - + # Получаем статистику stats = monitor.get_chat_stats(123) assert stats is not None assert stats.chat_id == 123 assert stats.total_requests == 1 - + def test_success_rate_calculation(self): """Тест расчета процента успеха""" monitor = RateLimitMonitor() - + # 3 успешных, 1 неудачный monitor.record_request(123, True, 0.1) monitor.record_request(123, True, 0.2) monitor.record_request(123, True, 0.3) monitor.record_request(123, False, 0.4, "RetryAfter") - + stats = monitor.get_chat_stats(123) assert stats.success_rate == 0.75 # 3/4 - assert stats.error_rate == 0.25 # 1/4 + assert stats.error_rate == 0.25 # 1/4 class TestRateLimitConfig: """Тесты для конфигурации rate limiting""" - + def test_get_rate_limit_config(self): """Тест получения конфигурации""" # Тест production конфигурации prod_config = get_rate_limit_config("production") assert prod_config.messages_per_second == 0.5 assert prod_config.burst_limit == 2 - + # Тест development конфигурации dev_config = get_rate_limit_config("development") assert dev_config.messages_per_second == 1.0 assert dev_config.burst_limit == 3 - + # Тест strict конфигурации strict_config = get_rate_limit_config("strict") assert strict_config.messages_per_second == 0.3 assert strict_config.burst_limit == 1 - + # Тест неизвестной конфигурации (должна вернуть production) unknown_config = get_rate_limit_config("unknown") assert unknown_config.messages_per_second == 0.5 @@ -298,9 +304,9 @@ class TestRateLimitConfig: async def test_send_with_rate_limit_integration(): """Интеграционный тест для send_with_rate_limit""" mock_send_func = AsyncMock(return_value="message_sent") - + result = await send_with_rate_limit(mock_send_func, 123) - + assert result == "message_sent" mock_send_func.assert_called_once() diff --git a/tests/test_refactored_admin_handlers.py b/tests/test_refactored_admin_handlers.py index a054c28..5e9406b 100644 --- a/tests/test_refactored_admin_handlers.py +++ b/tests/test_refactored_admin_handlers.py @@ -3,43 +3,46 @@ from unittest.mock import AsyncMock, Mock, patch import pytest from aiogram import types from aiogram.fsm.context import FSMContext -from helper_bot.handlers.admin.exceptions import (InvalidInputError, - UserAlreadyBannedError, - UserNotFoundError) + +from helper_bot.handlers.admin.exceptions import ( + InvalidInputError, + UserAlreadyBannedError, + UserNotFoundError, +) from helper_bot.handlers.admin.services import AdminService, BannedUser, User class TestAdminService: """Тесты для AdminService""" - + def setup_method(self): """Настройка перед каждым тестом""" self.mock_db = Mock() self.admin_service = AdminService(self.mock_db) - + @pytest.mark.asyncio async def test_get_last_users_success(self): """Тест успешного получения списка последних пользователей""" # Arrange # Формат данных: кортежи (full_name, user_id) как возвращает БД mock_users_data = [ - ('User One', 1), # (full_name, user_id) - ('User Two', 2) # (full_name, user_id) + ("User One", 1), # (full_name, user_id) + ("User Two", 2), # (full_name, user_id) ] self.mock_db.get_last_users = AsyncMock(return_value=mock_users_data) - + # Act result = await self.admin_service.get_last_users() - + # Assert assert len(result) == 2 assert result[0].user_id == 1 - assert result[0].username == 'Неизвестно' # username не возвращается из БД - assert result[0].full_name == 'User One' + assert result[0].username == "Неизвестно" # username не возвращается из БД + assert result[0].full_name == "User One" assert result[1].user_id == 2 - assert result[1].username == 'Неизвестно' # username не возвращается из БД - assert result[1].full_name == 'User Two' - + assert result[1].username == "Неизвестно" # username не возвращается из БД + assert result[1].full_name == "User Two" + @pytest.mark.asyncio async def test_get_user_by_username_success(self): """Тест успешного получения пользователя по username""" @@ -49,95 +52,102 @@ class TestAdminService: full_name = "Test User" self.mock_db.get_user_id_by_username = AsyncMock(return_value=user_id) self.mock_db.get_full_name_by_id = AsyncMock(return_value=full_name) - + # Act result = await self.admin_service.get_user_by_username(username) - + # Assert assert result is not None assert result.user_id == user_id assert result.username == username assert result.full_name == full_name - + @pytest.mark.asyncio async def test_get_user_by_username_not_found(self): """Тест получения пользователя по несуществующему username""" # Arrange username = "nonexistent_user" self.mock_db.get_user_id_by_username = AsyncMock(return_value=None) - + # Act result = await self.admin_service.get_user_by_username(username) - + # Assert assert result is None - + @pytest.mark.asyncio async def test_get_user_by_id_success(self): """Тест успешного получения пользователя по ID""" # Arrange user_id = 123 from database.models import User as DBUser + user_info = DBUser( user_id=user_id, first_name="Test", - full_name="Test User", - username="test_user" + full_name="Test User", + username="test_user", ) self.mock_db.get_user_by_id = AsyncMock(return_value=user_info) - + # Act result = await self.admin_service.get_user_by_id(user_id) - + # Assert assert result is not None assert result.user_id == user_id - assert result.username == 'test_user' - assert result.full_name == 'Test User' - + assert result.username == "test_user" + assert result.full_name == "Test User" + @pytest.mark.asyncio async def test_get_user_by_id_not_found(self): """Тест получения пользователя по несуществующему ID""" # Arrange user_id = 999 self.mock_db.get_user_by_id = AsyncMock(return_value=None) - + # Act result = await self.admin_service.get_user_by_id(user_id) - + # Assert assert result is None - + @pytest.mark.asyncio async def test_validate_user_input_success(self): """Тест успешной валидации ID пользователя""" # Act result = await self.admin_service.validate_user_input("123") - + # Assert assert result == 123 - + @pytest.mark.asyncio async def test_validate_user_input_invalid_number(self): """Тест валидации некорректного ID""" # Act & Assert - with pytest.raises(InvalidInputError, match="ID пользователя должен быть числом"): + with pytest.raises( + InvalidInputError, match="ID пользователя должен быть числом" + ): await self.admin_service.validate_user_input("abc") - + @pytest.mark.asyncio async def test_validate_user_input_negative_number(self): """Тест валидации отрицательного ID""" # Act & Assert - with pytest.raises(InvalidInputError, match="ID пользователя должен быть положительным числом"): + with pytest.raises( + InvalidInputError, match="ID пользователя должен быть положительным числом" + ): await self.admin_service.validate_user_input("-1") - + @pytest.mark.asyncio async def test_validate_user_input_zero(self): """Тест валидации нулевого ID""" # Act & Assert - with pytest.raises(InvalidInputError, match="ID пользователя должен быть положительным числом"): + with pytest.raises( + InvalidInputError, match="ID пользователя должен быть положительным числом" + ): await self.admin_service.validate_user_input("0") - + @pytest.mark.asyncio async def test_ban_user_success(self): """Тест успешной блокировки пользователя""" @@ -146,17 +156,19 @@ class TestAdminService: username = "test_user" reason = "Test ban" ban_days = 7 - + self.mock_db.check_user_in_blacklist = AsyncMock(return_value=False) self.mock_db.set_user_blacklist = AsyncMock(return_value=None) - + # Act - await self.admin_service.ban_user(user_id, username, reason, ban_days, ban_author_id=999) - + await self.admin_service.ban_user( + user_id, username, reason, ban_days, ban_author_id=999 + ) + # Assert self.mock_db.check_user_in_blacklist.assert_called_once_with(user_id) self.mock_db.set_user_blacklist.assert_called_once() - + @pytest.mark.asyncio async def test_ban_user_already_banned(self): """Тест попытки заблокировать уже заблокированного пользователя""" @@ -165,13 +177,17 @@ class TestAdminService: username = "test_user" reason = "Test ban" ban_days = 7 - + self.mock_db.check_user_in_blacklist = AsyncMock(return_value=True) - + # Act & Assert - with pytest.raises(UserAlreadyBannedError, match=f"Пользователь {user_id} уже заблокирован"): - await self.admin_service.ban_user(user_id, username, reason, ban_days, ban_author_id=999) - + with pytest.raises( + UserAlreadyBannedError, match=f"Пользователь {user_id} уже заблокирован" + ): + await self.admin_service.ban_user( + user_id, username, reason, ban_days, ban_author_id=999 + ) + @pytest.mark.asyncio async def test_ban_user_permanent(self): """Тест постоянной блокировки пользователя""" @@ -180,26 +196,30 @@ class TestAdminService: username = "test_user" reason = "Permanent ban" ban_days = None - + self.mock_db.check_user_in_blacklist = AsyncMock(return_value=False) self.mock_db.set_user_blacklist = AsyncMock(return_value=None) - + # Act - await self.admin_service.ban_user(user_id, username, reason, ban_days, ban_author_id=999) - + await self.admin_service.ban_user( + user_id, username, reason, ban_days, ban_author_id=999 + ) + # Assert - self.mock_db.set_user_blacklist.assert_called_once_with(user_id, None, reason, None, ban_author=999) - + self.mock_db.set_user_blacklist.assert_called_once_with( + user_id, None, reason, None, ban_author=999 + ) + @pytest.mark.asyncio async def test_unban_user_success(self): """Тест успешной разблокировки пользователя""" # Arrange user_id = 123 self.mock_db.delete_user_blacklist = AsyncMock(return_value=None) - + # Act await self.admin_service.unban_user(user_id) - + # Assert self.mock_db.delete_user_blacklist.assert_called_once_with(user_id) @@ -251,12 +271,12 @@ class TestAdminService: class TestUser: """Тесты для модели User""" - + def test_user_creation(self): """Тест создания объекта User""" # Act user = User(user_id=123, username="test_user", full_name="Test User") - + # Assert assert user.user_id == 123 assert user.username == "test_user" @@ -265,17 +285,17 @@ class TestUser: class TestBannedUser: """Тесты для модели BannedUser""" - + def test_banned_user_creation(self): """Тест создания объекта BannedUser""" # Act banned_user = BannedUser( - user_id=123, - username="test_user", - reason="Test ban", - unban_date="2025-01-01" + user_id=123, + username="test_user", + reason="Test ban", + unban_date="2025-01-01", ) - + # Assert assert banned_user.user_id == 123 assert banned_user.username == "test_user" diff --git a/tests/test_refactored_group_handlers.py b/tests/test_refactored_group_handlers.py index c3d4cf6..464c95f 100644 --- a/tests/test_refactored_group_handlers.py +++ b/tests/test_refactored_group_handlers.py @@ -5,29 +5,34 @@ from unittest.mock import AsyncMock, MagicMock, Mock import pytest from aiogram import types from aiogram.fsm.context import FSMContext + from helper_bot.handlers.group.constants import ERROR_MESSAGES, FSM_STATES -from helper_bot.handlers.group.exceptions import (NoReplyToMessageError, - UserNotFoundError) -from helper_bot.handlers.group.group_handlers import (GroupHandlers, - create_group_handlers) +from helper_bot.handlers.group.exceptions import ( + NoReplyToMessageError, + UserNotFoundError, +) +from helper_bot.handlers.group.group_handlers import ( + GroupHandlers, + create_group_handlers, +) from helper_bot.handlers.group.services import AdminReplyService class TestGroupHandlers: """Test class for GroupHandlers""" - + @pytest.fixture def mock_db(self): """Mock database""" db = Mock() db.get_user_by_message_id = Mock() return db - + @pytest.fixture def mock_keyboard_markup(self): """Mock keyboard markup""" return Mock() - + @pytest.fixture def mock_message(self): """Mock Telegram message""" @@ -44,7 +49,7 @@ class TestGroupHandlers: message.bot = Mock() message.bot.send_message = AsyncMock() return message - + @pytest.fixture def mock_reply_message(self, mock_message): """Mock reply message""" @@ -52,21 +57,21 @@ class TestGroupHandlers: reply_message.message_id = 222 mock_message.reply_to_message = reply_message return mock_message - + @pytest.fixture def mock_state(self): """Mock FSM state""" state = Mock(spec=FSMContext) state.set_state = AsyncMock() return state - + def test_create_group_handlers(self, mock_db, mock_keyboard_markup): """Test creating group handlers instance""" handlers = create_group_handlers(mock_db, mock_keyboard_markup) assert isinstance(handlers, GroupHandlers) assert handlers.db == mock_db assert handlers.keyboard_markup == mock_keyboard_markup - + def test_group_handlers_initialization(self, mock_db, mock_keyboard_markup): """Test GroupHandlers initialization""" handlers = GroupHandlers(mock_db, mock_keyboard_markup) @@ -74,109 +79,121 @@ class TestGroupHandlers: assert handlers.keyboard_markup == mock_keyboard_markup assert handlers.admin_reply_service is not None assert handlers.router is not None - + @pytest.mark.asyncio - async def test_handle_message_success(self, mock_db, mock_keyboard_markup, mock_reply_message, mock_state): + async def test_handle_message_success( + self, mock_db, mock_keyboard_markup, mock_reply_message, mock_state + ): """Test successful message handling""" mock_db.get_user_by_message_id = AsyncMock(return_value=99999) - + handlers = create_group_handlers(mock_db, mock_keyboard_markup) - + # Mock the send_reply_to_user method handlers.admin_reply_service.send_reply_to_user = AsyncMock() - + await handlers.handle_message(mock_reply_message, mock_state) - + # Verify database call mock_db.get_user_by_message_id.assert_called_once_with(222) - + # Verify service call handlers.admin_reply_service.send_reply_to_user.assert_called_once_with( 99999, mock_reply_message, "test reply message", mock_keyboard_markup ) - + # Verify state was set mock_state.set_state.assert_called_once_with(FSM_STATES["CHAT"]) - + @pytest.mark.asyncio - async def test_handle_message_no_reply(self, mock_db, mock_keyboard_markup, mock_message, mock_state): + async def test_handle_message_no_reply( + self, mock_db, mock_keyboard_markup, mock_message, mock_state + ): """Test message handling without reply""" handlers = create_group_handlers(mock_db, mock_keyboard_markup) - + # Mock the send_reply_to_user method to prevent it from being called handlers.admin_reply_service.send_reply_to_user = AsyncMock() - + # Ensure reply_to_message is None mock_message.reply_to_message = None - + await handlers.handle_message(mock_message, mock_state) - + # Verify error message was sent - mock_message.answer.assert_called_once_with(ERROR_MESSAGES["NO_REPLY_TO_MESSAGE"]) - + mock_message.answer.assert_called_once_with( + ERROR_MESSAGES["NO_REPLY_TO_MESSAGE"] + ) + # Verify no database calls mock_db.get_user_by_message_id.assert_not_called() - + # Verify send_reply_to_user was not called handlers.admin_reply_service.send_reply_to_user.assert_not_called() - + # Verify state was not set mock_state.set_state.assert_not_called() - + @pytest.mark.asyncio - async def test_handle_message_user_not_found(self, mock_db, mock_keyboard_markup, mock_reply_message, mock_state): + async def test_handle_message_user_not_found( + self, mock_db, mock_keyboard_markup, mock_reply_message, mock_state + ): """Test message handling when user is not found""" mock_db.get_user_by_message_id = AsyncMock(return_value=None) - + handlers = create_group_handlers(mock_db, mock_keyboard_markup) - + await handlers.handle_message(mock_reply_message, mock_state) - + # Verify error message was sent - mock_reply_message.answer.assert_called_once_with(ERROR_MESSAGES["USER_NOT_FOUND"]) - + mock_reply_message.answer.assert_called_once_with( + ERROR_MESSAGES["USER_NOT_FOUND"] + ) + # Verify database call mock_db.get_user_by_message_id.assert_called_once_with(222) - + # Verify state was not set mock_state.set_state.assert_not_called() class TestAdminReplyService: """Test class for AdminReplyService""" - + @pytest.fixture def mock_db(self): """Mock database""" db = Mock() db.get_user_by_message_id = Mock() return db - + @pytest.fixture def service(self, mock_db): """Create service instance""" return AdminReplyService(mock_db) - + @pytest.mark.asyncio async def test_get_user_id_for_reply_success(self, service, mock_db): """Test successful user ID retrieval""" mock_db.get_user_by_message_id = AsyncMock(return_value=12345) - + result = await service.get_user_id_for_reply(111) - + assert result == 12345 mock_db.get_user_by_message_id.assert_called_once_with(111) - + @pytest.mark.asyncio async def test_get_user_id_for_reply_not_found(self, service, mock_db): """Test user ID retrieval when user not found""" mock_db.get_user_by_message_id = AsyncMock(return_value=None) - - with pytest.raises(UserNotFoundError, match="User not found for message_id: 111"): + + with pytest.raises( + UserNotFoundError, match="User not found for message_id: 111" + ): await service.get_user_id_for_reply(111) - + mock_db.get_user_by_message_id.assert_called_once_with(111) - + @pytest.mark.asyncio async def test_send_reply_to_user(self, service, mock_db): """Test sending reply to user""" @@ -184,12 +201,14 @@ class TestAdminReplyService: message.reply_to_message = Mock() message.reply_to_message.message_id = 222 markup = Mock() - + # Mock the send_text_message function with pytest.MonkeyPatch().context() as m: mock_send_text = AsyncMock() - m.setattr('helper_bot.handlers.group.services.send_text_message', mock_send_text) - + m.setattr( + "helper_bot.handlers.group.services.send_text_message", mock_send_text + ) + await service.send_reply_to_user(12345, message, "test reply", markup) - + mock_send_text.assert_called_once_with(12345, message, "test reply", markup) diff --git a/tests/test_refactored_private_handlers.py b/tests/test_refactored_private_handlers.py index daf9b81..057068f 100644 --- a/tests/test_refactored_private_handlers.py +++ b/tests/test_refactored_private_handlers.py @@ -5,6 +5,7 @@ from unittest.mock import AsyncMock, MagicMock, Mock, patch import pytest from aiogram import types from aiogram.fsm.context import FSMContext + from helper_bot.handlers.private.constants import BUTTON_TEXTS, FSM_STATES from helper_bot.handlers.private.private_handlers import ( PrivateHandlers, @@ -15,7 +16,7 @@ from helper_bot.handlers.private.services import BotSettings class TestPrivateHandlers: """Test class for PrivateHandlers""" - + @pytest.fixture def mock_db(self): """Mock database""" @@ -29,7 +30,7 @@ class TestPrivateHandlers: db.update_helper_message = AsyncMock() db.update_user_activity = AsyncMock() return db - + @pytest.fixture def mock_settings(self): """Mock bot settings""" @@ -41,9 +42,9 @@ class TestPrivateHandlers: important_logs="test_important", preview_link="test_link", logs="test_logs_setting", - test="test_test_setting" + test="test_test_setting", ) - + @pytest.fixture def mock_message(self): """Mock Telegram message""" @@ -56,7 +57,7 @@ class TestPrivateHandlers: from_user.is_bot = False from_user.language_code = "ru" message.from_user = from_user - + message.text = "test message" message.message_id = 1 @@ -71,7 +72,7 @@ class TestPrivateHandlers: message.answer = AsyncMock() message.answer_sticker = AsyncMock() return message - + @pytest.fixture def mock_state(self): """Mock FSM state""" @@ -79,14 +80,14 @@ class TestPrivateHandlers: state.set_state = AsyncMock() state.get_state = AsyncMock(return_value=FSM_STATES["START"]) return state - + def test_create_private_handlers(self, mock_db, mock_settings): """Test creating private handlers instance""" handlers = create_private_handlers(mock_db, mock_settings) assert isinstance(handlers, PrivateHandlers) assert handlers.db == mock_db assert handlers.settings == mock_settings - + def test_private_handlers_initialization(self, mock_db, mock_settings): """Test PrivateHandlers initialization""" handlers = PrivateHandlers(mock_db, mock_settings) @@ -96,25 +97,32 @@ class TestPrivateHandlers: assert handlers.post_service is not None assert handlers.sticker_service is not None assert handlers.router is not None - + @pytest.mark.asyncio - async def test_handle_emoji_message(self, mock_db, mock_settings, mock_message, mock_state): + async def test_handle_emoji_message( + self, mock_db, mock_settings, mock_message, mock_state + ): """Test emoji message handler""" handlers = create_private_handlers(mock_db, mock_settings) - + # Mock the check_user_emoji function with pytest.MonkeyPatch().context() as m: mock_check_emoji = AsyncMock(return_value="😊") - m.setattr('helper_bot.handlers.private.private_handlers.check_user_emoji', mock_check_emoji) - + m.setattr( + "helper_bot.handlers.private.private_handlers.check_user_emoji", + mock_check_emoji, + ) + # Test the handler await handlers.handle_emoji_message(mock_message, mock_state) - + # Verify state was set mock_state.set_state.assert_called_once_with(FSM_STATES["START"]) - + # Verify message was logged - mock_message.forward.assert_called_once_with(chat_id=mock_settings.group_for_logs) + mock_message.forward.assert_called_once_with( + chat_id=mock_settings.group_for_logs + ) @pytest.mark.asyncio async def test_handle_emoji_message_no_emoji(self, mock_db, mock_settings, mock_message, mock_state): @@ -126,24 +134,39 @@ class TestPrivateHandlers: mock_state.set_state.assert_called_once_with(FSM_STATES["START"]) mock_message.answer.assert_not_called() + mock_message.forward.assert_called_once_with( + chat_id=mock_settings.group_for_logs + ) + @pytest.mark.asyncio - async def test_handle_start_message(self, mock_db, mock_settings, mock_message, mock_state): + async def test_handle_start_message( + self, mock_db, mock_settings, mock_message, mock_state + ): """Test start message handler""" handlers = create_private_handlers(mock_db, mock_settings) - + # Mock the get_first_name and messages functions with pytest.MonkeyPatch().context() as m: - m.setattr('helper_bot.handlers.private.private_handlers.get_first_name', lambda x: "Test") - m.setattr('helper_bot.handlers.private.private_handlers.messages.get_message', lambda x, y: "Hello Test!") + m.setattr( + "helper_bot.handlers.private.private_handlers.get_first_name", + lambda x: "Test", + ) + m.setattr( + "helper_bot.handlers.private.private_handlers.messages.get_message", + lambda x, y: "Hello Test!", + ) mock_keyboard = AsyncMock(return_value=Mock()) - m.setattr('helper_bot.handlers.private.private_handlers.get_reply_keyboard', mock_keyboard) - + m.setattr( + "helper_bot.handlers.private.private_handlers.get_reply_keyboard", + mock_keyboard, + ) + # Test the handler await handlers.handle_start_message(mock_message, mock_state) - + # Verify state was set mock_state.set_state.assert_called_once_with(FSM_STATES["START"]) - + # Verify user was ensured to exist mock_db.add_user.assert_called_once() mock_db.update_user_date.assert_called_once() @@ -241,7 +264,7 @@ class TestPrivateHandlers: class TestBotSettings: """Test class for BotSettings dataclass""" - + def test_bot_settings_creation(self): """Test creating BotSettings instance""" settings = BotSettings( @@ -252,9 +275,9 @@ class TestBotSettings: important_logs="important", preview_link="link", logs="logs_setting", - test="test_setting" + test="test_setting", ) - + assert settings.group_for_posts == "posts" assert settings.group_for_message == "message" assert settings.main_public == "public" @@ -267,14 +290,14 @@ class TestBotSettings: class TestConstants: """Test class for constants""" - + def test_fsm_states(self): """Test FSM states constants""" assert FSM_STATES["START"] == "START" assert FSM_STATES["SUGGEST"] == "SUGGEST" assert FSM_STATES["PRE_CHAT"] == "PRE_CHAT" assert FSM_STATES["CHAT"] == "CHAT" - + def test_button_texts(self): """Test button text constants""" assert BUTTON_TEXTS["SUGGEST_POST"] == "📢Предложить свой пост" diff --git a/tests/test_scoring_services.py b/tests/test_scoring_services.py index d827390..c7d427d 100644 --- a/tests/test_scoring_services.py +++ b/tests/test_scoring_services.py @@ -6,16 +6,19 @@ import json from unittest.mock import AsyncMock, MagicMock, patch import pytest + # Импорты для тестирования базовых классов from helper_bot.services.scoring.base import CombinedScore, ScoringResult -from helper_bot.services.scoring.exceptions import (InsufficientExamplesError, - ScoringError, - TextTooShortError) +from helper_bot.services.scoring.exceptions import ( + InsufficientExamplesError, + ScoringError, + TextTooShortError, +) class TestScoringResult: """Тесты для ScoringResult.""" - + def test_create_valid_score(self): """Тест создания валидного результата.""" result = ScoringResult( @@ -26,17 +29,17 @@ class TestScoringResult: assert result.score == 0.75 assert result.source == "rag" assert result.model == "test-model" - + def test_score_validation_lower_bound(self): """Тест валидации нижней границы скора.""" with pytest.raises(ValueError): ScoringResult(score=-0.1, source="test", model="test") - + def test_score_validation_upper_bound(self): """Тест валидации верхней границы скора.""" with pytest.raises(ValueError): ScoringResult(score=1.1, source="test", model="test") - + def test_to_dict(self): """Тест преобразования в словарь.""" result = ScoringResult( @@ -47,12 +50,12 @@ class TestScoringResult: timestamp=1234567890, ) d = result.to_dict() - + assert d["score"] == 0.7534 # Округлено до 4 знаков assert d["model"] == "test-model" assert d["ts"] == 1234567890 assert d["confidence"] == 0.85 - + def test_from_dict(self): """Тест создания из словаря.""" data = { @@ -62,7 +65,7 @@ class TestScoringResult: "confidence": 0.9, } result = ScoringResult.from_dict("rag", data) - + assert result.score == 0.75 assert result.source == "rag" assert result.model == "test-model" @@ -72,49 +75,55 @@ class TestScoringResult: class TestCombinedScore: """Тесты для CombinedScore.""" - + def test_empty_combined_score(self): """Тест пустого объединенного скора.""" score = CombinedScore() - + assert score.deepseek is None assert score.rag is None assert score.deepseek_score is None assert score.rag_score is None assert not score.has_any_score() - + def test_combined_score_with_rag(self): """Тест объединенного скора с RAG.""" rag_result = ScoringResult(score=0.8, source="rag", model="rubert") score = CombinedScore(rag=rag_result) - + assert score.rag_score == 0.8 assert score.deepseek_score is None assert score.has_any_score() - + def test_combined_score_with_both(self): """Тест объединенного скора с обоими сервисами.""" rag_result = ScoringResult(score=0.8, source="rag", model="rubert") - deepseek_result = ScoringResult(score=0.7, source="deepseek", model="deepseek-chat") + deepseek_result = ScoringResult( + score=0.7, source="deepseek", model="deepseek-chat" + ) score = CombinedScore(rag=rag_result, deepseek=deepseek_result) - + assert score.rag_score == 0.8 assert score.deepseek_score == 0.7 assert score.has_any_score() - + def test_to_json_dict(self): """Тест преобразования в JSON словарь.""" - rag_result = ScoringResult(score=0.8, source="rag", model="rubert", timestamp=123) - deepseek_result = ScoringResult(score=0.7, source="deepseek", model="deepseek-chat", timestamp=456) + rag_result = ScoringResult( + score=0.8, source="rag", model="rubert", timestamp=123 + ) + deepseek_result = ScoringResult( + score=0.7, source="deepseek", model="deepseek-chat", timestamp=456 + ) score = CombinedScore(rag=rag_result, deepseek=deepseek_result) - + d = score.to_json_dict() - + assert "rag" in d assert "deepseek" in d assert d["rag"]["score"] == 0.8 assert d["deepseek"]["score"] == 0.7 - + # Проверяем что это валидный JSON json_str = json.dumps(d) assert json_str @@ -122,38 +131,40 @@ class TestCombinedScore: class TestVectorStore: """Тесты для VectorStore (требует numpy).""" - + @pytest.fixture def vector_store(self): """Создает VectorStore для тестов.""" try: import numpy as np + from helper_bot.services.scoring.vector_store import VectorStore + return VectorStore(vector_dim=768, max_examples=100) except ImportError: pytest.skip("numpy не установлен") - + def test_add_positive_example(self, vector_store): """Тест добавления положительного примера.""" import numpy as np - + vector = np.random.randn(768).astype(np.float32) result = vector_store.add_positive(vector, "hash1") - + assert result is True assert vector_store.positive_count == 1 - + def test_add_duplicate_example(self, vector_store): """Тест добавления дубликата.""" import numpy as np - + vector = np.random.randn(768).astype(np.float32) vector_store.add_positive(vector, "hash1") result = vector_store.add_positive(vector, "hash1") # Дубликат - + assert result is False assert vector_store.positive_count == 1 - + def test_max_examples_limit(self, vector_store): """Тест ограничения максимального количества примеров.""" import numpy as np @@ -162,18 +173,18 @@ class TestVectorStore: for i in range(150): vector = np.random.randn(768).astype(np.float32) vector_store.add_positive(vector, f"hash_{i}") - + assert vector_store.positive_count == 100 # max_examples - + def test_calculate_similarity_no_examples(self, vector_store): """Тест расчета скора без примеров.""" import numpy as np - + vector = np.random.randn(768).astype(np.float32) - + with pytest.raises(InsufficientExamplesError): vector_store.calculate_similarity_score(vector) - + def test_calculate_similarity_with_examples(self, vector_store): """Тест расчета скора с примерами.""" import numpy as np @@ -182,86 +193,86 @@ class TestVectorStore: for i in range(10): vector = np.random.randn(768).astype(np.float32) vector_store.add_positive(vector, f"pos_{i}") - + # Добавляем отрицательные примеры for i in range(10): vector = np.random.randn(768).astype(np.float32) vector_store.add_negative(vector, f"neg_{i}") - + # Рассчитываем скор для нового вектора test_vector = np.random.randn(768).astype(np.float32) score, confidence = vector_store.calculate_similarity_score(test_vector) - + assert 0.0 <= score <= 1.0 assert 0.0 <= confidence <= 1.0 - + def test_compute_text_hash(self, vector_store): """Тест вычисления хеша текста.""" from helper_bot.services.scoring.vector_store import VectorStore - + hash1 = VectorStore.compute_text_hash("Привет мир") hash2 = VectorStore.compute_text_hash("Привет мир") hash3 = VectorStore.compute_text_hash("Другой текст") - + assert hash1 == hash2 assert hash1 != hash3 class TestDeepSeekService: """Тесты для DeepSeekService.""" - + @pytest.fixture def deepseek_service(self): """Создает DeepSeekService для тестов.""" - from helper_bot.services.scoring.deepseek_service import \ - DeepSeekService + from helper_bot.services.scoring.deepseek_service import DeepSeekService + return DeepSeekService( api_key="test_key", enabled=True, timeout=5, ) - + def test_service_disabled_without_key(self): """Тест отключения сервиса без API ключа.""" - from helper_bot.services.scoring.deepseek_service import \ - DeepSeekService + from helper_bot.services.scoring.deepseek_service import DeepSeekService + service = DeepSeekService(api_key=None, enabled=True) - + assert service.is_enabled is False - + def test_parse_score_response_valid(self, deepseek_service): """Тест парсинга валидного ответа.""" assert deepseek_service._parse_score_response("0.75") == 0.75 assert deepseek_service._parse_score_response("0.5") == 0.5 assert deepseek_service._parse_score_response("1.0") == 1.0 assert deepseek_service._parse_score_response("0") == 0.0 - + def test_parse_score_response_with_quotes(self, deepseek_service): """Тест парсинга ответа с кавычками.""" assert deepseek_service._parse_score_response('"0.75"') == 0.75 assert deepseek_service._parse_score_response("'0.8'") == 0.8 - + def test_parse_score_response_with_text(self, deepseek_service): """Тест парсинга ответа с текстом.""" # Сервис должен найти число в тексте assert deepseek_service._parse_score_response("Score: 0.75") == 0.75 - + def test_clean_text(self, deepseek_service): """Тест очистки текста.""" assert deepseek_service._clean_text(" hello world ") == "hello world" assert deepseek_service._clean_text("^") == "" assert deepseek_service._clean_text("") == "" - + @pytest.mark.asyncio async def test_calculate_score_disabled(self): """Тест расчета скора при отключенном сервисе.""" - from helper_bot.services.scoring.deepseek_service import \ - DeepSeekService + from helper_bot.services.scoring.deepseek_service import DeepSeekService + service = DeepSeekService(api_key=None, enabled=False) - + with pytest.raises(ScoringError): await service.calculate_score("Test text") - + @pytest.mark.asyncio async def test_calculate_score_short_text(self, deepseek_service): """Тест расчета скора для короткого текста.""" @@ -271,125 +282,143 @@ class TestDeepSeekService: class TestScoringManager: """Тесты для ScoringManager.""" - + @pytest.fixture def mock_rag_service(self): """Создает мок RAG сервиса.""" mock = AsyncMock() mock.is_enabled = True - mock.calculate_score = AsyncMock(return_value=ScoringResult( - score=0.8, - source="rag", - model="rubert", - )) + mock.calculate_score = AsyncMock( + return_value=ScoringResult( + score=0.8, + source="rag", + model="rubert", + ) + ) mock.add_positive_example = AsyncMock() mock.add_negative_example = AsyncMock() return mock - + @pytest.fixture def mock_deepseek_service(self): """Создает мок DeepSeek сервиса.""" mock = AsyncMock() mock.is_enabled = True - mock.calculate_score = AsyncMock(return_value=ScoringResult( - score=0.7, - source="deepseek", - model="deepseek-chat", - )) + mock.calculate_score = AsyncMock( + return_value=ScoringResult( + score=0.7, + source="deepseek", + model="deepseek-chat", + ) + ) mock.add_positive_example = AsyncMock() mock.add_negative_example = AsyncMock() return mock - + @pytest.mark.asyncio - async def test_score_post_both_services(self, mock_rag_service, mock_deepseek_service): + async def test_score_post_both_services( + self, mock_rag_service, mock_deepseek_service + ): """Тест скоринга с обоими сервисами.""" from helper_bot.services.scoring.scoring_manager import ScoringManager - + manager = ScoringManager( rag_client=mock_rag_service, deepseek_service=mock_deepseek_service, ) - + result = await manager.score_post("Тестовый пост") - + assert result.rag_score == 0.8 assert result.deepseek_score == 0.7 assert result.has_any_score() - + @pytest.mark.asyncio async def test_score_post_rag_only(self, mock_rag_service): """Тест скоринга только с RAG.""" from helper_bot.services.scoring.scoring_manager import ScoringManager - + manager = ScoringManager( rag_client=mock_rag_service, deepseek_service=None, ) - + result = await manager.score_post("Тестовый пост") - + assert result.rag_score == 0.8 assert result.deepseek_score is None - + @pytest.mark.asyncio async def test_score_post_empty_text(self, mock_rag_service): """Тест скоринга пустого текста.""" from helper_bot.services.scoring.scoring_manager import ScoringManager - + manager = ScoringManager(rag_client=mock_rag_service) - + result = await manager.score_post("") - + assert not result.has_any_score() mock_rag_service.calculate_score.assert_not_called() - + @pytest.mark.asyncio - async def test_score_post_service_error(self, mock_rag_service, mock_deepseek_service): + async def test_score_post_service_error( + self, mock_rag_service, mock_deepseek_service + ): """Тест обработки ошибки сервиса.""" from helper_bot.services.scoring.scoring_manager import ScoringManager # RAG выбрасывает ошибку - mock_rag_service.calculate_score = AsyncMock(side_effect=Exception("Test error")) - + mock_rag_service.calculate_score = AsyncMock( + side_effect=Exception("Test error") + ) + manager = ScoringManager( rag_client=mock_rag_service, deepseek_service=mock_deepseek_service, ) - + result = await manager.score_post("Тестовый пост") - + # DeepSeek должен вернуть результат assert result.deepseek_score == 0.7 # RAG должен быть None с ошибкой assert result.rag_score is None assert "rag" in result.errors - + @pytest.mark.asyncio async def test_on_post_published(self, mock_rag_service, mock_deepseek_service): """Тест обучения на опубликованном посте.""" from helper_bot.services.scoring.scoring_manager import ScoringManager - + manager = ScoringManager( rag_client=mock_rag_service, deepseek_service=mock_deepseek_service, ) - + await manager.on_post_published("Опубликованный пост") - - mock_rag_service.add_positive_example.assert_called_once_with("Опубликованный пост") - mock_deepseek_service.add_positive_example.assert_called_once_with("Опубликованный пост") - + + mock_rag_service.add_positive_example.assert_called_once_with( + "Опубликованный пост" + ) + mock_deepseek_service.add_positive_example.assert_called_once_with( + "Опубликованный пост" + ) + @pytest.mark.asyncio async def test_on_post_declined(self, mock_rag_service, mock_deepseek_service): """Тест обучения на отклоненном посте.""" from helper_bot.services.scoring.scoring_manager import ScoringManager - + manager = ScoringManager( rag_client=mock_rag_service, deepseek_service=mock_deepseek_service, ) - + await manager.on_post_declined("Отклоненный пост") - - mock_rag_service.add_negative_example.assert_called_once_with("Отклоненный пост") - mock_deepseek_service.add_negative_example.assert_called_once_with("Отклоненный пост") + + mock_rag_service.add_negative_example.assert_called_once_with( + "Отклоненный пост" + ) + mock_deepseek_service.add_negative_example.assert_called_once_with( + "Отклоненный пост" + ) diff --git a/tests/test_utils.py b/tests/test_utils.py index 2953237..d511122 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -2,27 +2,47 @@ import os from datetime import datetime from unittest.mock import AsyncMock, Mock, patch -import helper_bot.utils.messages as messages # Import for patching constants import pytest + +import helper_bot.utils.messages as messages # Import for patching constants from database.async_db import AsyncBotDB -from helper_bot.utils.base_dependency_factory import (BaseDependencyFactory, - get_global_instance) +from helper_bot.utils.base_dependency_factory import ( + BaseDependencyFactory, + get_global_instance, +) from helper_bot.utils.helper_func import ( - add_days_to_date, add_in_db_media, add_in_db_media_mediagroup, - check_access, check_user_emoji, check_username_and_full_name, - delete_user_blacklist, determine_anonymity, download_file, - get_banned_users_buttons, get_banned_users_list, get_first_name, - get_random_emoji, get_text_message, prepare_media_group_from_middlewares, - safe_html_escape, send_audio_message, - send_media_group_message_to_private_chat, send_media_group_to_channel, - send_photo_message, send_text_message, send_video_message, - send_video_note_message, send_voice_message, update_user_info) + add_days_to_date, + add_in_db_media, + add_in_db_media_mediagroup, + check_access, + check_user_emoji, + check_username_and_full_name, + delete_user_blacklist, + determine_anonymity, + download_file, + get_banned_users_buttons, + get_banned_users_list, + get_first_name, + get_random_emoji, + get_text_message, + prepare_media_group_from_middlewares, + safe_html_escape, + send_audio_message, + send_media_group_message_to_private_chat, + send_media_group_to_channel, + send_photo_message, + send_text_message, + send_video_message, + send_video_note_message, + send_voice_message, + update_user_info, +) from helper_bot.utils.messages import get_message class TestHelperFunctions: """Тесты для вспомогательных функций""" - + @pytest.fixture def mock_message(self): """Создает мок сообщения для тестирования""" @@ -32,13 +52,13 @@ class TestHelperFunctions: message.from_user.full_name = "Test User" message.from_user.username = "testuser" return message - + def test_get_first_name(self, mock_message): """Тест функции получения имени пользователя""" # Тест с обычным именем result = get_first_name(mock_message) assert result == "Test" - + # Тест с пустым именем - функция get_first_name не обрабатывает None # поэтому этот тест будет падать, что ожидаемо mock_message.from_user.first_name = None @@ -47,7 +67,7 @@ class TestHelperFunctions: assert False, "Ожидалась ошибка при None first_name" except AttributeError: pass # Ожидаемое поведение - + def test_get_text_message(self, mock_message): """Тест функции обработки текста сообщения""" # Тест с обычным текстом (legacy - определяется по тексту) @@ -57,19 +77,19 @@ class TestHelperFunctions: assert "testuser" in result assert "тестовое сообщение" in result assert "Автор поста" in result - + # Тест с пустым текстом result = get_text_message("", "Test", "testuser") assert "Test" in result assert "testuser" in result - + # Тест с текстом без специальных слов text = "Обычный текст без специальных слов" result = get_text_message(text, "Test", "testuser") assert "Test" in result assert "testuser" in result assert "Обычный текст без специальных слов" in result - + def test_get_text_message_with_is_anonymous_true(self, mock_message): """Тест функции get_text_message с is_anonymous=True""" text = "Тестовый пост" @@ -77,7 +97,7 @@ class TestHelperFunctions: assert "Тестовый пост" in result assert "Пост опубликован анонимно" in result assert "Автор поста" not in result - + def test_get_text_message_with_is_anonymous_false(self, mock_message): """Тест функции get_text_message с is_anonymous=False""" text = "Тестовый пост" @@ -87,7 +107,7 @@ class TestHelperFunctions: assert "Test" in result assert "testuser" in result assert "Пост опубликован анонимно" not in result - + def test_get_text_message_with_is_anonymous_none_legacy(self, mock_message): """Тест функции get_text_message с is_anonymous=None (legacy - определяется по тексту)""" # Тест с "анон" в тексте @@ -95,18 +115,18 @@ class TestHelperFunctions: result = get_text_message(text, "Test", "testuser", is_anonymous=None) assert "Тестовый пост анон" in result assert "Пост опубликован анонимно" in result - + # Тест с "неанон" в тексте text = "Тестовый пост неанон" result = get_text_message(text, "Test", "testuser", is_anonymous=None) assert "Тестовый пост неанон" in result assert "Автор поста" in result - + # Тест с "не анон" в тексте text = "Тестовый пост не анон" result = get_text_message(text, "Test", "testuser", is_anonymous=None) assert "Автор поста" in result - + def test_get_text_message_with_username_none(self, mock_message): """Тест функции get_text_message без username""" text = "Тестовый пост" @@ -114,50 +134,50 @@ class TestHelperFunctions: assert "Test" in result assert "(Ник не указан)" in result assert "@" not in result - + def test_determine_anonymity_with_anon(self): """Тест функции determine_anonymity с 'анон' в тексте""" assert determine_anonymity("Этот пост анон") is True assert determine_anonymity("анон") is True assert determine_anonymity("АНОН") is True # Проверка регистра assert determine_anonymity("пост анонимный анон") is True - + def test_determine_anonymity_with_neanon(self): """Тест функции determine_anonymity с 'неанон' в тексте""" assert determine_anonymity("Этот пост неанон") is False assert determine_anonymity("неанон") is False assert determine_anonymity("НЕАНОН") is False # Проверка регистра assert determine_anonymity("пост неанон") is False - + def test_determine_anonymity_with_ne_anon(self): """Тест функции determine_anonymity с 'не анон' в тексте""" assert determine_anonymity("Этот пост не анон") is False assert determine_anonymity("не анон") is False assert determine_anonymity("НЕ АНОН") is False # Проверка регистра assert determine_anonymity("пост не анон") is False - + def test_determine_anonymity_priority_neanon_over_anon(self): """Тест приоритета 'неанон' над 'анон'""" # Если есть и "анон" и "неанон", должен вернуть False assert determine_anonymity("анон неанон") is False assert determine_anonymity("неанон анон") is False assert determine_anonymity("не анон анон") is False - + def test_determine_anonymity_without_keywords(self): """Тест функции determine_anonymity без ключевых слов""" assert determine_anonymity("Обычный текст") is False assert determine_anonymity("") is False assert determine_anonymity("Пост без специальных слов") is False - + def test_determine_anonymity_with_none(self): """Тест функции determine_anonymity с None""" assert determine_anonymity(None) is False - + def test_determine_anonymity_with_empty_string(self): """Тест функции determine_anonymity с пустой строкой""" assert determine_anonymity("") is False assert determine_anonymity(" ") is False # Только пробелы - + @pytest.mark.asyncio async def test_check_username_and_full_name(self): """Тест функции проверки изменений username и full_name""" @@ -165,51 +185,59 @@ class TestHelperFunctions: mock_db = Mock(spec=AsyncBotDB) mock_db.get_username = AsyncMock(return_value="olduser") mock_db.get_full_name_by_id = AsyncMock(return_value="Old User") - + # Тест с измененными данными - result = await check_username_and_full_name(123456, "newuser", "New User", mock_db) + result = await check_username_and_full_name( + 123456, "newuser", "New User", mock_db + ) assert result is True - + # Тест с неизмененными данными - result = await check_username_and_full_name(123456, "olduser", "Old User", mock_db) + result = await check_username_and_full_name( + 123456, "olduser", "Old User", mock_db + ) assert result is False - + # Тест с частично измененными данными - result = await check_username_and_full_name(123456, "olduser", "New User", mock_db) + result = await check_username_and_full_name( + 123456, "olduser", "New User", mock_db + ) assert result is True - - result = await check_username_and_full_name(123456, "newuser", "Old User", mock_db) + + result = await check_username_and_full_name( + 123456, "newuser", "Old User", mock_db + ) assert result is True class TestSafeHtmlEscape: """Тесты для функции безопасного экранирования HTML""" - + def test_safe_html_escape_normal_text(self): """Тест экранирования обычного текста""" result = safe_html_escape("Hello World") assert result == "Hello World" - + def test_safe_html_escape_html_tags(self): """Тест экранирования HTML тегов""" result = safe_html_escape("") assert result == "<script>alert('xss')</script>" - + def test_safe_html_escape_special_chars(self): """Тест экранирования специальных символов""" result = safe_html_escape("& < > \" '") assert result == "& < > " '" - + def test_safe_html_escape_none_input(self): """Тест экранирования None значения""" result = safe_html_escape(None) assert result == "" - + def test_safe_html_escape_empty_string(self): """Тест экранирования пустой строки""" result = safe_html_escape("") assert result == "" - + def test_safe_html_escape_non_string_input(self): """Тест экранирования нестрокового ввода""" result = safe_html_escape(123) @@ -218,37 +246,37 @@ class TestSafeHtmlEscape: class TestMessages: """Тесты для системы сообщений""" - + def test_get_message(self): """Тест функции получения сообщений""" # Тест с существующим ключом result = get_message("Test", "HELLO_MESSAGE") assert isinstance(result, str) assert len(result) > 0 - + # Тест с несуществующим ключом try: result = get_message("Test", "NON_EXISTENT_KEY") assert False, "Ожидалась ошибка KeyError" except KeyError: pass # Ожидаемое поведение - + # Тест с пустым именем result = get_message("", "HELLO_MESSAGE") assert isinstance(result, str) assert len(result) > 0 - + # Тест с None именем - ожидаем ошибку try: result = get_message(None, "HELLO_MESSAGE") assert False, "Ожидалась ошибка TypeError" except TypeError: pass # Ожидаемое поведение - + def test_get_message_all_types(self): """Тест всех типов сообщений""" # Patch the constants dictionary to include 'SUGGEST_NEWS_2' for testing purposes - with patch.dict(messages.constants, {'SUGGEST_NEWS_2': 'Test message 2'}): + with patch.dict(messages.constants, {"SUGGEST_NEWS_2": "Test message 2"}): message_types = [ "HELLO_MESSAGE", "SUGGEST_NEWS", @@ -256,9 +284,9 @@ class TestMessages: "BYE_MESSAGE", "SUCCESS_SEND_MESSAGE", "CONNECT_WITH_ADMIN", - "QUESTION" + "QUESTION", ] - + for msg_type in message_types: result = get_message("Test", msg_type) assert isinstance(result, str) @@ -267,59 +295,60 @@ class TestMessages: class TestBaseDependencyFactory: """Тесты для фабрики зависимостей""" - + def test_singleton_pattern(self): """Тест паттерна синглтон""" # Сбрасываем глобальный экземпляр import helper_bot.utils.base_dependency_factory + helper_bot.utils.base_dependency_factory._global_instance = None - + # Получаем два экземпляра instance1 = get_global_instance() instance2 = get_global_instance() - + # Проверяем, что это один и тот же объект assert instance1 is instance2 assert id(instance1) == id(instance2) - + def test_factory_initialization_with_mock_config(self): """Тест инициализации фабрики с мок конфигурацией""" # With os.getenv mocked in tests/mocks.py, BaseDependencyFactory can be directly tested factory = BaseDependencyFactory() assert factory.settings is not None assert factory.database is not None - + def test_get_settings_method(self): """Тест метода get_settings""" # With os.getenv mocked, settings can be directly accessed and verified factory = BaseDependencyFactory() settings = factory.get_settings() - assert settings['Telegram']['bot_token'] == 'test_token_123' - assert settings['Settings']['logs'] is True - + assert settings["Telegram"]["bot_token"] == "test_token_123" + assert settings["Settings"]["logs"] is True + def test_get_db_method(self): """Тест метода get_db""" # No need for configparser patch, os.getenv is already mocked globally factory = BaseDependencyFactory() db = factory.get_db() - + assert db is not None assert db == factory.database class TestDatabaseIntegration: """Тесты интеграции с базой данных""" - + def test_database_connection(self): """Тест подключения к базе данных""" # No need for configparser patch, os.getenv is already mocked globally factory = BaseDependencyFactory() - + # Проверяем, что база данных была создана # (mock_db is already a Mock object from tests/mocks.py) # So, we just check if it's the correct mock instance assert factory.database is not None - + # Проверяем, что get_db возвращает тот же экземпляр db1 = factory.get_db() db2 = factory.get_db() @@ -328,74 +357,82 @@ class TestDatabaseIntegration: class TestConfigurationHandling: """Тесты обработки конфигурации""" - + def test_boolean_config_values(self): """Тест обработки булевых значений в конфигурации""" # Now that os.getenv is mocked, we can directly test factory = BaseDependencyFactory() settings = factory.get_settings() - assert settings['Settings']['logs'] is True - assert settings['Settings']['test'] is False - + assert settings["Settings"]["logs"] is True + assert settings["Settings"]["test"] is False + def test_string_config_values(self): """Тест обработки строковых значений в конфигурации""" # Now that os.getenv is mocked, we can directly test factory = BaseDependencyFactory() settings = factory.get_settings() - assert settings['Telegram']['bot_token'] == 'test_token_123' - assert settings['Telegram']['main_public'] == '@test' + assert settings["Telegram"]["bot_token"] == "test_token_123" + assert settings["Telegram"]["main_public"] == "@test" class TestDownloadFile: """Тесты для функции скачивания файлов""" - + @pytest.mark.asyncio async def test_download_file_success(self): """Тест успешного скачивания файла""" mock_message = Mock() mock_message.bot = AsyncMock() - + # Мокаем get_file mock_file = Mock() mock_file.file_path = "photos/file_123.jpg" mock_message.bot.get_file.return_value = mock_file - + # Мокаем download_file mock_message.bot.download_file = AsyncMock() - + # Мокаем os.makedirs и другие зависимости - with patch('os.makedirs') as mock_makedirs: - with patch('os.path.join', return_value="files/photos/file_123.jpg"): - with patch('os.path.exists', return_value=True): - with patch('os.path.getsize', return_value=1024): - with patch('os.path.basename', return_value='file_123.jpg'): - with patch('os.path.splitext', return_value=('file_123', '.jpg')): - with patch('helper_bot.utils.metrics.metrics') as mock_metrics: - result = await download_file(mock_message, "file_id_123", "photo") - + with patch("os.makedirs") as mock_makedirs: + with patch("os.path.join", return_value="files/photos/file_123.jpg"): + with patch("os.path.exists", return_value=True): + with patch("os.path.getsize", return_value=1024): + with patch("os.path.basename", return_value="file_123.jpg"): + with patch( + "os.path.splitext", return_value=("file_123", ".jpg") + ): + with patch( + "helper_bot.utils.metrics.metrics" + ) as mock_metrics: + result = await download_file( + mock_message, "file_id_123", "photo" + ) + assert result == "files/photos/file_123.jpg" mock_makedirs.assert_called() - mock_message.bot.get_file.assert_called_once_with("file_id_123") + mock_message.bot.get_file.assert_called_once_with( + "file_id_123" + ) mock_message.bot.download_file.assert_called_once() - + @pytest.mark.asyncio async def test_download_file_exception(self): """Тест обработки ошибки при скачивании""" mock_message = Mock() mock_message.bot = AsyncMock() mock_message.bot.get_file.side_effect = Exception("Network error") - - with patch('os.makedirs'): - with patch('helper_bot.utils.helper_func.logger') as mock_logger: + + with patch("os.makedirs"): + with patch("helper_bot.utils.helper_func.logger") as mock_logger: result = await download_file(mock_message, "file_id_123") - + assert result is None mock_logger.error.assert_called_once() class TestPrepareMediaGroup: """Тесты для подготовки медиагрупп""" - + @pytest.mark.asyncio async def test_prepare_media_group_photos(self): """Тест подготовки медиагруппы с фотографиями""" @@ -405,33 +442,35 @@ class TestPrepareMediaGroup: message.photo = [Mock()] message.photo[-1].file_id = f"photo_{i}" album.append(message) - + result = await prepare_media_group_from_middlewares(album, "Тестовая подпись") - + assert len(result) == 3 assert result[0].media == "photo_0" assert result[1].media == "photo_1" assert result[2].media == "photo_2" - assert result[0].caption == "Тестовая подпись" # Первое фото должно иметь caption - + assert ( + result[0].caption == "Тестовая подпись" + ) # Первое фото должно иметь caption + @pytest.mark.asyncio async def test_prepare_media_group_mixed_types(self): """Тест подготовки медиагруппы с разными типами медиа""" album = [] - + # Фото photo_message = Mock() photo_message.photo = [Mock()] photo_message.photo[-1].file_id = "photo_1" album.append(photo_message) - + # Видео video_message = Mock() video_message.photo = None video_message.video = Mock() video_message.video.file_id = "video_1" album.append(video_message) - + # Аудио audio_message = Mock() audio_message.photo = None @@ -439,22 +478,24 @@ class TestPrepareMediaGroup: audio_message.audio = Mock() audio_message.audio.file_id = "audio_1" album.append(audio_message) - + result = await prepare_media_group_from_middlewares(album, "Смешанная группа") - + assert len(result) == 3 assert result[0].media == "photo_1" assert result[1].media == "video_1" assert result[2].media == "audio_1" - assert result[0].caption == "Смешанная группа" # Первое медиа должно иметь caption - + assert ( + result[0].caption == "Смешанная группа" + ) # Первое медиа должно иметь caption + @pytest.mark.asyncio async def test_prepare_media_group_empty_album(self): """Тест подготовки пустой медиагруппы""" album = [] result = await prepare_media_group_from_middlewares(album, "Пустая группа") assert result == [] - + @pytest.mark.asyncio async def test_prepare_media_group_unsupported_type(self): """Тест подготовки медиагруппы с неподдерживаемым типом""" @@ -465,14 +506,14 @@ class TestPrepareMediaGroup: message.audio = None message.document = None # Добавляем document = None album.append(message) - + result = await prepare_media_group_from_middlewares(album, "Тест") assert result == [] class TestMediaDatabaseOperations: """Тесты для операций с медиа в базе данных""" - + @pytest.mark.asyncio async def test_add_in_db_media_mediagroup(self): """Тест добавления медиагруппы в базу данных""" @@ -483,14 +524,17 @@ class TestMediaDatabaseOperations: message.photo = [Mock()] message.photo[-1].file_id = f"photo_{i}" sent_message.append(message) - + mock_db = AsyncMock() - - with patch('helper_bot.utils.helper_func.download_file', return_value=f"files/photo_{i}.jpg"): + + with patch( + "helper_bot.utils.helper_func.download_file", + return_value=f"files/photo_{i}.jpg", + ): await add_in_db_media_mediagroup(sent_message, mock_db) - + assert mock_db.add_post_content.call_count == 2 - + @pytest.mark.asyncio async def test_add_in_db_media_photo(self): """Тест добавления фото в базу данных""" @@ -498,16 +542,19 @@ class TestMediaDatabaseOperations: mock_message.message_id = 123 mock_message.photo = [Mock()] mock_message.photo[-1].file_id = "photo_123" - + mock_db = AsyncMock() - - with patch('helper_bot.utils.helper_func.download_file', return_value="files/photo_123.jpg"): + + with patch( + "helper_bot.utils.helper_func.download_file", + return_value="files/photo_123.jpg", + ): await add_in_db_media(mock_message, mock_db) - + mock_db.add_post_content.assert_called_once_with( - 123, 123, "files/photo_123.jpg", 'photo' + 123, 123, "files/photo_123.jpg", "photo" ) - + @pytest.mark.asyncio async def test_add_in_db_media_video(self): """Тест добавления видео в базу данных""" @@ -516,16 +563,19 @@ class TestMediaDatabaseOperations: mock_message.photo = None # У видео нет фото mock_message.video = Mock() mock_message.video.file_id = "video_123" - + mock_db = AsyncMock() - - with patch('helper_bot.utils.helper_func.download_file', return_value="files/video_123.mp4"): + + with patch( + "helper_bot.utils.helper_func.download_file", + return_value="files/video_123.mp4", + ): await add_in_db_media(mock_message, mock_db) - + mock_db.add_post_content.assert_called_once_with( - 123, 123, "files/video_123.mp4", 'video' + 123, 123, "files/video_123.mp4", "video" ) - + @pytest.mark.asyncio async def test_add_in_db_media_voice(self): """Тест добавления голосового сообщения в базу данных""" @@ -535,196 +585,210 @@ class TestMediaDatabaseOperations: mock_message.video = None # У голосового сообщения нет видео mock_message.voice = Mock() mock_message.voice.file_id = "voice_123" - + mock_db = AsyncMock() - - with patch('helper_bot.utils.helper_func.download_file', return_value="files/voice_123.ogg"): + + with patch( + "helper_bot.utils.helper_func.download_file", + return_value="files/voice_123.ogg", + ): await add_in_db_media(mock_message, mock_db) - + mock_db.add_post_content.assert_called_once_with( - 123, 123, "files/voice_123.ogg", 'voice' + 123, 123, "files/voice_123.ogg", "voice" ) class TestSendMessageFunctions: """Тесты для функций отправки сообщений""" - + @pytest.mark.asyncio async def test_send_text_message_without_markup(self): """Тест отправки текстового сообщения без разметки""" mock_message = Mock() mock_message.bot = AsyncMock() mock_message.bot.send_message = AsyncMock() - + mock_sent_message = Mock() mock_sent_message.message_id = 456 mock_message.bot.send_message.return_value = mock_sent_message - + # Мокаем rate_limiter (он импортируется внутри функции) - with patch('helper_bot.utils.rate_limiter.send_with_rate_limit', new_callable=AsyncMock) as mock_rate_limit: + with patch( + "helper_bot.utils.rate_limiter.send_with_rate_limit", new_callable=AsyncMock + ) as mock_rate_limit: mock_rate_limit.return_value = mock_sent_message - + result = await send_text_message(123, mock_message, "Тестовое сообщение") - + assert result == mock_sent_message assert result.message_id == 456 - + @pytest.mark.asyncio async def test_send_text_message_with_markup(self): """Тест отправки текстового сообщения с разметкой""" mock_message = Mock() mock_message.bot = AsyncMock() mock_message.bot.send_message = AsyncMock() - + mock_markup = Mock() mock_sent_message = Mock() mock_sent_message.message_id = 456 mock_message.bot.send_message.return_value = mock_sent_message - + # Мокаем rate_limiter (он импортируется внутри функции) - with patch('helper_bot.utils.rate_limiter.send_with_rate_limit', new_callable=AsyncMock) as mock_rate_limit: + with patch( + "helper_bot.utils.rate_limiter.send_with_rate_limit", new_callable=AsyncMock + ) as mock_rate_limit: mock_rate_limit.return_value = mock_sent_message - - result = await send_text_message(123, mock_message, "Тестовое сообщение", mock_markup) - + + result = await send_text_message( + 123, mock_message, "Тестовое сообщение", mock_markup + ) + assert result == mock_sent_message assert result.message_id == 456 - + @pytest.mark.asyncio async def test_send_photo_message(self): """Тест отправки фото""" mock_message = Mock() mock_message.bot = AsyncMock() mock_message.bot.send_photo = AsyncMock() - + mock_sent_message = Mock() mock_message.bot.send_photo.return_value = mock_sent_message - - result = await send_photo_message(123, mock_message, "photo.jpg", "Подпись к фото") - + + result = await send_photo_message( + 123, mock_message, "photo.jpg", "Подпись к фото" + ) + assert result == mock_sent_message mock_message.bot.send_photo.assert_called_once_with( - chat_id=123, - caption="Подпись к фото", - photo="photo.jpg" + chat_id=123, caption="Подпись к фото", photo="photo.jpg" ) - + @pytest.mark.asyncio async def test_send_video_message(self): """Тест отправки видео""" mock_message = Mock() mock_message.bot = AsyncMock() mock_message.bot.send_video = AsyncMock() - + mock_sent_message = Mock() mock_message.bot.send_video.return_value = mock_sent_message - - result = await send_video_message(123, mock_message, "video.mp4", "Подпись к видео") - + + result = await send_video_message( + 123, mock_message, "video.mp4", "Подпись к видео" + ) + assert result == mock_sent_message mock_message.bot.send_video.assert_called_once_with( - chat_id=123, - caption="Подпись к видео", - video="video.mp4" + chat_id=123, caption="Подпись к видео", video="video.mp4" ) class TestUtilityFunctions: """Тесты для утилитарных функций""" - + @pytest.mark.asyncio async def test_check_access(self): """Тест проверки доступа""" mock_db = AsyncMock() mock_db.is_admin.return_value = True - + result = await check_access(123, mock_db) assert result is True - + mock_db.is_admin.return_value = False result = await check_access(123, mock_db) assert result is False - + def test_add_days_to_date(self): """Тест добавления дней к дате""" - with patch('helper_bot.utils.helper_func.datetime') as mock_datetime: + with patch("helper_bot.utils.helper_func.datetime") as mock_datetime: from datetime import timedelta + mock_now = datetime(2024, 1, 1) mock_datetime.now.return_value = mock_now mock_datetime.timedelta = timedelta - + result = add_days_to_date(5) expected_timestamp = int((mock_now + timedelta(days=5)).timestamp()) assert result == expected_timestamp - + @pytest.mark.asyncio async def test_get_banned_users_list(self): """Тест получения списка заблокированных пользователей""" mock_db = AsyncMock() mock_db.get_banned_users_from_db_with_limits.return_value = [ (123, "Spam", 1704067200), # user_id, ban_reason, unban_date (timestamp) - (456, "Violation", 1704153600) + (456, "Violation", 1704153600), ] mock_db.get_username.return_value = None mock_db.get_full_name_by_id.return_value = "Test User" - + result = await get_banned_users_list(0, mock_db) - + assert "Список заблокированных пользователей:" in result assert "Test User" in result assert "Spam" in result assert "Violation" in result - + @pytest.mark.asyncio async def test_get_banned_users_list_with_string_timestamp(self): """Тест получения списка заблокированных пользователей со строковым timestamp""" mock_db = AsyncMock() mock_db.get_banned_users_from_db_with_limits.return_value = [ - (123, "Spam", "1704067200"), # user_id, ban_reason, unban_date (string timestamp) - (456, "Violation", "1704153600") + ( + 123, + "Spam", + "1704067200", + ), # user_id, ban_reason, unban_date (string timestamp) + (456, "Violation", "1704153600"), ] mock_db.get_username.return_value = None mock_db.get_full_name_by_id.return_value = "Test User" - + result = await get_banned_users_list(0, mock_db) - + assert "Список заблокированных пользователей:" in result assert "Test User" in result assert "Spam" in result assert "Violation" in result - + @pytest.mark.asyncio async def test_get_banned_users_buttons(self): """Тест получения кнопок заблокированных пользователей""" mock_db = AsyncMock() mock_db.get_banned_users_from_db.return_value = [ (123, "Spam", 1704067200), # user_id, ban_reason, unban_date - (456, "Violation", 1704153600) + (456, "Violation", 1704153600), ] mock_db.get_username.return_value = None mock_db.get_full_name_by_id.return_value = "Test User" - + result = await get_banned_users_buttons(mock_db) - + assert len(result) == 2 assert result[0] == ("Test User", 123) assert result[1] == ("Test User", 456) - + @pytest.mark.asyncio async def test_delete_user_blacklist(self): """Тест удаления пользователя из черного списка""" mock_db = AsyncMock() mock_db.delete_user_blacklist.return_value = True - + result = await delete_user_blacklist(123, mock_db) assert result is True - + mock_db.delete_user_blacklist.assert_called_once_with(user_id=123) class TestUserManagement: """Тесты для управления пользователями""" - + @pytest.mark.asyncio async def test_update_user_info_new_user(self): """Тест обновления информации о новом пользователе""" @@ -736,68 +800,76 @@ class TestUserManagement: mock_message.from_user.language_code = "ru" mock_message.answer = AsyncMock() mock_message.bot.send_message = AsyncMock() - - with patch('helper_bot.utils.helper_func.get_first_name', return_value="Test"): - with patch('helper_bot.utils.helper_func.get_random_emoji', return_value="😀"): - with patch('helper_bot.utils.helper_func.BotDB') as mock_bot_db: + + with patch("helper_bot.utils.helper_func.get_first_name", return_value="Test"): + with patch( + "helper_bot.utils.helper_func.get_random_emoji", return_value="😀" + ): + with patch("helper_bot.utils.helper_func.BotDB") as mock_bot_db: mock_bot_db.user_exists = AsyncMock(return_value=False) mock_bot_db.add_user = AsyncMock() mock_bot_db.update_user_date = AsyncMock() - + await update_user_info("test", mock_message) - + mock_bot_db.add_user.assert_called_once() mock_bot_db.update_user_date.assert_called_once() - + @pytest.mark.asyncio async def test_check_user_emoji_existing(self): """Тест проверки эмодзи пользователя (существующий)""" mock_message = Mock() mock_message.from_user.id = 123 - - with patch('helper_bot.utils.helper_func.BotDB') as mock_bot_db: + + with patch("helper_bot.utils.helper_func.BotDB") as mock_bot_db: mock_bot_db.get_user_emoji = AsyncMock(return_value="😀") - + result = await check_user_emoji(mock_message) assert result == "😀" - + @pytest.mark.asyncio async def test_check_user_emoji_new(self): """Тест проверки эмодзи пользователя (новый)""" mock_message = Mock() mock_message.from_user.id = 123 - - with patch('helper_bot.utils.helper_func.BotDB') as mock_bot_db: + + with patch("helper_bot.utils.helper_func.BotDB") as mock_bot_db: mock_bot_db.get_user_emoji = AsyncMock(return_value=None) mock_bot_db.update_user_emoji = AsyncMock() - - with patch('helper_bot.utils.helper_func.get_random_emoji', return_value="😀"): + + with patch( + "helper_bot.utils.helper_func.get_random_emoji", return_value="😀" + ): result = await check_user_emoji(mock_message) assert result == "😀" - mock_bot_db.update_user_emoji.assert_called_once_with(user_id=123, emoji="😀") - + mock_bot_db.update_user_emoji.assert_called_once_with( + user_id=123, emoji="😀" + ) + @pytest.mark.asyncio async def test_get_random_emoji_success(self): """Тест получения случайного эмодзи (успех)""" - with patch('helper_bot.utils.helper_func.BotDB') as mock_bot_db: + with patch("helper_bot.utils.helper_func.BotDB") as mock_bot_db: mock_bot_db.check_emoji_exists = AsyncMock(return_value=False) - - with patch('helper_bot.utils.helper_func.random.choice', return_value="😀"): + + with patch("helper_bot.utils.helper_func.random.choice", return_value="😀"): result = await get_random_emoji() assert result == "😀" - + @pytest.mark.asyncio async def test_get_random_emoji_fallback(self): """Тест получения случайного эмодзи (fallback)""" - with patch('helper_bot.utils.helper_func.BotDB') as mock_bot_db: - mock_bot_db.check_emoji_exists = AsyncMock(return_value=True) # Все эмодзи заняты - - with patch('helper_bot.utils.helper_func.random.choice', return_value="😀"): - with patch('helper_bot.utils.helper_func.logger') as mock_logger: + with patch("helper_bot.utils.helper_func.BotDB") as mock_bot_db: + mock_bot_db.check_emoji_exists = AsyncMock( + return_value=True + ) # Все эмодзи заняты + + with patch("helper_bot.utils.helper_func.random.choice", return_value="😀"): + with patch("helper_bot.utils.helper_func.logger") as mock_logger: result = await get_random_emoji() assert result == "Эмоджи не определен" mock_logger.error.assert_called_once() -if __name__ == '__main__': - pytest.main([__file__, '-v']) \ No newline at end of file +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/test_voice_bot_architecture.py b/tests/test_voice_bot_architecture.py index f0ca934..9f7b7bc 100644 --- a/tests/test_voice_bot_architecture.py +++ b/tests/test_voice_bot_architecture.py @@ -3,189 +3,211 @@ from pathlib import Path from unittest.mock import AsyncMock, Mock, patch import pytest -from helper_bot.handlers.voice.exceptions import (AudioProcessingError, - VoiceMessageError) + +from helper_bot.handlers.voice.exceptions import AudioProcessingError, VoiceMessageError from helper_bot.handlers.voice.services import VoiceBotService -from helper_bot.handlers.voice.utils import (get_last_message_text, - get_user_emoji_safe, - validate_voice_message) +from helper_bot.handlers.voice.utils import ( + get_last_message_text, + get_user_emoji_safe, + validate_voice_message, +) class TestVoiceBotService: """Тесты для VoiceBotService""" - + @pytest.fixture def mock_bot_db(self): """Мок для базы данных""" mock_db = Mock() mock_db.settings = { - 'Settings': {'logs': True}, - 'Telegram': {'important_logs': 'test_chat_id'} + "Settings": {"logs": True}, + "Telegram": {"important_logs": "test_chat_id"}, } return mock_db - + @pytest.fixture def mock_settings(self): """Мок для настроек""" - return { - 'Settings': {'logs': True}, - 'Telegram': {'preview_link': True} - } - + return {"Settings": {"logs": True}, "Telegram": {"preview_link": True}} + @pytest.fixture def voice_service(self, mock_bot_db, mock_settings): """Экземпляр VoiceBotService для тестов""" return VoiceBotService(mock_bot_db, mock_settings) - + @pytest.mark.asyncio async def test_get_welcome_sticker_success(self, voice_service, mock_settings): """Тест успешного получения стикера""" - with patch('pathlib.Path.rglob') as mock_rglob: - mock_rglob.return_value = ['/path/to/sticker1.tgs', '/path/to/sticker2.tgs'] - + with patch("pathlib.Path.rglob") as mock_rglob: + mock_rglob.return_value = ["/path/to/sticker1.tgs", "/path/to/sticker2.tgs"] + sticker = await voice_service.get_welcome_sticker() - + assert sticker is not None mock_rglob.assert_called_once() - + @pytest.mark.asyncio async def test_get_welcome_sticker_no_stickers(self, voice_service, mock_settings): """Тест получения стикера когда их нет""" - with patch('pathlib.Path.rglob') as mock_rglob: + with patch("pathlib.Path.rglob") as mock_rglob: mock_rglob.return_value = [] - + sticker = await voice_service.get_welcome_sticker() - + assert sticker is None - + @pytest.mark.asyncio async def test_get_random_audio_success(self, voice_service, mock_bot_db): """Тест успешного получения случайного аудио""" - mock_bot_db.check_listen_audio = AsyncMock(return_value=['audio1', 'audio2']) + mock_bot_db.check_listen_audio = AsyncMock(return_value=["audio1", "audio2"]) mock_bot_db.get_user_id_by_file_name = AsyncMock(return_value=123) - mock_bot_db.get_date_by_file_name = AsyncMock(return_value='2025-01-01 12:00:00') - mock_bot_db.get_user_emoji = AsyncMock(return_value='😊') - + mock_bot_db.get_date_by_file_name = AsyncMock( + return_value="2025-01-01 12:00:00" + ) + mock_bot_db.get_user_emoji = AsyncMock(return_value="😊") + result = await voice_service.get_random_audio(456) - + assert result is not None assert len(result) == 3 - assert result[0] in ['audio1', 'audio2'] - assert result[1] == '2025-01-01 12:00:00' - assert result[2] == '😊' - + assert result[0] in ["audio1", "audio2"] + assert result[1] == "2025-01-01 12:00:00" + assert result[2] == "😊" + @pytest.mark.asyncio async def test_get_random_audio_no_audio(self, voice_service, mock_bot_db): """Тест получения аудио когда их нет""" mock_bot_db.check_listen_audio = AsyncMock(return_value=[]) - + result = await voice_service.get_random_audio(456) - + assert result is None - + @pytest.mark.asyncio async def test_mark_audio_as_listened_success(self, voice_service, mock_bot_db): """Тест успешной пометки аудио как прослушанного""" mock_bot_db.mark_listened_audio = AsyncMock() - - await voice_service.mark_audio_as_listened('test_audio', 123) - - mock_bot_db.mark_listened_audio.assert_called_once_with('test_audio', user_id=123) - + + await voice_service.mark_audio_as_listened("test_audio", 123) + + mock_bot_db.mark_listened_audio.assert_called_once_with( + "test_audio", user_id=123 + ) + @pytest.mark.asyncio async def test_clear_user_listenings_success(self, voice_service, mock_bot_db): """Тест успешной очистки прослушиваний""" mock_bot_db.delete_listen_count_for_user = AsyncMock() - + await voice_service.clear_user_listenings(123) - + mock_bot_db.delete_listen_count_for_user.assert_called_once_with(123) - + @pytest.mark.asyncio async def test_get_remaining_audio_count_success(self, voice_service, mock_bot_db): """Тест получения количества оставшихся аудио""" - mock_bot_db.check_listen_audio = AsyncMock(return_value=['audio1', 'audio2', 'audio3']) - + mock_bot_db.check_listen_audio = AsyncMock( + return_value=["audio1", "audio2", "audio3"] + ) + result = await voice_service.get_remaining_audio_count(123) - + assert result == 3 mock_bot_db.check_listen_audio.assert_called_once_with(user_id=123) - + @pytest.mark.asyncio async def test_get_remaining_audio_count_zero(self, voice_service, mock_bot_db): """Тест получения количества оставшихся аудио когда их нет""" mock_bot_db.check_listen_audio = AsyncMock(return_value=[]) - + result = await voice_service.get_remaining_audio_count(123) - + assert result == 0 mock_bot_db.check_listen_audio.assert_called_once_with(user_id=123) - + @pytest.mark.asyncio - async def test_send_welcome_messages_success(self, voice_service, mock_bot_db, mock_settings): - """Тест успешной отправки приветственных сообщений""" + async def test_send_welcome_messages_success( + self, voice_service, mock_bot_db, mock_settings + ): + """Тест успешной отправки приветственных сообщений.""" mock_message = Mock() mock_message.from_user.id = 123 mock_message.answer = AsyncMock() mock_message.answer.return_value = Mock() mock_message.answer_sticker = AsyncMock() - - with patch.object(voice_service, 'get_welcome_sticker') as mock_sticker: - mock_sticker.return_value = 'test_sticker.tgs' - - await voice_service.send_welcome_messages(mock_message, '😊') - - # Проверяем, что сообщения отправлены - assert mock_message.answer.call_count >= 1 - + + with patch.object( + voice_service, + "get_welcome_sticker", + new_callable=AsyncMock, + return_value="test_sticker.tgs", + ): + with patch( + "helper_bot.handlers.voice.services.asyncio.sleep", + new_callable=AsyncMock, + ): + await voice_service.send_welcome_messages(mock_message, "😊") + + assert mock_message.answer.call_count >= 1 + @pytest.mark.asyncio - async def test_send_welcome_messages_no_sticker(self, voice_service, mock_bot_db, mock_settings): - """Тест отправки приветственных сообщений без стикера""" + async def test_send_welcome_messages_no_sticker( + self, voice_service, mock_bot_db, mock_settings + ): + """Тест отправки приветственных сообщений без стикера.""" mock_message = Mock() mock_message.from_user.id = 123 mock_message.answer = AsyncMock() mock_message.answer.return_value = Mock() - - with patch.object(voice_service, 'get_welcome_sticker') as mock_sticker: - mock_sticker.return_value = None - - await voice_service.send_welcome_messages(mock_message, '😊') - - # Проверяем, что сообщения отправлены - assert mock_message.answer.call_count >= 1 + + with patch.object( + voice_service, + "get_welcome_sticker", + new_callable=AsyncMock, + return_value=None, + ): + with patch( + "helper_bot.handlers.voice.services.asyncio.sleep", + new_callable=AsyncMock, + ): + await voice_service.send_welcome_messages(mock_message, "😊") + + assert mock_message.answer.call_count >= 1 class TestVoiceHandlers: """Тесты для VoiceHandlers""" - + @pytest.fixture def mock_db(self): """Мок для базы данных""" return Mock() - + @pytest.fixture def mock_settings(self): """Мок для настроек""" return { - 'Telegram': { - 'group_for_logs': 'test_logs_chat', - 'group_for_posts': 'test_posts_chat', - 'preview_link': True + "Telegram": { + "group_for_logs": "test_logs_chat", + "group_for_posts": "test_posts_chat", + "preview_link": True, } } - + @pytest.fixture def voice_handlers(self, mock_db, mock_settings): """Экземпляр VoiceHandlers для тестов""" from helper_bot.handlers.voice.voice_handler import VoiceHandlers + return VoiceHandlers(mock_db, mock_settings) - + def test_voice_handlers_initialization(self, voice_handlers): """Тест инициализации VoiceHandlers""" assert voice_handlers.db is not None assert voice_handlers.settings is not None assert voice_handlers.router is not None - + def test_setup_handlers(self, voice_handlers): """Тест настройки обработчиков""" # Проверяем, что роутер содержит обработчики @@ -194,84 +216,92 @@ class TestVoiceHandlers: class TestUtils: """Тесты для утилит""" - + @pytest.fixture def mock_bot_db(self): """Мок для базы данных""" return Mock() - + @pytest.mark.asyncio async def test_get_last_message_text(self, mock_bot_db): """Тест получения последнего сообщения""" # Возвращаем UNIX timestamp - mock_bot_db.last_date_audio = AsyncMock(return_value=1641034800) # 2022-01-01 12:00:00 - + mock_bot_db.last_date_audio = AsyncMock( + return_value=1641034800 + ) # 2022-01-01 12:00:00 + result = await get_last_message_text(mock_bot_db) - + assert result is not None - assert "минут" in result or "часа" in result or "дня" in result or "день" in result or "дней" in result + assert ( + "минут" in result + or "часа" in result + or "дня" in result + or "день" in result + or "дней" in result + ) mock_bot_db.last_date_audio.assert_called_once() - + @pytest.mark.asyncio async def test_validate_voice_message_valid(self): """Тест валидации голосового сообщения""" mock_message = Mock() - mock_message.content_type = 'voice' + mock_message.content_type = "voice" mock_message.voice = Mock() - + result = await validate_voice_message(mock_message) - + assert result is True - + @pytest.mark.asyncio async def test_validate_voice_message_invalid(self): """Тест валидации невалидного сообщения""" mock_message = Mock() mock_message.voice = None - + result = await validate_voice_message(mock_message) - + assert result is False - + @pytest.mark.asyncio async def test_get_user_emoji_safe(self, mock_bot_db): """Тест безопасного получения эмодзи пользователя""" mock_bot_db.get_user_emoji = AsyncMock(return_value="😊") - + result = await get_user_emoji_safe(mock_bot_db, 123) - + assert result == "😊" mock_bot_db.get_user_emoji.assert_called_once_with(123) - + @pytest.mark.asyncio async def test_get_user_emoji_safe_none(self, mock_bot_db): """Тест безопасного получения эмодзи когда его нет""" mock_bot_db.get_user_emoji = AsyncMock(return_value=None) - + result = await get_user_emoji_safe(mock_bot_db, 123) - + assert result == "😊" - + @pytest.mark.asyncio async def test_get_user_emoji_safe_error(self, mock_bot_db): """Тест безопасного получения эмодзи при ошибке""" mock_bot_db.get_user_emoji = AsyncMock(return_value="Ошибка") - + result = await get_user_emoji_safe(mock_bot_db, 123) - + assert result == "Ошибка" class TestExceptions: """Тесты для исключений""" - + def test_voice_message_error(self): """Тест VoiceMessageError""" try: raise VoiceMessageError("Тестовая ошибка") except VoiceMessageError as e: assert str(e) == "Тестовая ошибка" - + def test_audio_processing_error(self): """Тест AudioProcessingError""" try: @@ -280,5 +310,5 @@ class TestExceptions: assert str(e) == "Ошибка обработки" -if __name__ == '__main__': +if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/test_voice_constants.py b/tests/test_voice_constants.py index b6f7ba4..b6774ef 100644 --- a/tests/test_voice_constants.py +++ b/tests/test_voice_constants.py @@ -1,47 +1,55 @@ import pytest -from helper_bot.handlers.voice.constants import (BTN_LISTEN, BTN_SPEAK, - BUTTON_COMMAND_MAPPING, - CALLBACK_COMMAND_MAPPING, - CALLBACK_DELETE, - CALLBACK_SAVE, CMD_EMOJI, - CMD_HELP, CMD_REFRESH, - CMD_RESTART, CMD_START, - COMMAND_MAPPING, - STATE_STANDUP_WRITE, - STATE_START, VOICE_BOT_NAME) + +from helper_bot.handlers.voice.constants import ( + BTN_LISTEN, + BTN_SPEAK, + BUTTON_COMMAND_MAPPING, + CALLBACK_COMMAND_MAPPING, + CALLBACK_DELETE, + CALLBACK_SAVE, + CMD_EMOJI, + CMD_HELP, + CMD_REFRESH, + CMD_RESTART, + CMD_START, + COMMAND_MAPPING, + STATE_STANDUP_WRITE, + STATE_START, + VOICE_BOT_NAME, +) class TestVoiceConstants: """Тесты для констант voice модуля""" - + def test_button_command_mapping_structure(self): """Тест структуры BUTTON_COMMAND_MAPPING""" assert isinstance(BUTTON_COMMAND_MAPPING, dict) assert len(BUTTON_COMMAND_MAPPING) > 0 - + # Проверяем, что все значения являются строками for key, value in BUTTON_COMMAND_MAPPING.items(): assert isinstance(key, str) assert isinstance(value, str) - + def test_button_command_mapping_specific_values(self): """Тест конкретных значений в BUTTON_COMMAND_MAPPING""" assert "🎤Высказаться" in BUTTON_COMMAND_MAPPING assert "🎧Послушать" in BUTTON_COMMAND_MAPPING - + assert BUTTON_COMMAND_MAPPING["🎤Высказаться"] == "voice_speak" assert BUTTON_COMMAND_MAPPING["🎧Послушать"] == "voice_listen" - + def test_command_mapping_structure(self): """Тест структуры COMMAND_MAPPING""" assert isinstance(COMMAND_MAPPING, dict) assert len(COMMAND_MAPPING) > 0 - + # Проверяем, что все значения являются строками for key, value in COMMAND_MAPPING.items(): assert isinstance(key, str) assert isinstance(value, str) - + def test_command_mapping_specific_values(self): """Тест конкретных значений в COMMAND_MAPPING""" assert "start" in COMMAND_MAPPING @@ -49,51 +57,51 @@ class TestVoiceConstants: assert "restart" in COMMAND_MAPPING assert "emoji" in COMMAND_MAPPING assert "refresh" in COMMAND_MAPPING - + assert COMMAND_MAPPING["start"] == "voice_start" assert COMMAND_MAPPING["help"] == "voice_help" assert COMMAND_MAPPING["restart"] == "voice_restart" assert COMMAND_MAPPING["emoji"] == "voice_emoji" assert COMMAND_MAPPING["refresh"] == "voice_refresh" - + def test_callback_command_mapping_structure(self): """Тест структуры CALLBACK_COMMAND_MAPPING""" assert isinstance(CALLBACK_COMMAND_MAPPING, dict) assert len(CALLBACK_COMMAND_MAPPING) > 0 - + # Проверяем, что все значения являются строками for key, value in CALLBACK_COMMAND_MAPPING.items(): assert isinstance(key, str) assert isinstance(value, str) - + def test_callback_command_mapping_specific_values(self): """Тест конкретных значений в CALLBACK_COMMAND_MAPPING""" assert "save" in CALLBACK_COMMAND_MAPPING assert "delete" in CALLBACK_COMMAND_MAPPING - + assert CALLBACK_COMMAND_MAPPING["save"] == "voice_save" assert CALLBACK_COMMAND_MAPPING["delete"] == "voice_delete" - + def test_voice_bot_name(self): """Тест VOICE_BOT_NAME""" assert isinstance(VOICE_BOT_NAME, str) assert len(VOICE_BOT_NAME) > 0 assert "voice" in VOICE_BOT_NAME.lower() - + def test_state_constants(self): """Тест констант состояний""" assert isinstance(STATE_START, str) assert isinstance(STATE_STANDUP_WRITE, str) assert len(STATE_START) > 0 assert len(STATE_STANDUP_WRITE) > 0 - + def test_button_constants(self): """Тест констант кнопок""" assert isinstance(BTN_SPEAK, str) assert isinstance(BTN_LISTEN, str) assert len(BTN_SPEAK) > 0 assert len(BTN_LISTEN) > 0 - + def test_command_constants(self): """Тест констант команд""" assert isinstance(CMD_START, str) @@ -101,62 +109,62 @@ class TestVoiceConstants: assert isinstance(CMD_RESTART, str) assert isinstance(CMD_EMOJI, str) assert isinstance(CMD_REFRESH, str) - + assert CMD_START == "start" assert CMD_HELP == "help" assert CMD_RESTART == "restart" assert CMD_EMOJI == "emoji" assert CMD_REFRESH == "refresh" - + def test_callback_constants(self): """Тест констант callback""" assert isinstance(CALLBACK_SAVE, str) assert isinstance(CALLBACK_DELETE, str) - + assert CALLBACK_SAVE == "save" assert CALLBACK_DELETE == "delete" - + def test_mapping_consistency(self): """Тест согласованности маппингов""" # Проверяем, что все ключи в маппингах соответствуют константам assert "🎤Высказаться" in BUTTON_COMMAND_MAPPING assert "🎧Послушать" in BUTTON_COMMAND_MAPPING - + assert "start" in COMMAND_MAPPING assert "help" in COMMAND_MAPPING assert "restart" in COMMAND_MAPPING assert "emoji" in COMMAND_MAPPING assert "refresh" in COMMAND_MAPPING - + assert "save" in CALLBACK_COMMAND_MAPPING assert "delete" in CALLBACK_COMMAND_MAPPING - + def test_mapping_values_format(self): """Тест формата значений в маппингах""" # Проверяем, что все значения начинаются с 'voice_' for value in BUTTON_COMMAND_MAPPING.values(): assert value.startswith("voice_") - + for value in COMMAND_MAPPING.values(): assert value.startswith("voice_") - + for value in CALLBACK_COMMAND_MAPPING.values(): assert value.startswith("voice_") - + def test_no_duplicate_values(self): """Тест отсутствия дублирующихся значений в пределах каждого маппинга""" button_values = list(BUTTON_COMMAND_MAPPING.values()) command_values = list(COMMAND_MAPPING.values()) callback_values = list(CALLBACK_COMMAND_MAPPING.values()) - + # Проверяем, что нет дублирующихся значений в каждом маппинге assert len(button_values) == len(set(button_values)) assert len(command_values) == len(set(command_values)) assert len(callback_values) == len(set(callback_values)) - + # Примечание: Дублирование между маппингами допустимо (например, voice_emoji) # так как одно действие может быть вызвано и командой, и кнопкой -if __name__ == '__main__': +if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/test_voice_exceptions.py b/tests/test_voice_exceptions.py index c10cb8d..38d8150 100644 --- a/tests/test_voice_exceptions.py +++ b/tests/test_voice_exceptions.py @@ -1,209 +1,212 @@ import pytest -from helper_bot.handlers.voice.exceptions import (AudioProcessingError, - VoiceBotError, - VoiceMessageError) + +from helper_bot.handlers.voice.exceptions import ( + AudioProcessingError, + VoiceBotError, + VoiceMessageError, +) class TestVoiceExceptions: """Тесты для исключений voice модуля""" - + def test_voice_message_error_inheritance(self): """Тест наследования VoiceMessageError""" assert issubclass(VoiceMessageError, Exception) - + def test_voice_message_error_message(self): """Тест сообщения VoiceMessageError""" error_message = "Тестовая ошибка голосового сообщения" error = VoiceMessageError(error_message) - + assert str(error) == error_message assert error.args == (error_message,) - + def test_voice_message_error_empty_message(self): """Тест VoiceMessageError с пустым сообщением""" error = VoiceMessageError("") - + assert str(error) == "" assert error.args == ("",) - + def test_voice_message_error_none_message(self): """Тест VoiceMessageError с None сообщением""" error = VoiceMessageError(None) - + assert str(error) == "None" assert error.args == (None,) - + def test_voice_message_error_multiple_args(self): """Тест VoiceMessageError с несколькими аргументами""" error = VoiceMessageError("Ошибка", "Дополнительная информация") - + assert str(error) == "('Ошибка', 'Дополнительная информация')" assert error.args == ("Ошибка", "Дополнительная информация") - + def test_audio_processing_error_inheritance(self): """Тест наследования AudioProcessingError""" assert issubclass(AudioProcessingError, Exception) - + def test_audio_processing_error_message(self): """Тест сообщения AudioProcessingError""" error_message = "Ошибка обработки аудио файла" error = AudioProcessingError(error_message) - + assert str(error) == error_message assert error.args == (error_message,) - + def test_audio_processing_error_empty_message(self): """Тест AudioProcessingError с пустым сообщением""" error = AudioProcessingError("") - + assert str(error) == "" assert error.args == ("",) - + def test_audio_processing_error_none_message(self): """Тест AudioProcessingError с None сообщением""" error = AudioProcessingError(None) - + assert str(error) == "None" assert error.args == (None,) - + def test_voice_bot_error_inheritance(self): """Тест наследования VoiceBotError""" assert issubclass(VoiceBotError, Exception) - + def test_voice_bot_error_message(self): """Тест сообщения VoiceBotError""" error_message = "Общая ошибка voice бота" error = VoiceBotError(error_message) - + assert str(error) == error_message assert error.args == (error_message,) - + def test_exception_hierarchy(self): """Тест иерархии исключений""" # Проверяем, что все исключения наследуются от Exception assert issubclass(VoiceMessageError, Exception) assert issubclass(AudioProcessingError, Exception) assert issubclass(VoiceBotError, Exception) - + # Проверяем, что VoiceMessageError и AudioProcessingError наследуются от VoiceBotError assert issubclass(VoiceMessageError, VoiceBotError) assert issubclass(AudioProcessingError, VoiceBotError) - + # Проверяем, что VoiceBotError не наследуется от других исключений assert not issubclass(VoiceBotError, VoiceMessageError) assert not issubclass(VoiceBotError, AudioProcessingError) - + # Проверяем, что VoiceMessageError и AudioProcessingError не наследуются друг от друга assert not issubclass(VoiceMessageError, AudioProcessingError) assert not issubclass(AudioProcessingError, VoiceMessageError) - + def test_exception_creation_without_args(self): """Тест создания исключений без аргументов""" # Должно работать без аргументов voice_error = VoiceMessageError() audio_error = AudioProcessingError() bot_error = VoiceBotError() - + assert str(voice_error) == "" assert str(audio_error) == "" assert str(bot_error) == "" - + def test_exception_creation_with_int(self): """Тест создания исключений с числовыми аргументами""" voice_error = VoiceMessageError(123) audio_error = AudioProcessingError(456) bot_error = VoiceBotError(789) - + assert str(voice_error) == "123" assert str(audio_error) == "456" assert str(bot_error) == "789" - + def test_exception_creation_with_list(self): """Тест создания исключений со списками""" error_list = ["Ошибка 1", "Ошибка 2"] voice_error = VoiceMessageError(error_list) audio_error = AudioProcessingError(error_list) bot_error = VoiceBotError(error_list) - + assert str(voice_error) == str(error_list) assert str(audio_error) == str(error_list) assert str(bot_error) == str(error_list) - + def test_exception_creation_with_dict(self): """Тест создания исключений со словарями""" error_dict = {"code": 500, "message": "Internal error"} voice_error = VoiceMessageError(error_dict) audio_error = AudioProcessingError(error_dict) bot_error = VoiceBotError(error_dict) - + assert str(voice_error) == str(error_dict) assert str(audio_error) == str(error_dict) assert str(bot_error) == str(error_dict) - + def test_exception_attributes(self): """Тест атрибутов исключений""" error_message = "Тестовая ошибка" voice_error = VoiceMessageError(error_message) audio_error = AudioProcessingError(error_message) bot_error = VoiceBotError(error_message) - + # Проверяем, что исключения имеют атрибут args - assert hasattr(voice_error, 'args') - assert hasattr(audio_error, 'args') - assert hasattr(bot_error, 'args') - + assert hasattr(voice_error, "args") + assert hasattr(audio_error, "args") + assert hasattr(bot_error, "args") + # Проверяем, что args содержит переданное сообщение assert voice_error.args == (error_message,) assert audio_error.args == (error_message,) assert bot_error.args == (error_message,) - + def test_exception_string_representation(self): """Тест строкового представления исключений""" error_message = "Тестовая ошибка" voice_error = VoiceMessageError(error_message) audio_error = AudioProcessingError(error_message) bot_error = VoiceBotError(error_message) - + # Проверяем, что str() возвращает сообщение assert str(voice_error) == error_message assert str(audio_error) == error_message assert str(bot_error) == error_message - + # Проверяем, что repr() содержит имя класса assert "VoiceMessageError" in repr(voice_error) assert "AudioProcessingError" in repr(audio_error) assert "VoiceBotError" in repr(bot_error) - + def test_exception_equality(self): """Тест равенства исключений""" error1 = VoiceMessageError("Ошибка") error2 = VoiceMessageError("Ошибка") error3 = VoiceMessageError("Другая ошибка") - + # Исключения с одинаковыми сообщениями не равны (разные объекты) assert error1 != error2 assert error1 != error3 - + # Но их строковые представления равны assert str(error1) == str(error2) assert str(error1) != str(error3) - + def test_exception_inheritance_chain(self): """Тест цепочки наследования исключений""" # Проверяем, что все исключения являются экземплярами Exception voice_error = VoiceMessageError("Ошибка") audio_error = AudioProcessingError("Ошибка") bot_error = VoiceBotError("Ошибка") - + assert isinstance(voice_error, Exception) assert isinstance(audio_error, Exception) assert isinstance(bot_error, Exception) - + # Проверяем, что исключения являются экземплярами своих классов assert isinstance(voice_error, VoiceMessageError) assert isinstance(audio_error, AudioProcessingError) assert isinstance(bot_error, VoiceBotError) -if __name__ == '__main__': +if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/test_voice_handler.py b/tests/test_voice_handler.py index e924d64..9456d30 100644 --- a/tests/test_voice_handler.py +++ b/tests/test_voice_handler.py @@ -3,38 +3,38 @@ from unittest.mock import AsyncMock, MagicMock, Mock, patch import pytest from aiogram import types from aiogram.fsm.context import FSMContext -from helper_bot.handlers.voice.constants import (STATE_STANDUP_WRITE, - STATE_START) + +from helper_bot.handlers.voice.constants import STATE_STANDUP_WRITE, STATE_START from helper_bot.handlers.voice.voice_handler import VoiceHandlers class TestVoiceHandler: """Тесты для VoiceHandler""" - + @pytest.fixture def mock_db(self): """Мок для базы данных""" mock_db = Mock() mock_db.settings = { - 'Telegram': { - 'group_for_logs': 'test_logs_chat', - 'group_for_posts': 'test_posts_chat', - 'preview_link': True + "Telegram": { + "group_for_logs": "test_logs_chat", + "group_for_posts": "test_posts_chat", + "preview_link": True, } } return mock_db - + @pytest.fixture def mock_settings(self): """Мок для настроек""" return { - 'Telegram': { - 'group_for_logs': 'test_logs_chat', - 'group_for_posts': 'test_posts_chat', - 'preview_link': True + "Telegram": { + "group_for_logs": "test_logs_chat", + "group_for_posts": "test_posts_chat", + "preview_link": True, } } - + @pytest.fixture def mock_message(self): """Мок для сообщения""" @@ -48,71 +48,98 @@ class TestVoiceHandler: message.answer = AsyncMock() message.forward = AsyncMock() return message - + @pytest.fixture def mock_state(self): """Мок для состояния FSM""" state = Mock(spec=FSMContext) state.set_state = AsyncMock() return state - + @pytest.fixture def voice_handler(self, mock_db, mock_settings): """Экземпляр VoiceHandler для тестов""" return VoiceHandlers(mock_db, mock_settings) - + @pytest.mark.asyncio - async def test_voice_bot_button_handler_welcome_received(self, voice_handler, mock_message, mock_state, mock_db, mock_settings): + async def test_voice_bot_button_handler_welcome_received( + self, voice_handler, mock_message, mock_state, mock_db, mock_settings + ): """Тест обработчика кнопки когда приветствие уже получено""" from unittest.mock import AsyncMock + mock_db.check_voice_bot_welcome_received = AsyncMock(return_value=True) - - with patch.object(voice_handler, 'restart_function') as mock_restart: - with patch('helper_bot.handlers.voice.voice_handler.update_user_info') as mock_update_user: + + with patch.object(voice_handler, "restart_function") as mock_restart: + with patch( + "helper_bot.handlers.voice.voice_handler.update_user_info" + ) as mock_update_user: mock_update_user.return_value = None - - await voice_handler.voice_bot_button_handler(mock_message, mock_state, mock_db, mock_settings) - + + await voice_handler.voice_bot_button_handler( + mock_message, mock_state, mock_db, mock_settings + ) + mock_db.check_voice_bot_welcome_received.assert_called_once_with(123) - mock_restart.assert_called_once_with(mock_message, mock_state, mock_db, mock_settings) - + mock_restart.assert_called_once_with( + mock_message, mock_state, mock_db, mock_settings + ) + @pytest.mark.asyncio - async def test_voice_bot_button_handler_welcome_not_received(self, voice_handler, mock_message, mock_state, mock_db, mock_settings): + async def test_voice_bot_button_handler_welcome_not_received( + self, voice_handler, mock_message, mock_state, mock_db, mock_settings + ): """Тест обработчика кнопки когда приветствие не получено""" from unittest.mock import AsyncMock + mock_db.check_voice_bot_welcome_received = AsyncMock(return_value=False) - - with patch.object(voice_handler, 'start') as mock_start: - await voice_handler.voice_bot_button_handler(mock_message, mock_state, mock_db, mock_settings) - + + with patch.object(voice_handler, "start") as mock_start: + await voice_handler.voice_bot_button_handler( + mock_message, mock_state, mock_db, mock_settings + ) + mock_db.check_voice_bot_welcome_received.assert_called_once_with(123) - mock_start.assert_called_once_with(mock_message, mock_state, mock_db, mock_settings) - + mock_start.assert_called_once_with( + mock_message, mock_state, mock_db, mock_settings + ) + @pytest.mark.asyncio - async def test_voice_bot_button_handler_exception(self, voice_handler, mock_message, mock_state, mock_db, mock_settings): + async def test_voice_bot_button_handler_exception( + self, voice_handler, mock_message, mock_state, mock_db, mock_settings + ): """Тест обработчика кнопки при исключении""" mock_db.check_voice_bot_welcome_received.side_effect = Exception("Test error") - - with patch.object(voice_handler, 'start') as mock_start: - await voice_handler.voice_bot_button_handler(mock_message, mock_state, mock_db, mock_settings) - - mock_start.assert_called_once_with(mock_message, mock_state, mock_db, mock_settings) - + + with patch.object(voice_handler, "start") as mock_start: + await voice_handler.voice_bot_button_handler( + mock_message, mock_state, mock_db, mock_settings + ) + + mock_start.assert_called_once_with( + mock_message, mock_state, mock_db, mock_settings + ) + # Упрощенные тесты для основных функций @pytest.mark.asyncio - async def test_standup_write(self, voice_handler, mock_message, mock_state, mock_db, mock_settings): + async def test_standup_write( + self, voice_handler, mock_message, mock_state, mock_db, mock_settings + ): """Тест функции standup_write""" - with patch('helper_bot.handlers.voice.voice_handler.messages.get_message') as mock_get_message: + with patch( + "helper_bot.handlers.voice.voice_handler.messages.get_message" + ) as mock_get_message: mock_get_message.return_value = "Record voice message" - - await voice_handler.standup_write(mock_message, mock_state, mock_db, mock_settings) - + + await voice_handler.standup_write( + mock_message, mock_state, mock_db, mock_settings + ) + mock_state.set_state.assert_called_once_with(STATE_STANDUP_WRITE) mock_message.answer.assert_called_once_with( - text="Record voice message", - reply_markup=types.ReplyKeyboardRemove() + text="Record voice message", reply_markup=types.ReplyKeyboardRemove() ) - + def test_setup_handlers(self, voice_handler): """Тест настройки обработчиков""" # Проверяем, что роутер содержит обработчики @@ -228,5 +255,5 @@ class TestVoiceHandler: mock_message.answer.assert_called_once_with(f'Твоя эмодзя - 😊', parse_mode='HTML') -if __name__ == '__main__': +if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/test_voice_services.py b/tests/test_voice_services.py index 7349d7f..515f7d1 100644 --- a/tests/test_voice_services.py +++ b/tests/test_voice_services.py @@ -3,240 +3,276 @@ from pathlib import Path from unittest.mock import AsyncMock, MagicMock, Mock, patch import pytest -from helper_bot.handlers.voice.exceptions import (AudioProcessingError, - VoiceMessageError) + +from helper_bot.handlers.voice.exceptions import AudioProcessingError, VoiceMessageError from helper_bot.handlers.voice.services import VoiceBotService class TestVoiceBotService: """Тесты для VoiceBotService""" - + @pytest.fixture def mock_bot_db(self): """Мок для базы данных""" mock_db = Mock() mock_db.settings = { - 'Settings': {'logs': True}, - 'Telegram': {'important_logs': 'test_chat_id'} + "Settings": {"logs": True}, + "Telegram": {"important_logs": "test_chat_id"}, } return mock_db - + @pytest.fixture def mock_settings(self): """Мок для настроек""" - return { - 'Settings': {'logs': True}, - 'Telegram': {'preview_link': True} - } - + return {"Settings": {"logs": True}, "Telegram": {"preview_link": True}} + @pytest.fixture def voice_service(self, mock_bot_db, mock_settings): """Экземпляр VoiceBotService для тестов""" return VoiceBotService(mock_bot_db, mock_settings) - + @pytest.mark.asyncio async def test_get_welcome_sticker_success(self, voice_service, mock_settings): """Тест успешного получения стикера""" - with patch('pathlib.Path.rglob') as mock_rglob: - mock_rglob.return_value = ['/path/to/sticker1.tgs', '/path/to/sticker2.tgs'] - + with patch("pathlib.Path.rglob") as mock_rglob: + mock_rglob.return_value = ["/path/to/sticker1.tgs", "/path/to/sticker2.tgs"] + sticker = await voice_service.get_welcome_sticker() - + assert sticker is not None mock_rglob.assert_called_once() - + @pytest.mark.asyncio async def test_get_welcome_sticker_no_stickers(self, voice_service, mock_settings): """Тест получения стикера когда их нет""" - with patch('pathlib.Path.rglob') as mock_rglob: + with patch("pathlib.Path.rglob") as mock_rglob: mock_rglob.return_value = [] - + sticker = await voice_service.get_welcome_sticker() - + assert sticker is None - + @pytest.mark.asyncio - async def test_get_welcome_sticker_only_webp_files(self, voice_service, mock_settings): + async def test_get_welcome_sticker_only_webp_files( + self, voice_service, mock_settings + ): """Тест получения стикера когда есть только webp файлы""" - with patch('pathlib.Path.rglob') as mock_rglob: - mock_rglob.return_value = ['/path/to/sticker1.webp', '/path/to/sticker2.webp'] - + with patch("pathlib.Path.rglob") as mock_rglob: + mock_rglob.return_value = [ + "/path/to/sticker1.webp", + "/path/to/sticker2.webp", + ] + sticker = await voice_service.get_welcome_sticker() - + # Проверяем, что стикер не None (метод ищет файлы по паттерну Hello_*) assert sticker is not None - + @pytest.mark.asyncio async def test_get_welcome_sticker_mixed_files(self, voice_service, mock_settings): """Тест получения стикера когда есть смешанные файлы""" - with patch('pathlib.Path.rglob') as mock_rglob: + with patch("pathlib.Path.rglob") as mock_rglob: mock_rglob.return_value = [ - '/path/to/sticker1.webp', - '/path/to/sticker2.tgs', - '/path/to/sticker3.webp' + "/path/to/sticker1.webp", + "/path/to/sticker2.tgs", + "/path/to/sticker3.webp", ] - + sticker = await voice_service.get_welcome_sticker() - + assert sticker is not None # Проверяем, что стикер не None (метод возвращает FSInputFile объект) - + @pytest.mark.asyncio async def test_get_random_audio_success(self, voice_service, mock_bot_db): """Тест успешного получения случайного аудио""" - mock_bot_db.check_listen_audio = AsyncMock(return_value=['audio1', 'audio2']) + mock_bot_db.check_listen_audio = AsyncMock(return_value=["audio1", "audio2"]) mock_bot_db.get_user_id_by_file_name = AsyncMock(return_value=123) - mock_bot_db.get_date_by_file_name = AsyncMock(return_value='2025-01-01 12:00:00') - mock_bot_db.get_user_emoji = AsyncMock(return_value='😊') - + mock_bot_db.get_date_by_file_name = AsyncMock( + return_value="2025-01-01 12:00:00" + ) + mock_bot_db.get_user_emoji = AsyncMock(return_value="😊") + result = await voice_service.get_random_audio(456) - + assert result is not None assert len(result) == 3 # Проверяем, что результат содержит ожидаемые данные, но не проверяем точное значение audio - assert result[0] in ['audio1', 'audio2'] - assert result[1] == '2025-01-01 12:00:00' - assert result[2] == '😊' - + assert result[0] in ["audio1", "audio2"] + assert result[1] == "2025-01-01 12:00:00" + assert result[2] == "😊" + @pytest.mark.asyncio async def test_get_random_audio_no_audio(self, voice_service, mock_bot_db): """Тест получения аудио когда их нет""" mock_bot_db.check_listen_audio = AsyncMock(return_value=[]) - + result = await voice_service.get_random_audio(456) - + assert result is None - + @pytest.mark.asyncio async def test_get_random_audio_single_audio(self, voice_service, mock_bot_db): """Тест получения аудио когда есть только одно""" - mock_bot_db.check_listen_audio = AsyncMock(return_value=['audio1']) + mock_bot_db.check_listen_audio = AsyncMock(return_value=["audio1"]) mock_bot_db.get_user_id_by_file_name = AsyncMock(return_value=123) - mock_bot_db.get_date_by_file_name = AsyncMock(return_value='2025-01-01 12:00:00') - mock_bot_db.get_user_emoji = AsyncMock(return_value='😊') - + mock_bot_db.get_date_by_file_name = AsyncMock( + return_value="2025-01-01 12:00:00" + ) + mock_bot_db.get_user_emoji = AsyncMock(return_value="😊") + result = await voice_service.get_random_audio(456) - + assert result is not None assert len(result) == 3 - assert result[0] == 'audio1' - + assert result[0] == "audio1" + @pytest.mark.asyncio async def test_mark_audio_as_listened_success(self, voice_service, mock_bot_db): """Тест успешной пометки аудио как прослушанного""" mock_bot_db.mark_listened_audio = AsyncMock() - - await voice_service.mark_audio_as_listened('test_audio', 123) - - mock_bot_db.mark_listened_audio.assert_called_once_with('test_audio', user_id=123) - + + await voice_service.mark_audio_as_listened("test_audio", 123) + + mock_bot_db.mark_listened_audio.assert_called_once_with( + "test_audio", user_id=123 + ) + @pytest.mark.asyncio async def test_clear_user_listenings_success(self, voice_service, mock_bot_db): """Тест успешной очистки прослушиваний""" mock_bot_db.delete_listen_count_for_user = AsyncMock() - + await voice_service.clear_user_listenings(123) - + mock_bot_db.delete_listen_count_for_user.assert_called_once_with(123) - + @pytest.mark.asyncio async def test_get_remaining_audio_count_success(self, voice_service, mock_bot_db): """Тест получения количества оставшихся аудио""" - mock_bot_db.check_listen_audio = AsyncMock(return_value=['audio1', 'audio2', 'audio3']) - + mock_bot_db.check_listen_audio = AsyncMock( + return_value=["audio1", "audio2", "audio3"] + ) + result = await voice_service.get_remaining_audio_count(123) - + assert result == 3 mock_bot_db.check_listen_audio.assert_called_once_with(user_id=123) - + @pytest.mark.asyncio async def test_get_remaining_audio_count_zero(self, voice_service, mock_bot_db): """Тест получения количества оставшихся аудио когда их нет""" mock_bot_db.check_listen_audio = AsyncMock(return_value=[]) - + result = await voice_service.get_remaining_audio_count(123) - + assert result == 0 mock_bot_db.check_listen_audio.assert_called_once_with(user_id=123) - + @pytest.mark.asyncio - async def test_send_welcome_messages_success(self, voice_service, mock_bot_db, mock_settings): - """Тест успешной отправки приветственных сообщений""" + async def test_send_welcome_messages_success( + self, voice_service, mock_bot_db, mock_settings + ): + """Тест успешной отправки приветственных сообщений.""" mock_message = Mock() mock_message.from_user.id = 123 mock_message.answer = AsyncMock() mock_message.answer.return_value = Mock() mock_message.answer_sticker = AsyncMock() - - with patch.object(voice_service, 'get_welcome_sticker') as mock_sticker: - mock_sticker.return_value = 'test_sticker.tgs' - - await voice_service.send_welcome_messages(mock_message, '😊') - - # Проверяем, что сообщения отправлены - assert mock_message.answer.call_count >= 1 - + + with patch.object( + voice_service, + "get_welcome_sticker", + new_callable=AsyncMock, + return_value="test_sticker.tgs", + ): + with patch( + "helper_bot.handlers.voice.services.asyncio.sleep", + new_callable=AsyncMock, + ): + await voice_service.send_welcome_messages(mock_message, "😊") + + assert mock_message.answer.call_count >= 1 + @pytest.mark.asyncio - async def test_send_welcome_messages_no_sticker(self, voice_service, mock_bot_db, mock_settings): - """Тест отправки приветственных сообщений без стикера""" + async def test_send_welcome_messages_no_sticker( + self, voice_service, mock_bot_db, mock_settings + ): + """Тест отправки приветственных сообщений без стикера.""" mock_message = Mock() mock_message.from_user.id = 123 mock_message.answer = AsyncMock() mock_message.answer.return_value = Mock() - - with patch.object(voice_service, 'get_welcome_sticker') as mock_sticker: - mock_sticker.return_value = None - - await voice_service.send_welcome_messages(mock_message, '😊') - - # Проверяем, что сообщения отправлены - assert mock_message.answer.call_count >= 1 - + + with patch.object( + voice_service, + "get_welcome_sticker", + new_callable=AsyncMock, + return_value=None, + ): + with patch( + "helper_bot.handlers.voice.services.asyncio.sleep", + new_callable=AsyncMock, + ): + await voice_service.send_welcome_messages(mock_message, "😊") + + assert mock_message.answer.call_count >= 1 + @pytest.mark.asyncio - async def test_send_welcome_messages_with_sticker(self, voice_service, mock_bot_db, mock_settings): - """Тест отправки приветственных сообщений со стикером""" + async def test_send_welcome_messages_with_sticker( + self, voice_service, mock_bot_db, mock_settings + ): + """Тест отправки приветственных сообщений со стикером.""" mock_message = Mock() mock_message.from_user.id = 123 mock_message.answer = AsyncMock() mock_message.answer.return_value = Mock() mock_message.answer_sticker = AsyncMock() - - with patch.object(voice_service, 'get_welcome_sticker') as mock_sticker: - mock_sticker.return_value = 'test_sticker.tgs' - - await voice_service.send_welcome_messages(mock_message, '😊') - - # Проверяем, что сообщения отправлены - assert mock_message.answer.call_count >= 1 - + + with patch.object( + voice_service, + "get_welcome_sticker", + new_callable=AsyncMock, + return_value="test_sticker.tgs", + ): + with patch( + "helper_bot.handlers.voice.services.asyncio.sleep", + new_callable=AsyncMock, + ): + await voice_service.send_welcome_messages(mock_message, "😊") + + assert mock_message.answer.call_count >= 1 + @pytest.mark.asyncio - async def test_get_welcome_sticker_with_tgs_files(self, voice_service, mock_settings): + async def test_get_welcome_sticker_with_tgs_files( + self, voice_service, mock_settings + ): """Тест получения стикера когда есть .tgs файлы""" - with patch('pathlib.Path.rglob') as mock_rglob: - mock_rglob.return_value = ['/path/to/sticker1.tgs', '/path/to/sticker2.tgs'] - + with patch("pathlib.Path.rglob") as mock_rglob: + mock_rglob.return_value = ["/path/to/sticker1.tgs", "/path/to/sticker2.tgs"] + sticker = await voice_service.get_welcome_sticker() - + assert sticker is not None # Проверяем, что стикер не None (метод возвращает FSInputFile объект) - + def test_service_initialization(self, mock_bot_db, mock_settings): """Тест инициализации сервиса""" service = VoiceBotService(mock_bot_db, mock_settings) - + assert service.bot_db == mock_bot_db assert service.settings == mock_settings - + def test_service_attributes(self, voice_service): """Тест атрибутов сервиса""" - assert hasattr(voice_service, 'bot_db') - assert hasattr(voice_service, 'settings') - assert hasattr(voice_service, 'get_welcome_sticker') - assert hasattr(voice_service, 'get_random_audio') - assert hasattr(voice_service, 'mark_audio_as_listened') - assert hasattr(voice_service, 'clear_user_listenings') - assert hasattr(voice_service, 'get_remaining_audio_count') - assert hasattr(voice_service, 'send_welcome_messages') + assert hasattr(voice_service, "bot_db") + assert hasattr(voice_service, "settings") + assert hasattr(voice_service, "get_welcome_sticker") + assert hasattr(voice_service, "get_random_audio") + assert hasattr(voice_service, "mark_audio_as_listened") + assert hasattr(voice_service, "clear_user_listenings") + assert hasattr(voice_service, "get_remaining_audio_count") + assert hasattr(voice_service, "send_welcome_messages") @pytest.mark.asyncio async def test_get_welcome_sticker_exception_returns_none(self, voice_service, mock_settings): @@ -324,5 +360,5 @@ class TestVoiceBotService: mock_send.assert_awaited_once() -if __name__ == '__main__': +if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/test_voice_utils.py b/tests/test_voice_utils.py index 4949090..50de559 100644 --- a/tests/test_voice_utils.py +++ b/tests/test_voice_utils.py @@ -3,26 +3,26 @@ from unittest.mock import Mock, patch import pytest from aiogram import types -from helper_bot.handlers.voice.utils import (format_time_ago, - get_last_message_text, - get_user_emoji_safe, plural_time, - validate_voice_message) + +from helper_bot.handlers.voice.utils import ( + format_time_ago, + get_last_message_text, + get_user_emoji_safe, + plural_time, + validate_voice_message, +) class TestVoiceUtils: """Тесты для утилит voice модуля""" - + @pytest.fixture def mock_bot_db(self): """Мок для базы данных""" mock_db = Mock() - mock_db.settings = { - 'Telegram': { - 'group_for_logs': 'test_logs_chat' - } - } + mock_db.settings = {"Telegram": {"group_for_logs": "test_logs_chat"}} return mock_db - + @pytest.fixture def mock_message(self): """Мок для сообщения""" @@ -35,134 +35,149 @@ class TestVoiceUtils: message.from_user.language_code = "ru" message.chat.id = 456 return message - + @pytest.mark.asyncio async def test_get_last_message_text(self, mock_bot_db): """Тест получения последнего сообщения""" # Возвращаем UNIX timestamp from unittest.mock import AsyncMock - mock_bot_db.last_date_audio = AsyncMock(return_value=1641034800) # 2022-01-01 12:00:00 - + + mock_bot_db.last_date_audio = AsyncMock( + return_value=1641034800 + ) # 2022-01-01 12:00:00 + result = await get_last_message_text(mock_bot_db) - + assert result is not None - assert "минут" in result or "часа" in result or "дня" in result or "день" in result or "дней" in result + assert ( + "минут" in result + or "часа" in result + or "дня" in result + or "день" in result + or "дней" in result + ) mock_bot_db.last_date_audio.assert_called_once() - + @pytest.mark.asyncio async def test_validate_voice_message_valid(self): """Тест валидации голосового сообщения""" mock_message = Mock() - mock_message.content_type = 'voice' + mock_message.content_type = "voice" mock_message.voice = Mock() - + result = await validate_voice_message(mock_message) - + assert result is True - + @pytest.mark.asyncio async def test_validate_voice_message_invalid(self): """Тест валидации невалидного сообщения""" mock_message = Mock() mock_message.voice = None - + result = await validate_voice_message(mock_message) - + assert result is False - + @pytest.mark.asyncio async def test_get_user_emoji_safe_with_emoji(self, mock_bot_db): """Тест безопасного получения эмодзи пользователя когда эмодзи есть""" from unittest.mock import AsyncMock + mock_bot_db.get_user_emoji = AsyncMock(return_value="😊") - + result = await get_user_emoji_safe(mock_bot_db, 123) - + assert result == "😊" mock_bot_db.get_user_emoji.assert_called_once_with(123) - + @pytest.mark.asyncio async def test_get_user_emoji_safe_without_emoji(self, mock_bot_db): """Тест безопасного получения эмодзи пользователя когда эмодзи нет""" from unittest.mock import AsyncMock + mock_bot_db.get_user_emoji = AsyncMock(return_value=None) - + result = await get_user_emoji_safe(mock_bot_db, 123) - + assert result == "😊" mock_bot_db.get_user_emoji.assert_called_once_with(123) - + @pytest.mark.asyncio async def test_get_user_emoji_safe_with_empty_emoji(self, mock_bot_db): """Тест безопасного получения эмодзи пользователя с пустым эмодзи""" from unittest.mock import AsyncMock + mock_bot_db.get_user_emoji = AsyncMock(return_value="") - + result = await get_user_emoji_safe(mock_bot_db, 123) - + assert result == "😊" mock_bot_db.get_user_emoji.assert_called_once_with(123) - + @pytest.mark.asyncio async def test_get_user_emoji_safe_with_error(self, mock_bot_db): """Тест безопасного получения эмодзи пользователя при ошибке""" from unittest.mock import AsyncMock + mock_bot_db.get_user_emoji = AsyncMock(return_value="Ошибка") - + result = await get_user_emoji_safe(mock_bot_db, 123) - + assert result == "Ошибка" mock_bot_db.get_user_emoji.assert_called_once_with(123) - + def test_format_time_ago_minutes(self): """Тест форматирования времени в минутах""" from datetime import datetime, timedelta # Создаем дату 30 минут назад - test_date = (datetime.now() - timedelta(minutes=30)).strftime("%Y-%m-%d %H:%M:%S") - + test_date = (datetime.now() - timedelta(minutes=30)).strftime( + "%Y-%m-%d %H:%M:%S" + ) + result = format_time_ago(test_date) - + assert result is not None assert "минут" in result assert "30" in result or "29" in result or "31" in result - + def test_format_time_ago_hours(self): """Тест форматирования времени в часах""" from datetime import datetime, timedelta # Создаем дату 2 часа назад test_date = (datetime.now() - timedelta(hours=2)).strftime("%Y-%m-%d %H:%M:%S") - + result = format_time_ago(test_date) - + assert result is not None assert "часа" in result or "часов" in result - + def test_format_time_ago_days(self): """Тест форматирования времени в днях""" from datetime import datetime, timedelta # Создаем дату 3 дня назад test_date = (datetime.now() - timedelta(days=3)).strftime("%Y-%m-%d %H:%M:%S") - + result = format_time_ago(test_date) - + assert result is not None assert "дня" in result or "дней" in result - + def test_format_time_ago_none(self): """Тест форматирования времени с None""" result = format_time_ago(None) - + assert result is None - + def test_format_time_ago_invalid_format(self): """Тест форматирования времени с неверным форматом""" result = format_time_ago("invalid_date_format") - + assert result is None - + def test_plural_time_minutes(self): """Тест множественного числа для минут""" assert "1 минуту" in plural_time(1, 1) @@ -170,7 +185,7 @@ class TestVoiceUtils: assert "5 минут" in plural_time(1, 5) assert "11 минут" in plural_time(1, 11) assert "21 минуту" in plural_time(1, 21) - + def test_plural_time_hours(self): """Тест множественного числа для часов""" assert "1 час" in plural_time(2, 1) @@ -178,7 +193,7 @@ class TestVoiceUtils: assert "5 часов" in plural_time(2, 5) assert "11 часов" in plural_time(2, 11) assert "21 час" in plural_time(2, 21) - + def test_plural_time_days(self): """Тест множественного числа для дней""" assert "1 день" in plural_time(3, 1) @@ -186,25 +201,25 @@ class TestVoiceUtils: assert "5 дней" in plural_time(3, 5) assert "11 дней" in plural_time(3, 11) assert "21 день" in plural_time(3, 21) - + def test_plural_time_invalid_type(self): """Тест множественного числа с неверным типом""" result = plural_time(4, 5) - + assert result == "5" - + def test_plural_time_edge_cases(self): """Тест граничных случаев для множественного числа""" # Тест для 0 assert "0 минут" in plural_time(1, 0) assert "0 часов" in plural_time(2, 0) assert "0 дней" in plural_time(3, 0) - + # Тест для больших чисел assert "100 минут" in plural_time(1, 100) assert "100 часов" in plural_time(2, 100) assert "100 дней" in plural_time(3, 100) -if __name__ == '__main__': +if __name__ == "__main__": pytest.main([__file__])