From 7f6f0f028cb3a6cc0a8f9ab70effd4f3d9475128 Mon Sep 17 00:00:00 2001 From: Andrey Date: Mon, 26 Jan 2026 18:40:38 +0300 Subject: [PATCH] =?UTF-8?q?feat:=20=D0=B8=D0=BD=D1=82=D0=B5=D0=B3=D1=80?= =?UTF-8?q?=D0=B0=D1=86=D0=B8=D1=8F=20ML-=D1=81=D0=BA=D0=BE=D1=80=D0=B8?= =?UTF-8?q?=D0=BD=D0=B3=D0=B0=20=D1=81=20=D0=B8=D1=81=D0=BF=D0=BE=D0=BB?= =?UTF-8?q?=D1=8C=D0=B7=D0=BE=D0=B2=D0=B0=D0=BD=D0=B8=D0=B5=D0=BC=20RAG=20?= =?UTF-8?q?=D0=B8=20DeepSeek?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Обновлен Dockerfile для установки необходимых зависимостей. - Добавлены новые переменные окружения для настройки ML-скоринга в env.example. - Реализованы методы для получения и обновления ML-скоров в AsyncBotDB и PostRepository. - Обновлены обработчики публикации постов для интеграции ML-скоринга. - Добавлен новый обработчик для получения статистики ML-скоринга в админ-панели. - Обновлены функции для форматирования сообщений с учетом ML-скоров. --- .github/workflows/deploy.yml | 5 +- Dockerfile | 36 +- database/async_db.py | 17 + database/repositories/post_repository.py | 123 +++++ env.example | 17 + helper_bot/handlers/admin/admin_handlers.py | 64 +++ .../handlers/callback/dependency_factory.py | 3 +- helper_bot/handlers/callback/services.py | 35 +- .../handlers/private/private_handlers.py | 20 +- helper_bot/handlers/private/services.py | 148 ++++- helper_bot/keyboards/keyboards.py | 3 + helper_bot/main.py | 16 + helper_bot/services/__init__.py | 5 + helper_bot/services/scoring/__init__.py | 42 ++ helper_bot/services/scoring/base.py | 155 ++++++ .../services/scoring/deepseek_service.py | 358 +++++++++++++ helper_bot/services/scoring/exceptions.py | 33 ++ helper_bot/services/scoring/rag_service.py | 507 ++++++++++++++++++ .../services/scoring/scoring_manager.py | 242 +++++++++ helper_bot/services/scoring/vector_store.py | 399 ++++++++++++++ helper_bot/utils/base_dependency_factory.py | 123 +++++ helper_bot/utils/helper_func.py | 43 +- requirements.txt | 8 +- scripts/add_ml_scores_columns.py | 93 ++++ tests/test_scoring_services.py | 390 ++++++++++++++ 25 files changed, 2833 insertions(+), 52 deletions(-) create mode 100644 helper_bot/services/__init__.py create mode 100644 helper_bot/services/scoring/__init__.py create mode 100644 helper_bot/services/scoring/base.py create mode 100644 helper_bot/services/scoring/deepseek_service.py create mode 100644 helper_bot/services/scoring/exceptions.py create mode 100644 helper_bot/services/scoring/rag_service.py create mode 100644 helper_bot/services/scoring/scoring_manager.py create mode 100644 helper_bot/services/scoring/vector_store.py create mode 100644 scripts/add_ml_scores_columns.py create mode 100644 tests/test_scoring_services.py diff --git a/.github/workflows/deploy.yml b/.github/workflows/deploy.yml index 70e5df4..d23b853 100644 --- a/.github/workflows/deploy.yml +++ b/.github/workflows/deploy.yml @@ -165,9 +165,8 @@ jobs: 📦 Repository: telegram-helper-bot 🌿 Branch: main - 📝 Commit: ${{ github.event.pull_request.merge_commit_sha || github.sha }} - 👤 Author: ${{ github.event.pull_request.user.login || github.actor }} - ${{ github.event.pull_request.number && format('🔀 PR: #{0}', github.event.pull_request.number) || '' }} + 📝 Commit: ${{ github.sha }} + 👤 Author: ${{ github.actor }} ${{ job.status == 'success' && '✅ Deployment successful! Container restarted with migrations applied.' || '❌ Deployment failed! Check logs for details.' }} diff --git a/Dockerfile b/Dockerfile index 0c36fe8..c41ab93 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,15 +1,14 @@ ########################################### # Этап 1: Сборщик (Builder) ########################################### -FROM python:3.11.9-alpine as builder +FROM python:3.11.9-slim as builder -# Устанавливаем инструменты для компиляции + linux-headers для psutil -RUN apk add --no-cache \ +# Устанавливаем инструменты для компиляции +RUN apt-get update && apt-get install --no-install-recommends -y \ gcc \ g++ \ - musl-dev \ python3-dev \ - linux-headers # ← ЭТО КРИТИЧЕСКИ ВАЖНО ДЛЯ psutil + && rm -rf /var/lib/apt/lists/* WORKDIR /app COPY requirements.txt . @@ -21,29 +20,34 @@ RUN pip install --no-cache-dir --target /install -r requirements.txt ########################################### # Этап 2: Финальный образ (Runtime) ########################################### -FROM python:3.11.9-alpine as runtime +FROM python:3.11.9-slim as runtime # Минимальные рантайм-зависимости -RUN apk add --no-cache \ - libstdc++ \ - sqlite-libs +RUN apt-get update && apt-get install --no-install-recommends -y \ + libgomp1 \ + && rm -rf /var/lib/apt/lists/* # Создаем пользователя -RUN addgroup -g 1001 deploy && adduser -D -u 1001 -G deploy deploy +RUN groupadd -g 1001 deploy && useradd -r -u 1001 -g deploy deploy WORKDIR /app # Копируем зависимости -COPY --from=builder --chown=1001:1001 /install /usr/local/lib/python3.11/site-packages +COPY --from=builder --chown=deploy:deploy /install /usr/local/lib/python3.11/site-packages -# Создаем структуру папок -RUN mkdir -p database logs voice_users && \ - chown -R 1001:1001 /app +# Создаем структуру папок (включая директории для ML моделей) +RUN mkdir -p database logs voice_users data/models && \ + chown -R deploy:deploy /app + +# Устанавливаем переменные для HuggingFace (кеш моделей внутри /app) +ENV HF_HOME=/app/data/models +ENV TRANSFORMERS_CACHE=/app/data/models +ENV HF_HUB_CACHE=/app/data/models # Копируем исходный код -COPY --chown=1001:1001 . . +COPY --chown=deploy:deploy . . -USER 1001 +USER deploy # Healthcheck HEALTHCHECK --interval=30s --timeout=15s --start-period=10s --retries=5 \ diff --git a/database/async_db.py b/database/async_db.py index e5e3403..e5f74d8 100644 --- a/database/async_db.py +++ b/database/async_db.py @@ -210,6 +210,23 @@ class AsyncBotDB: return await self.factory.posts.update_status_for_media_group_by_helper_id( helper_message_id, status ) + + # Методы для ML Scoring + async def get_post_text_by_message_id(self, message_id: int) -> Optional[str]: + """Получает текст поста по message_id.""" + return await self.factory.posts.get_post_text_by_message_id(message_id) + + async def update_ml_scores(self, message_id: int, ml_scores_json: str) -> bool: + """Обновляет ML-скоры для поста.""" + return await self.factory.posts.update_ml_scores(message_id, ml_scores_json) + + async def get_approved_posts_texts(self, limit: int = 1000) -> List[str]: + """Получает тексты одобренных постов для обучения RAG.""" + return await self.factory.posts.get_approved_posts_texts(limit) + + async def get_declined_posts_texts(self, limit: int = 1000) -> List[str]: + """Получает тексты отклоненных постов для обучения RAG.""" + return await self.factory.posts.get_declined_posts_texts(limit) # Методы для работы с черным списком async def set_user_blacklist( diff --git a/database/repositories/post_repository.py b/database/repositories/post_repository.py index daa4265..cae5ede 100644 --- a/database/repositories/post_repository.py +++ b/database/repositories/post_repository.py @@ -357,3 +357,126 @@ class PostRepository(DatabaseConnection): post_content = await self._execute_query_with_result(query, (published_message_id,)) self.logger.info(f"Получен контент опубликованного поста: {len(post_content)} элементов для published_message_id={published_message_id}") return post_content + + # ============================================ + # Методы для работы с ML-скорингом + # ============================================ + + async def update_ml_scores(self, message_id: int, ml_scores_json: str) -> bool: + """ + Обновляет ML-скоры для поста. + + Args: + message_id: ID сообщения в группе модерации + ml_scores_json: JSON строка со скорами + + Returns: + True если обновлено успешно + """ + try: + query = "UPDATE post_from_telegram_suggest SET ml_scores = ? WHERE message_id = ?" + await self._execute_query(query, (ml_scores_json, message_id)) + self.logger.info(f"ML-скоры обновлены для message_id={message_id}") + return True + except Exception as e: + self.logger.error(f"Ошибка обновления ML-скоров для message_id={message_id}: {e}") + return False + + async def get_ml_scores_by_message_id(self, message_id: int) -> Optional[str]: + """ + Получает ML-скоры для поста. + + Args: + message_id: ID сообщения + + Returns: + JSON строка со скорами или None + """ + query = "SELECT ml_scores FROM post_from_telegram_suggest WHERE message_id = ?" + rows = await self._execute_query_with_result(query, (message_id,)) + if rows and rows[0][0]: + return rows[0][0] + return None + + async def get_post_text_by_message_id(self, message_id: int) -> Optional[str]: + """ + Получает текст поста по message_id. + + Args: + message_id: ID сообщения + + Returns: + Текст поста или None + """ + query = "SELECT text FROM post_from_telegram_suggest WHERE message_id = ?" + rows = await self._execute_query_with_result(query, (message_id,)) + if rows and rows[0][0]: + return rows[0][0] + return None + + async def get_approved_posts_texts(self, limit: int = 1000) -> List[str]: + """ + Получает тексты опубликованных постов для обучения RAG. + + Args: + limit: Максимальное количество постов + + Returns: + Список текстов + """ + query = """ + SELECT text FROM post_from_telegram_suggest + WHERE status = 'approved' + AND text IS NOT NULL + AND text != '' + AND text != '^' + ORDER BY created_at DESC + LIMIT ? + """ + rows = await self._execute_query_with_result(query, (limit,)) + texts = [row[0] for row in rows if row[0]] + self.logger.info(f"Получено {len(texts)} опубликованных постов для обучения") + return texts + + async def get_declined_posts_texts(self, limit: int = 1000) -> List[str]: + """ + Получает тексты отклоненных постов для обучения RAG. + + Args: + limit: Максимальное количество постов + + Returns: + Список текстов + """ + query = """ + SELECT text FROM post_from_telegram_suggest + WHERE status = 'declined' + AND text IS NOT NULL + AND text != '' + AND text != '^' + ORDER BY created_at DESC + LIMIT ? + """ + rows = await self._execute_query_with_result(query, (limit,)) + texts = [row[0] for row in rows if row[0]] + self.logger.info(f"Получено {len(texts)} отклоненных постов для обучения") + return texts + + async def update_vector_hash(self, message_id: int, vector_hash: str) -> bool: + """ + Обновляет хеш вектора для поста (для кеширования). + + Args: + message_id: ID сообщения + vector_hash: Хеш вектора + + Returns: + True если обновлено успешно + """ + try: + query = "UPDATE post_from_telegram_suggest SET vector_hash = ? WHERE message_id = ?" + await self._execute_query(query, (vector_hash, message_id)) + return True + except Exception as e: + self.logger.error(f"Ошибка обновления vector_hash для message_id={message_id}: {e}") + return False diff --git a/env.example b/env.example index dbab9a9..ea06d24 100644 --- a/env.example +++ b/env.example @@ -35,3 +35,20 @@ METRICS_PORT=8080 # Logging LOG_LEVEL=INFO LOG_RETENTION_DAYS=30 + +# ML Scoring - RAG (ruBERT) +# Включает локальное векторное сравнение с использованием ruBERT +RAG_ENABLED=false +RAG_MODEL=DeepPavlov/rubert-base-cased +RAG_CACHE_DIR=data/models +RAG_VECTORS_PATH=data/vectors.npz +RAG_MAX_EXAMPLES=10000 +RAG_SCORE_MULTIPLIER=5 + +# ML Scoring - DeepSeek API +# Включает оценку постов через DeepSeek API +DEEPSEEK_ENABLED=false +DEEPSEEK_API_KEY=your_deepseek_api_key_here +DEEPSEEK_API_URL=https://api.deepseek.com/v1/chat/completions +DEEPSEEK_MODEL=deepseek-chat +DEEPSEEK_TIMEOUT=30 diff --git a/helper_bot/handlers/admin/admin_handlers.py b/helper_bot/handlers/admin/admin_handlers.py index 50d041a..31ed534 100644 --- a/helper_bot/handlers/admin/admin_handlers.py +++ b/helper_bot/handlers/admin/admin_handlers.py @@ -16,6 +16,7 @@ from helper_bot.keyboards.keyboards import (create_keyboard_for_approve_ban, create_keyboard_for_ban_reason, create_keyboard_with_pagination, get_reply_keyboard_admin) +from helper_bot.utils.base_dependency_factory import get_global_instance # Local imports - metrics from helper_bot.utils.metrics import db_query_time, track_errors, track_time from logs.custom_logger import logger @@ -137,6 +138,69 @@ async def get_banned_users( await handle_admin_error(message, e, state, "get_banned_users") +@admin_router.message( + ChatTypeFilter(chat_type=["private"]), + StateFilter("ADMIN"), + F.text == '📊 ML Статистика' +) +@track_time("get_ml_stats", "admin_handlers") +@track_errors("admin_handlers", "get_ml_stats") +async def get_ml_stats( + message: types.Message, + state: FSMContext, + **kwargs + ): + """Получение статистики ML-скоринга""" + try: + logger.info(f"Запрос ML статистики от пользователя: {message.from_user.full_name}") + + bdf = get_global_instance() + scoring_manager = bdf.get_scoring_manager() + + if not scoring_manager: + await message.answer("📊 ML Scoring отключен\n\nДля включения установите RAG_ENABLED=true или DEEPSEEK_ENABLED=true в .env") + return + + stats = scoring_manager.get_stats() + + # Формируем текст статистики + lines = ["📊 ML Scoring Статистика\n"] + + # RAG статистика + if "rag" in stats: + rag = stats["rag"] + lines.append("🤖 RAG (ruBERT):") + lines.append(f" • Статус: {'✅ Включен' if rag.get('enabled') else '❌ Отключен'}") + lines.append(f" • Модель: {rag.get('model_name', 'N/A')}") + lines.append(f" • Модель загружена: {'✅' if rag.get('model_loaded') else '❌'}") + + vs = rag.get("vector_store", {}) + lines.append(f" • Положительных примеров: {vs.get('positive_count', 0)}") + lines.append(f" • Отрицательных примеров: {vs.get('negative_count', 0)}") + lines.append(f" • Всего примеров: {vs.get('total_count', 0)}") + lines.append(f" • Макс. примеров: {vs.get('max_examples', 'N/A')}") + lines.append("") + + # DeepSeek статистика + if "deepseek" in stats: + ds = stats["deepseek"] + lines.append("🔮 DeepSeek API:") + lines.append(f" • Статус: {'✅ Включен' if ds.get('enabled') else '❌ Отключен'}") + lines.append(f" • Модель: {ds.get('model', 'N/A')}") + lines.append(f" • Таймаут: {ds.get('timeout', 'N/A')}с") + lines.append("") + + # Если ничего не включено + if "rag" not in stats and "deepseek" not in stats: + lines.append("⚠️ Ни один сервис не настроен") + + await message.answer("\n".join(lines), parse_mode="HTML") + + except Exception as e: + logger.error(f"Ошибка получения ML статистики: {e}") + await message.answer(f"❌ Ошибка получения статистики: {str(e)}") + + # ============================================================================ # ХЕНДЛЕРЫ ПРОЦЕССА БАНА # ============================================================================ diff --git a/helper_bot/handlers/callback/dependency_factory.py b/helper_bot/handlers/callback/dependency_factory.py index c6cdbb4..ec3f563 100644 --- a/helper_bot/handlers/callback/dependency_factory.py +++ b/helper_bot/handlers/callback/dependency_factory.py @@ -15,7 +15,8 @@ def get_post_publish_service() -> PostPublishService: db = bdf.get_db() settings = bdf.settings s3_storage = bdf.get_s3_storage() - return PostPublishService(None, db, settings, s3_storage) + scoring_manager = bdf.get_scoring_manager() + return PostPublishService(None, db, settings, s3_storage, scoring_manager) def get_ban_service() -> BanService: diff --git a/helper_bot/handlers/callback/services.py b/helper_bot/handlers/callback/services.py index e72d347..4620e7f 100644 --- a/helper_bot/handlers/callback/services.py +++ b/helper_bot/handlers/callback/services.py @@ -29,12 +29,13 @@ from .exceptions import (BanError, PostNotFoundError, PublishError, class PostPublishService: - def __init__(self, bot: Bot, db, settings: Dict[str, Any], s3_storage=None): + def __init__(self, bot: Bot, db, settings: Dict[str, Any], s3_storage=None, scoring_manager=None): # bot может быть None - в этом случае используем бота из контекста сообщения self.bot = bot self.db = db self.settings = settings self.s3_storage = s3_storage + self.scoring_manager = scoring_manager self.group_for_posts = settings['Telegram']['group_for_posts'] self.main_public = settings['Telegram']['main_public'] self.important_logs = settings['Telegram']['important_logs'] @@ -392,6 +393,9 @@ class PostPublishService: async def _decline_single_post(self, call: CallbackQuery) -> None: """Отклонение одиночного поста""" author_id = await self._get_author_id(call.message.message_id) + + # Обучаем RAG на отклоненном посте перед удалением + await self._train_on_declined(call.message.message_id) updated_rows = await self.db.update_status_by_message_id(call.message.message_id, "declined") if updated_rows == 0: @@ -485,6 +489,9 @@ class PostPublishService: @track_errors("post_publish_service", "_delete_post_and_notify_author") async def _delete_post_and_notify_author(self, call: CallbackQuery, author_id: int) -> None: """Удаление поста и уведомление автора""" + # Получаем текст поста для обучения RAG перед удалением + await self._train_on_published(call.message.message_id) + await self._get_bot(call.message).delete_message(chat_id=self.group_for_posts, message_id=call.message.message_id) try: @@ -493,6 +500,32 @@ class PostPublishService: if str(e) == ERROR_BOT_BLOCKED: raise UserBlockedBotError("Пользователь заблокировал бота") raise + + async def _train_on_published(self, message_id: int) -> None: + """Обучает RAG на опубликованном посте.""" + if not self.scoring_manager: + return + + try: + text = await self.db.get_post_text_by_message_id(message_id) + if text and text.strip() and text != "^": + await self.scoring_manager.on_post_published(text) + logger.debug(f"RAG обучен на опубликованном посте: {message_id}") + except Exception as e: + logger.error(f"Ошибка обучения RAG на опубликованном посте {message_id}: {e}") + + async def _train_on_declined(self, message_id: int) -> None: + """Обучает RAG на отклоненном посте.""" + if not self.scoring_manager: + return + + try: + text = await self.db.get_post_text_by_message_id(message_id) + if text and text.strip() and text != "^": + await self.scoring_manager.on_post_declined(text) + logger.debug(f"RAG обучен на отклоненном посте: {message_id}") + except Exception as e: + logger.error(f"Ошибка обучения RAG на отклоненном посте {message_id}: {e}") @track_time("_delete_media_group_and_notify_author", "post_publish_service") @track_errors("post_publish_service", "_delete_media_group_and_notify_author") diff --git a/helper_bot/handlers/private/private_handlers.py b/helper_bot/handlers/private/private_handlers.py index f9f2646..af01457 100644 --- a/helper_bot/handlers/private/private_handlers.py +++ b/helper_bot/handlers/private/private_handlers.py @@ -35,11 +35,11 @@ sleep = asyncio.sleep class PrivateHandlers: """Main handler class for private messages""" - def __init__(self, db: AsyncBotDB, settings: BotSettings, s3_storage=None): + def __init__(self, db: AsyncBotDB, settings: BotSettings, s3_storage=None, scoring_manager=None): self.db = db self.settings = settings self.user_service = UserService(db, settings) - self.post_service = PostService(db, settings, s3_storage) + self.post_service = PostService(db, settings, s3_storage, scoring_manager) self.sticker_service = StickerService(settings) self.router = Router() @@ -240,18 +240,24 @@ class PrivateHandlers: # Factory function to create handlers with dependencies -def create_private_handlers(db: AsyncBotDB, settings: BotSettings, s3_storage=None) -> PrivateHandlers: +def create_private_handlers(db: AsyncBotDB, settings: BotSettings, s3_storage=None, scoring_manager=None) -> PrivateHandlers: """Create private handlers instance with dependencies""" - return PrivateHandlers(db, settings, s3_storage) + return PrivateHandlers(db, settings, s3_storage, scoring_manager) # Legacy router for backward compatibility private_router = Router() +# Флаг инициализации для защиты от повторного вызова +_legacy_router_initialized = False + # Initialize with global dependencies (for backward compatibility) def init_legacy_router(): """Initialize legacy router with global dependencies""" - global private_router + global private_router, _legacy_router_initialized + + if _legacy_router_initialized: + return from helper_bot.utils.base_dependency_factory import get_global_instance @@ -269,11 +275,13 @@ def init_legacy_router(): db = bdf.get_db() s3_storage = bdf.get_s3_storage() - handlers = create_private_handlers(db, settings, s3_storage) + scoring_manager = bdf.get_scoring_manager() + handlers = create_private_handlers(db, settings, s3_storage, scoring_manager) # Instead of trying to copy handlers, we'll use the new router directly # This maintains backward compatibility while using the new architecture private_router = handlers.router + _legacy_router_initialized = True # Initialize legacy router init_legacy_router() diff --git a/helper_bot/handlers/private/services.py b/helper_bot/handlers/private/services.py index 8f0c151..904dd60 100644 --- a/helper_bot/handlers/private/services.py +++ b/helper_bot/handlers/private/services.py @@ -128,10 +128,11 @@ class UserService: class PostService: """Service for post-related operations""" - def __init__(self, db: DatabaseProtocol, settings: BotSettings, s3_storage=None) -> None: + def __init__(self, db: DatabaseProtocol, settings: BotSettings, s3_storage=None, scoring_manager=None) -> None: self.db = db self.settings = settings self.s3_storage = s3_storage + self.scoring_manager = scoring_manager async def _save_media_background(self, sent_message: types.Message, bot_db: Any, s3_storage) -> None: """Сохраняет медиа в фоне, чтобы не блокировать ответ пользователю""" @@ -142,18 +143,65 @@ class PostService: except Exception as e: logger.error(f"_save_media_background: Ошибка при сохранении медиа для поста {sent_message.message_id}: {e}") + async def _get_scores(self, text: str) -> tuple: + """ + Получает скоры для текста поста. + + Returns: + Tuple (deepseek_score, rag_score, rag_confidence, rag_score_pos_only, ml_scores_json) + """ + if not self.scoring_manager or not text or not text.strip(): + return None, None, None, None, None + + try: + scores = await self.scoring_manager.score_post(text) + + # Формируем JSON для сохранения в БД + import json + ml_scores_json = json.dumps(scores.to_json_dict()) if scores.has_any_score() else None + + # Получаем данные от RAG + rag_confidence = scores.rag.confidence if scores.rag else None + rag_score_pos_only = scores.rag.metadata.get("score_pos_only") if scores.rag else None + + return scores.deepseek_score, scores.rag_score, rag_confidence, rag_score_pos_only, ml_scores_json + except Exception as e: + logger.error(f"PostService: Ошибка получения скоров: {e}") + return None, None, None, None, None + + async def _save_scores_background(self, message_id: int, ml_scores_json: str) -> None: + """Сохраняет скоры в БД в фоне.""" + if ml_scores_json: + try: + await self.db.update_ml_scores(message_id, ml_scores_json) + except Exception as e: + logger.error(f"PostService: Ошибка сохранения скоров для {message_id}: {e}") + @track_time("handle_text_post", "post_service") @track_errors("post_service", "handle_text_post") @db_query_time("handle_text_post", "posts", "insert") async def handle_text_post(self, message: types.Message, first_name: str) -> None: """Handle text post submission""" - post_text = get_text_message(message.text.lower(), first_name, message.from_user.username) + raw_text = message.text or "" + + # Получаем скоры для текста + deepseek_score, rag_score, rag_confidence, rag_score_pos_only, ml_scores_json = await self._get_scores(raw_text) + + # Формируем текст с учетом скоров + post_text = get_text_message( + message.text.lower(), + first_name, + message.from_user.username, + deepseek_score=deepseek_score, + rag_score=rag_score, + rag_confidence=rag_confidence, + rag_score_pos_only=rag_score_pos_only, + ) markup = get_reply_keyboard_for_post() sent_message = await send_text_message(self.settings.group_for_posts, message, post_text, markup) - # Сохраняем сырой текст и определяем анонимность - raw_text = message.text or "" + # Определяем анонимность is_anonymous = determine_anonymity(raw_text) post = TelegramPost( @@ -164,23 +212,39 @@ class PostService: is_anonymous=is_anonymous ) await self.db.add_post(post) + + # Сохраняем скоры в фоне + if ml_scores_json: + asyncio.create_task(self._save_scores_background(sent_message.message_id, ml_scores_json)) @track_time("handle_photo_post", "post_service") @track_errors("post_service", "handle_photo_post") @db_query_time("handle_photo_post", "posts", "insert") async def handle_photo_post(self, message: types.Message, first_name: str) -> None: """Handle photo post submission""" + raw_caption = message.caption or "" + + # Получаем скоры для текста + deepseek_score, rag_score, rag_confidence, rag_score_pos_only, ml_scores_json = await self._get_scores(raw_caption) + post_caption = "" if message.caption: - post_caption = get_text_message(message.caption.lower(), first_name, message.from_user.username) + post_caption = get_text_message( + message.caption.lower(), + first_name, + message.from_user.username, + deepseek_score=deepseek_score, + rag_score=rag_score, + rag_confidence=rag_confidence, + rag_score_pos_only=rag_score_pos_only, + ) markup = get_reply_keyboard_for_post() sent_message = await send_photo_message( self.settings.group_for_posts, message, message.photo[-1].file_id, post_caption, markup ) - # Сохраняем сырой caption и определяем анонимность - raw_caption = message.caption or "" + # Определяем анонимность is_anonymous = determine_anonymity(raw_caption) post = TelegramPost( @@ -191,25 +255,40 @@ class PostService: is_anonymous=is_anonymous ) await self.db.add_post(post) - # Сохраняем медиа в фоне, чтобы не блокировать ответ пользователю + + # Сохраняем медиа и скоры в фоне asyncio.create_task(self._save_media_background(sent_message, self.db, self.s3_storage)) + if ml_scores_json: + asyncio.create_task(self._save_scores_background(sent_message.message_id, ml_scores_json)) @track_time("handle_video_post", "post_service") @track_errors("post_service", "handle_video_post") @db_query_time("handle_video_post", "posts", "insert") async def handle_video_post(self, message: types.Message, first_name: str) -> None: """Handle video post submission""" + raw_caption = message.caption or "" + + # Получаем скоры для текста + deepseek_score, rag_score, rag_confidence, rag_score_pos_only, ml_scores_json = await self._get_scores(raw_caption) + post_caption = "" if message.caption: - post_caption = get_text_message(message.caption.lower(), first_name, message.from_user.username) + post_caption = get_text_message( + message.caption.lower(), + first_name, + message.from_user.username, + deepseek_score=deepseek_score, + rag_score=rag_score, + rag_confidence=rag_confidence, + rag_score_pos_only=rag_score_pos_only, + ) markup = get_reply_keyboard_for_post() sent_message = await send_video_message( self.settings.group_for_posts, message, message.video.file_id, post_caption, markup ) - # Сохраняем сырой caption и определяем анонимность - raw_caption = message.caption or "" + # Определяем анонимность is_anonymous = determine_anonymity(raw_caption) post = TelegramPost( @@ -220,8 +299,11 @@ class PostService: is_anonymous=is_anonymous ) await self.db.add_post(post) - # Сохраняем медиа в фоне, чтобы не блокировать ответ пользователю + + # Сохраняем медиа и скоры в фоне asyncio.create_task(self._save_media_background(sent_message, self.db, self.s3_storage)) + if ml_scores_json: + asyncio.create_task(self._save_scores_background(sent_message.message_id, ml_scores_json)) @track_time("handle_video_note_post", "post_service") @track_errors("post_service", "handle_video_note_post") @@ -253,17 +335,29 @@ class PostService: @db_query_time("handle_audio_post", "posts", "insert") async def handle_audio_post(self, message: types.Message, first_name: str) -> None: """Handle audio post submission""" + raw_caption = message.caption or "" + + # Получаем скоры для текста + deepseek_score, rag_score, rag_confidence, rag_score_pos_only, ml_scores_json = await self._get_scores(raw_caption) + post_caption = "" if message.caption: - post_caption = get_text_message(message.caption.lower(), first_name, message.from_user.username) + post_caption = get_text_message( + message.caption.lower(), + first_name, + message.from_user.username, + deepseek_score=deepseek_score, + rag_score=rag_score, + rag_confidence=rag_confidence, + rag_score_pos_only=rag_score_pos_only, + ) markup = get_reply_keyboard_for_post() sent_message = await send_audio_message( self.settings.group_for_posts, message, message.audio.file_id, post_caption, markup ) - # Сохраняем сырой caption и определяем анонимность - raw_caption = message.caption or "" + # Определяем анонимность is_anonymous = determine_anonymity(raw_caption) post = TelegramPost( @@ -274,8 +368,11 @@ class PostService: is_anonymous=is_anonymous ) await self.db.add_post(post) - # Сохраняем медиа в фоне, чтобы не блокировать ответ пользователю + + # Сохраняем медиа и скоры в фоне asyncio.create_task(self._save_media_background(sent_message, self.db, self.s3_storage)) + if ml_scores_json: + asyncio.create_task(self._save_scores_background(sent_message.message_id, ml_scores_json)) @track_time("handle_voice_post", "post_service") @track_errors("post_service", "handle_voice_post") @@ -310,10 +407,23 @@ class PostService: """Handle media group post submission""" post_caption = " " raw_caption = "" + ml_scores_json = None if album and album[0].caption: raw_caption = album[0].caption or "" - post_caption = get_text_message(album[0].caption.lower(), first_name, message.from_user.username) + + # Получаем скоры для текста + deepseek_score, rag_score, rag_confidence, rag_score_pos_only, ml_scores_json = await self._get_scores(raw_caption) + + post_caption = get_text_message( + album[0].caption.lower(), + first_name, + message.from_user.username, + deepseek_score=deepseek_score, + rag_score=rag_score, + rag_confidence=rag_confidence, + rag_score_pos_only=rag_score_pos_only, + ) is_anonymous = determine_anonymity(raw_caption) media_group = await prepare_media_group_from_middlewares(album, post_caption) @@ -333,6 +443,10 @@ class PostService: ) await self.db.add_post(main_post) + # Сохраняем скоры в фоне + if ml_scores_json: + asyncio.create_task(self._save_scores_background(main_post_id, ml_scores_json)) + for msg_id in media_group_message_ids: await self.db.add_message_link(main_post_id, msg_id) diff --git a/helper_bot/keyboards/keyboards.py b/helper_bot/keyboards/keyboards.py index aeac9b8..3fd4f3c 100644 --- a/helper_bot/keyboards/keyboards.py +++ b/helper_bot/keyboards/keyboards.py @@ -47,6 +47,9 @@ def get_reply_keyboard_admin(): ) builder.row( types.KeyboardButton(text="Разбан (список)"), + types.KeyboardButton(text="📊 ML Статистика") + ) + builder.row( types.KeyboardButton(text="Вернуться в бота") ) markup = builder.as_markup(resize_keyboard=True, one_time_keyboard=True) diff --git a/helper_bot/main.py b/helper_bot/main.py index 0c1da15..b710d47 100644 --- a/helper_bot/main.py +++ b/helper_bot/main.py @@ -78,6 +78,22 @@ async def start_bot(bdf): await bot.delete_webhook(drop_pending_updates=True) + # Загружаем примеры для RAG из базы данных + scoring_manager = bdf.get_scoring_manager() + if scoring_manager and scoring_manager.rag_service and scoring_manager.rag_service.is_enabled: + try: + db = bdf.get_db() + positive_texts = await db.get_approved_posts_texts(limit=5000) + negative_texts = await db.get_declined_posts_texts(limit=5000) + + if positive_texts or negative_texts: + await scoring_manager.load_examples_from_db(positive_texts, negative_texts) + logging.info(f"RAG: Загружено {len(positive_texts)} положительных и {len(negative_texts)} отрицательных примеров") + else: + logging.warning("RAG: Нет примеров в базе данных для загрузки") + except Exception as e: + logging.error(f"Ошибка загрузки примеров для RAG: {e}") + # Запускаем HTTP сервер для метрик параллельно с ботом metrics_host = bdf.settings.get('Metrics', {}).get('host', '0.0.0.0') metrics_port = bdf.settings.get('Metrics', {}).get('port', 8080) diff --git a/helper_bot/services/__init__.py b/helper_bot/services/__init__.py new file mode 100644 index 0000000..50a732e --- /dev/null +++ b/helper_bot/services/__init__.py @@ -0,0 +1,5 @@ +""" +Сервисы приложения. + +Содержит бизнес-логику, не связанную напрямую с handlers. +""" diff --git a/helper_bot/services/scoring/__init__.py b/helper_bot/services/scoring/__init__.py new file mode 100644 index 0000000..a56b7fe --- /dev/null +++ b/helper_bot/services/scoring/__init__.py @@ -0,0 +1,42 @@ +""" +Сервисы для ML-скоринга постов. + +Включает: +- RAGService - локальное векторное сравнение с ruBERT +- DeepSeekService - интеграция с DeepSeek API +- ScoringManager - объединение всех сервисов скоринга +- VectorStore - in-memory хранилище векторов +""" + +from .base import ScoringResult, ScoringServiceProtocol, CombinedScore +from .exceptions import ( + ScoringError, + ModelNotLoadedError, + VectorStoreError, + DeepSeekAPIError, + InsufficientExamplesError, + TextTooShortError, +) +from .vector_store import VectorStore +from .rag_service import RAGService +from .deepseek_service import DeepSeekService +from .scoring_manager import ScoringManager + +__all__ = [ + # Базовые классы + "ScoringResult", + "ScoringServiceProtocol", + "CombinedScore", + # Исключения + "ScoringError", + "ModelNotLoadedError", + "VectorStoreError", + "DeepSeekAPIError", + "InsufficientExamplesError", + "TextTooShortError", + # Сервисы + "VectorStore", + "RAGService", + "DeepSeekService", + "ScoringManager", +] diff --git a/helper_bot/services/scoring/base.py b/helper_bot/services/scoring/base.py new file mode 100644 index 0000000..748afa2 --- /dev/null +++ b/helper_bot/services/scoring/base.py @@ -0,0 +1,155 @@ +""" +Базовые классы и протоколы для сервисов скоринга. +""" + +from dataclasses import dataclass, field +from typing import Optional, Protocol, Dict, Any +from datetime import datetime + + +@dataclass +class ScoringResult: + """ + Результат оценки поста от одного сервиса. + + Attributes: + score: Оценка от 0.0 до 1.0 (вероятность публикации) + source: Источник оценки ("deepseek", "rag", etc.) + model: Название используемой модели + confidence: Уверенность в оценке (опционально) + timestamp: Время получения оценки + metadata: Дополнительные данные + """ + score: float + source: str + model: str + confidence: Optional[float] = None + timestamp: int = field(default_factory=lambda: int(datetime.now().timestamp())) + metadata: Dict[str, Any] = field(default_factory=dict) + + def __post_init__(self): + """Валидация score в диапазоне [0.0, 1.0].""" + if not 0.0 <= self.score <= 1.0: + raise ValueError(f"Score должен быть в диапазоне [0.0, 1.0], получено: {self.score}") + + def to_dict(self) -> Dict[str, Any]: + """Преобразует результат в словарь для сохранения в JSON.""" + result = { + "score": round(self.score, 4), + "model": self.model, + "ts": self.timestamp, + } + if self.confidence is not None: + result["confidence"] = round(self.confidence, 4) + if self.metadata: + result["metadata"] = self.metadata + return result + + @classmethod + def from_dict(cls, source: str, data: Dict[str, Any]) -> "ScoringResult": + """Создает ScoringResult из словаря.""" + return cls( + score=data["score"], + source=source, + model=data.get("model", "unknown"), + confidence=data.get("confidence"), + timestamp=data.get("ts", int(datetime.now().timestamp())), + metadata=data.get("metadata", {}), + ) + + +@dataclass +class CombinedScore: + """ + Объединенный результат от всех сервисов скоринга. + + Attributes: + deepseek: Результат от DeepSeek API (None если отключен/ошибка) + rag: Результат от RAG сервиса (None если отключен/ошибка) + errors: Словарь с ошибками по источникам + """ + deepseek: Optional[ScoringResult] = None + rag: Optional[ScoringResult] = None + errors: Dict[str, str] = field(default_factory=dict) + + @property + def deepseek_score(self) -> Optional[float]: + """Возвращает только числовой скор от DeepSeek.""" + return self.deepseek.score if self.deepseek else None + + @property + def rag_score(self) -> Optional[float]: + """Возвращает только числовой скор от RAG.""" + return self.rag.score if self.rag else None + + def to_json_dict(self) -> Dict[str, Any]: + """ + Преобразует в словарь для сохранения в ml_scores колонку. + + Формат: + { + "deepseek": {"score": 0.75, "model": "...", "ts": ...}, + "rag": {"score": 0.90, "model": "...", "ts": ...} + } + """ + result = {} + if self.deepseek: + result["deepseek"] = self.deepseek.to_dict() + if self.rag: + result["rag"] = self.rag.to_dict() + return result + + def has_any_score(self) -> bool: + """Проверяет, есть ли хотя бы один успешный скор.""" + return self.deepseek is not None or self.rag is not None + + +class ScoringServiceProtocol(Protocol): + """ + Протокол для сервисов скоринга. + + Любой сервис скоринга должен реализовывать эти методы. + """ + + @property + def source_name(self) -> str: + """Возвращает имя источника ("deepseek", "rag", etc.).""" + ... + + @property + def is_enabled(self) -> bool: + """Проверяет, включен ли сервис.""" + ... + + async def calculate_score(self, text: str) -> ScoringResult: + """ + Рассчитывает скор для текста поста. + + Args: + text: Текст поста для оценки + + Returns: + ScoringResult с оценкой + + Raises: + ScoringError: При ошибке расчета + """ + ... + + async def add_positive_example(self, text: str) -> None: + """ + Добавляет текст как положительный пример (опубликованный пост). + + Args: + text: Текст опубликованного поста + """ + ... + + async def add_negative_example(self, text: str) -> None: + """ + Добавляет текст как отрицательный пример (отклоненный пост). + + Args: + text: Текст отклоненного поста + """ + ... diff --git a/helper_bot/services/scoring/deepseek_service.py b/helper_bot/services/scoring/deepseek_service.py new file mode 100644 index 0000000..de45835 --- /dev/null +++ b/helper_bot/services/scoring/deepseek_service.py @@ -0,0 +1,358 @@ +""" +DeepSeek API сервис для скоринга постов. + +Использует DeepSeek API для семантической оценки релевантности поста. +""" + +import asyncio +import json +from typing import Optional, List + +import httpx + +from logs.custom_logger import logger +from helper_bot.utils.metrics import track_time, track_errors + +from .base import ScoringResult +from .exceptions import DeepSeekAPIError, ScoringError, TextTooShortError + + +class DeepSeekService: + """ + Сервис для оценки постов через DeepSeek API. + + Отправляет текст поста в DeepSeek с промптом для оценки + и получает числовой скор релевантности. + + Attributes: + api_key: API ключ DeepSeek + api_url: URL API эндпоинта + model: Название модели + timeout: Таймаут запроса в секундах + """ + + # Промпт для оценки поста + SCORING_PROMPT = """Роль: Ты — строгий и внимательный модератор сообщества в социальной сети, ориентированного на знакомства между людьми. Твоя задача — оценить, можно ли опубликовать пост, основываясь на четких правилах. + +Контекст группы: Это группа для поиска и знакомства с людьми. Пользователи могут искать кого угодно: случайно увиденных на улице, в транспорте, в кафе, старых знакомых, новых друзей или пару. Это главная и единственная цель группы. + +--- + +ПРАВИЛА ЗАПРЕТА (пост НЕ ДОЛЖЕН быть опубликован, если содержит это): + +1. Запрещенные законом тематики: Любые призывы, обсуждение или поиск чего-либо незаконного (наркотики, оружие, мошенничество, насилие и т.д.). +2. Поиск и утеря животных, найденные предметы: Запрещены посты про потерявшихся/найденных кошек, собак, хомяков, а также про потерянные/найденные телефоны, ключи, сумки и т.п. +3. Конкуренция (Дайвинчик): Любое упоминание группы/проекта/чата "Дайвинчик" или любых других групп-конкурентов. Запрещены призывы переходить в другие сообщества. +4. Сбор больших компаний и групп: Запрещены посты с целью собрать большую тусовку, компанию, группу для похода, вечеринки, игры и т.д. (например, "собираем команду для футбола", "кто хочет на квартиру?"). +5. Организация чатов и других сообществ: Запрещено создание или реклама сторонних чатов, каналов, групп в телеграме, дискорде и т.п. + +--- + +ПРАВИЛА РАЗРЕШЕНИЯ (пост МОЖЕТ быть опубликован, если): + +· Цель — найти конкретного человека или познакомиться с кем-то новым. +· Формат: Описание человека, обстоятельств встречи, примет, места и времени. Или прямой призыв к знакомству. +· Примеры ДОПУСТИМЫХ постов (ориентируйся на них): + · "мальчики нефоры/патлатые, гоу знакомиться😻 анон" + · "ищу девочку, ехала на 21 автобусе примерно в 15:20. села на детской поликлинике и вышла в заречье вся в черной одежде и с черным баулом" + · "ищу мальчика ехали на 35 автобусе часов в 7 вечера я была с девочками,у нас с тобой еще куртки одинаковые ,я рядом с тобой сидела,напиши в комментарии если у тебя нету девочки. анон админу любви." + +--- + +ИНСТРУКЦИЯ ПО ОЦЕНКЕ: + +Проанализируй полученный пост и присвой ему итоговый Вес (Score) от 0.0 до 1.0, где: + +· 1.0 — Пост полностью соответствует правилам. Цель — найти/познакомиться с человеком. Ничего из списка запретов не нарушено. Можно публиковать. +· 0.0 — Пост категорически нарушает правила. Содержит явные признаки одного или нескольких пунктов из списка запрета. Публиковать НЕЛЬЗЯ. +· 0.2 - 0.8 — Пост находится в "серой зоне". Присваивай промежуточный вес, оценивая степень риска и соответствия цели группы. + · Ближе к 0.2: Сильно сомнительный пост, есть явные признаки запрещенной темы (например, упоминание "собраться компанией", косвенная реклама другого места). + · 0.5: Нейтральный или неочевидный пост. Нужно проверить, нет ли скрытого смысла, нарушающего правила. + · Ближе к 0.8: В целом допустимый пост, но с небольшими странностями или двусмысленностями, не нарушающими правила напрямую. +--- +{text} +--- + +Ответь ТОЛЬКО числом от 0.0 до 1.0, без дополнительных объяснений. +Пример ответа: 0.75""" + + DEFAULT_API_URL = "https://api.deepseek.com/v1/chat/completions" + DEFAULT_MODEL = "deepseek-chat" + + def __init__( + self, + api_key: Optional[str] = None, + api_url: Optional[str] = None, + model: Optional[str] = None, + timeout: int = 30, + enabled: bool = True, + min_text_length: int = 3, + max_retries: int = 3, + ): + """ + Инициализация DeepSeek сервиса. + + Args: + api_key: API ключ DeepSeek + api_url: URL API эндпоинта + model: Название модели + timeout: Таймаут запроса в секундах + enabled: Включен ли сервис + min_text_length: Минимальная длина текста для обработки + max_retries: Максимальное количество повторных попыток + """ + self.api_key = api_key + self.api_url = api_url or self.DEFAULT_API_URL + self.model = model or self.DEFAULT_MODEL + self.timeout = timeout + self._enabled = enabled and bool(api_key) + self.min_text_length = min_text_length + self.max_retries = max_retries + + # HTTP клиент (создается лениво) + self._client: Optional[httpx.AsyncClient] = None + + if not api_key and enabled: + logger.warning("DeepSeekService: API ключ не указан, сервис отключен") + self._enabled = False + + logger.info( + f"DeepSeekService инициализирован " + f"(model={self.model}, enabled={self._enabled})" + ) + + @property + def source_name(self) -> str: + """Имя источника для результатов.""" + return "deepseek" + + @property + def is_enabled(self) -> bool: + """Проверяет, включен ли сервис.""" + return self._enabled + + async def _get_client(self) -> httpx.AsyncClient: + """Получает или создает HTTP клиент.""" + if self._client is None: + self._client = httpx.AsyncClient( + timeout=httpx.Timeout(self.timeout), + headers={ + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + }, + ) + return self._client + + async def close(self) -> None: + """Закрывает HTTP клиент.""" + if self._client: + await self._client.aclose() + self._client = None + + def _clean_text(self, text: str) -> str: + """Очищает текст от лишних символов.""" + if not text: + return "" + + # Удаляем лишние пробелы и переносы строк + clean = " ".join(text.split()) + + # Удаляем служебные символы + if clean == "^": + return "" + + return clean.strip() + + def _parse_score_response(self, response_text: str) -> float: + """ + Парсит ответ от DeepSeek и извлекает скор. + + Args: + response_text: Текст ответа от API + + Returns: + Числовой скор от 0.0 до 1.0 + + Raises: + DeepSeekAPIError: Если не удалось распарсить ответ + """ + try: + # Пытаемся найти число в ответе + text = response_text.strip() + + # Убираем возможные обрамления + text = text.strip('"\'`') + + # Пробуем распарсить как число + score = float(text) + + # Ограничиваем диапазон + score = max(0.0, min(1.0, score)) + + return score + + except ValueError: + # Пробуем найти число в тексте + import re + matches = re.findall(r'0\.\d+|1\.0|0|1', text) + if matches: + score = float(matches[0]) + return max(0.0, min(1.0, score)) + + logger.error(f"DeepSeekService: Не удалось распарсить ответ: {response_text}") + raise DeepSeekAPIError(f"Не удалось распарсить скор из ответа: {response_text}") + + @track_time("calculate_score", "deepseek_service") + @track_errors("deepseek_service", "calculate_score") + async def calculate_score(self, text: str) -> ScoringResult: + """ + Рассчитывает скор для текста поста через DeepSeek API. + + Args: + text: Текст поста для оценки + + Returns: + ScoringResult с оценкой + + Raises: + ScoringError: При ошибке расчета + """ + if not self._enabled: + raise ScoringError("DeepSeek сервис отключен") + + # Очищаем текст + clean_text = self._clean_text(text) + + if len(clean_text) < self.min_text_length: + raise TextTooShortError( + f"Текст слишком короткий (минимум {self.min_text_length} символов)" + ) + + # Формируем промпт + prompt = self.SCORING_PROMPT.format(text=clean_text) + + # Выполняем запрос с повторными попытками + last_error = None + for attempt in range(self.max_retries): + try: + score = await self._make_api_request(prompt) + + return ScoringResult( + score=score, + source=self.source_name, + model=self.model, + metadata={ + "text_length": len(clean_text), + "attempt": attempt + 1, + }, + ) + + except DeepSeekAPIError as e: + last_error = e + logger.warning( + f"DeepSeekService: Попытка {attempt + 1}/{self.max_retries} " + f"не удалась: {e}" + ) + if attempt < self.max_retries - 1: + # Экспоненциальная задержка + await asyncio.sleep(2 ** attempt) + + raise ScoringError(f"Все попытки запроса к DeepSeek API не удались: {last_error}") + + async def _make_api_request(self, prompt: str) -> float: + """ + Выполняет запрос к DeepSeek API. + + Args: + prompt: Промпт для отправки + + Returns: + Числовой скор от 0.0 до 1.0 + + Raises: + DeepSeekAPIError: При ошибке API + """ + client = await self._get_client() + + payload = { + "model": self.model, + "messages": [ + { + "role": "user", + "content": prompt, + } + ], + "temperature": 0.1, # Низкая температура для детерминированности + "max_tokens": 10, # Ожидаем только число + } + + try: + response = await client.post(self.api_url, json=payload) + response.raise_for_status() + + data = response.json() + + # Извлекаем ответ + if "choices" not in data or not data["choices"]: + raise DeepSeekAPIError("Пустой ответ от API") + + response_text = data["choices"][0]["message"]["content"] + + # Парсим скор + score = self._parse_score_response(response_text) + + logger.debug(f"DeepSeekService: Получен скор {score} для текста") + return score + + except httpx.HTTPStatusError as e: + error_msg = f"HTTP ошибка {e.response.status_code}" + try: + error_data = e.response.json() + if "error" in error_data: + error_msg = error_data["error"].get("message", error_msg) + except Exception: + pass + raise DeepSeekAPIError(error_msg) + + except httpx.TimeoutException: + raise DeepSeekAPIError(f"Таймаут запроса ({self.timeout}s)") + + except Exception as e: + raise DeepSeekAPIError(f"Ошибка запроса: {e}") + + async def add_positive_example(self, text: str) -> None: + """ + Добавляет текст как положительный пример. + + Для DeepSeek не требуется хранить примеры - оценка выполняется + на основе промпта. Метод существует для совместимости с протоколом. + + Args: + text: Текст опубликованного поста + """ + # DeepSeek не использует примеры для обучения + # Промпт уже содержит критерии оценки + pass + + async def add_negative_example(self, text: str) -> None: + """ + Добавляет текст как отрицательный пример. + + Для DeepSeek не требуется хранить примеры - оценка выполняется + на основе промпта. Метод существует для совместимости с протоколом. + + Args: + text: Текст отклоненного поста + """ + # DeepSeek не использует примеры для обучения + pass + + def get_stats(self) -> dict: + """Возвращает статистику сервиса.""" + return { + "enabled": self._enabled, + "model": self.model, + "api_url": self.api_url, + "timeout": self.timeout, + "max_retries": self.max_retries, + } diff --git a/helper_bot/services/scoring/exceptions.py b/helper_bot/services/scoring/exceptions.py new file mode 100644 index 0000000..8af309c --- /dev/null +++ b/helper_bot/services/scoring/exceptions.py @@ -0,0 +1,33 @@ +""" +Исключения для сервисов скоринга. +""" + + +class ScoringError(Exception): + """Базовое исключение для ошибок скоринга.""" + pass + + +class ModelNotLoadedError(ScoringError): + """Модель не загружена или недоступна.""" + pass + + +class VectorStoreError(ScoringError): + """Ошибка при работе с хранилищем векторов.""" + pass + + +class DeepSeekAPIError(ScoringError): + """Ошибка при обращении к DeepSeek API.""" + pass + + +class InsufficientExamplesError(ScoringError): + """Недостаточно примеров для расчета скора.""" + pass + + +class TextTooShortError(ScoringError): + """Текст слишком короткий для векторизации.""" + pass diff --git a/helper_bot/services/scoring/rag_service.py b/helper_bot/services/scoring/rag_service.py new file mode 100644 index 0000000..0c02272 --- /dev/null +++ b/helper_bot/services/scoring/rag_service.py @@ -0,0 +1,507 @@ +""" +RAG сервис для скоринга постов с использованием ruBERT. + +Использует модель DeepPavlov/rubert-base-cased для создания эмбеддингов +и сравнивает их с эталонными примерами через VectorStore. +""" + +import asyncio +from typing import Optional, List + +import numpy as np + +from logs.custom_logger import logger +from helper_bot.utils.metrics import track_time, track_errors + +from .base import ScoringResult +from .vector_store import VectorStore +from .exceptions import ( + ModelNotLoadedError, + ScoringError, + InsufficientExamplesError, + TextTooShortError, +) + + +class RAGService: + """ + RAG сервис для оценки постов на основе векторного сходства. + + Использует ruBERT для создания эмбеддингов текста и сравнивает + их с эталонными примерами (опубликованные vs отклоненные посты). + + Attributes: + model_name: Название модели HuggingFace + vector_store: Хранилище векторов + min_text_length: Минимальная длина текста для обработки + """ + + # Название модели по умолчанию + DEFAULT_MODEL = "DeepPavlov/rubert-base-cased" + + def __init__( + self, + model_name: Optional[str] = None, + vector_store: Optional[VectorStore] = None, + cache_dir: Optional[str] = None, + enabled: bool = True, + min_text_length: int = 3, + ): + """ + Инициализация RAG сервиса. + + Args: + model_name: Название модели HuggingFace (по умолчанию ruBERT) + vector_store: Хранилище векторов (создается автоматически если не передано) + cache_dir: Директория для кеширования модели + enabled: Включен ли сервис + min_text_length: Минимальная длина текста для обработки + """ + self.model_name = model_name or self.DEFAULT_MODEL + self.cache_dir = cache_dir + self._enabled = enabled + self.min_text_length = min_text_length + + # Модель и токенизатор загружаются лениво + self._model = None + self._tokenizer = None + self._model_loaded = False + + # Хранилище векторов + self.vector_store = vector_store or VectorStore() + + logger.info(f"RAGService инициализирован (model={self.model_name}, enabled={enabled})") + + @property + def source_name(self) -> str: + """Имя источника для результатов.""" + return "rag" + + @property + def is_enabled(self) -> bool: + """Проверяет, включен ли сервис.""" + return self._enabled + + @property + def is_model_loaded(self) -> bool: + """Проверяет, загружена ли модель.""" + return self._model_loaded + + async def load_model(self) -> None: + """ + Загружает модель и токенизатор. + + Выполняется асинхронно в отдельном потоке чтобы не блокировать event loop. + """ + if self._model_loaded: + return + + if not self._enabled: + logger.warning("RAGService: Сервис отключен, модель не загружается") + return + + logger.info(f"RAGService: Загрузка модели {self.model_name}...") + + try: + # Загрузка в отдельном потоке + loop = asyncio.get_event_loop() + await loop.run_in_executor(None, self._load_model_sync) + + self._model_loaded = True + logger.info(f"RAGService: Модель {self.model_name} успешно загружена") + + except Exception as e: + logger.error(f"RAGService: Ошибка загрузки модели: {e}") + raise ModelNotLoadedError(f"Не удалось загрузить модель {self.model_name}: {e}") + + def _load_model_sync(self) -> None: + """Синхронная загрузка модели (вызывается в executor).""" + logger.info("RAGService: Начало _load_model_sync, импорт transformers...") + from transformers import AutoTokenizer, AutoModel + import torch + + # Определяем устройство + self._device = "cuda" if torch.cuda.is_available() else "cpu" + logger.info(f"RAGService: Устройство определено: {self._device}") + + # Загружаем токенизатор + logger.info(f"RAGService: Загрузка токенизатора из {self.model_name}...") + self._tokenizer = AutoTokenizer.from_pretrained( + self.model_name, + cache_dir=self.cache_dir, + ) + logger.info("RAGService: Токенизатор загружен") + + # Загружаем модель + logger.info(f"RAGService: Загрузка модели из {self.model_name} (это может занять несколько минут)...") + self._model = AutoModel.from_pretrained( + self.model_name, + cache_dir=self.cache_dir, + ) + logger.info("RAGService: Модель загружена, перенос на устройство...") + self._model.to(self._device) + self._model.eval() # Режим инференса + + logger.info(f"RAGService: Модель готова на устройстве: {self._device}") + + def _get_embedding_sync(self, text: str) -> np.ndarray: + """ + Получает эмбеддинг текста (синхронно). + + Использует [CLS] токен как представление всего текста. + + Args: + text: Текст для векторизации + + Returns: + Numpy массив с эмбеддингом (768 измерений для ruBERT) + """ + import torch + + # Токенизация с ограничением длины + inputs = self._tokenizer( + text, + return_tensors="pt", + truncation=True, + max_length=512, + padding=True, + ) + inputs = {k: v.to(self._device) for k, v in inputs.items()} + + # Получаем эмбеддинг + with torch.no_grad(): + outputs = self._model(**inputs) + # Используем [CLS] токен (первый токен) + embedding = outputs.last_hidden_state[:, 0, :].cpu().numpy() + + return embedding.flatten() + + def _get_embeddings_batch_sync(self, texts: List[str], batch_size: int = 16) -> List[np.ndarray]: + """ + Получает эмбеддинги для батча текстов (синхронно). + + Обрабатывает тексты пачками для эффективного использования GPU/CPU. + + Args: + texts: Список текстов для векторизации + batch_size: Размер батча (по умолчанию 16) + + Returns: + Список numpy массивов с эмбеддингами + """ + import torch + + all_embeddings = [] + + for i in range(0, len(texts), batch_size): + batch_texts = texts[i:i + batch_size] + + # Токенизация батча + inputs = self._tokenizer( + batch_texts, + return_tensors="pt", + truncation=True, + max_length=512, + padding=True, + ) + inputs = {k: v.to(self._device) for k, v in inputs.items()} + + # Получаем эмбеддинги + with torch.no_grad(): + outputs = self._model(**inputs) + # [CLS] токен для каждого текста в батче + batch_embeddings = outputs.last_hidden_state[:, 0, :].cpu().numpy() + + # Разбиваем на отдельные эмбеддинги + for j in range(len(batch_texts)): + all_embeddings.append(batch_embeddings[j]) + + if i > 0 and i % (batch_size * 10) == 0: + logger.info(f"RAGService: Обработано {i}/{len(texts)} текстов") + + return all_embeddings + + async def get_embeddings_batch(self, texts: List[str], batch_size: int = 16) -> List[np.ndarray]: + """ + Получает эмбеддинги для батча текстов (асинхронно). + + Args: + texts: Список текстов для векторизации + batch_size: Размер батча + + Returns: + Список numpy массивов с эмбеддингами + """ + if not self._model_loaded: + await self.load_model() + + if not self._model_loaded: + raise ModelNotLoadedError("Модель не загружена") + + # Очищаем тексты + clean_texts = [self._clean_text(text) for text in texts] + + # Выполняем батч-обработку в thread pool + loop = asyncio.get_event_loop() + embeddings = await loop.run_in_executor( + None, + self._get_embeddings_batch_sync, + clean_texts, + batch_size, + ) + + return embeddings + + async def get_embedding(self, text: str) -> np.ndarray: + """ + Получает эмбеддинг текста (асинхронно). + + Args: + text: Текст для векторизации + + Returns: + Numpy массив с эмбеддингом + + Raises: + ModelNotLoadedError: Если модель не загружена + TextTooShortError: Если текст слишком короткий + """ + if not self._model_loaded: + await self.load_model() + + if not self._model_loaded: + raise ModelNotLoadedError("Модель не загружена") + + # Очищаем текст + clean_text = self._clean_text(text) + + if len(clean_text) < self.min_text_length: + raise TextTooShortError( + f"Текст слишком короткий (минимум {self.min_text_length} символов)" + ) + + # Выполняем в отдельном потоке + loop = asyncio.get_event_loop() + embedding = await loop.run_in_executor( + None, + self._get_embedding_sync, + clean_text + ) + + return embedding + + def _clean_text(self, text: str) -> str: + """Очищает текст от лишних символов.""" + if not text: + return "" + + # Удаляем лишние пробелы и переносы строк + clean = " ".join(text.split()) + + # Удаляем служебные символы (например "^" для helper сообщений) + if clean == "^": + return "" + + return clean.strip() + + @track_time("calculate_score", "rag_service") + @track_errors("rag_service", "calculate_score") + async def calculate_score(self, text: str) -> ScoringResult: + """ + Рассчитывает скор для текста поста. + + Args: + text: Текст поста для оценки + + Returns: + ScoringResult с оценкой + + Raises: + ScoringError: При ошибке расчета + """ + if not self._enabled: + raise ScoringError("RAG сервис отключен") + + try: + # Получаем эмбеддинг текста + embedding = await self.get_embedding(text) + + # Логируем первые элементы вектора для отладки + logger.info( + f"RAGService: embedding[:3]={embedding[:3].tolist()}, " + f"text_preview='{text[:30]}'" + ) + + # Рассчитываем скор через VectorStore + score, confidence, score_pos_only = self.vector_store.calculate_similarity_score(embedding) + + return ScoringResult( + score=score, + source=self.source_name, + model=self.model_name, + confidence=confidence, + metadata={ + "positive_examples": self.vector_store.positive_count, + "negative_examples": self.vector_store.negative_count, + "score_pos_only": score_pos_only, # Для сравнения + }, + ) + + except InsufficientExamplesError: + # Не достаточно примеров - возвращаем нейтральный скор + logger.warning("RAGService: Недостаточно примеров для расчета скора") + raise + + except TextTooShortError: + logger.warning(f"RAGService: Текст слишком короткий для оценки") + raise + + except Exception as e: + logger.error(f"RAGService: Ошибка расчета скора: {e}") + raise ScoringError(f"Ошибка расчета скора: {e}") + + @track_time("add_positive_example", "rag_service") + async def add_positive_example(self, text: str) -> None: + """ + Добавляет текст как положительный пример (опубликованный пост). + + Args: + text: Текст опубликованного поста + """ + if not self._enabled: + return + + try: + clean_text = self._clean_text(text) + if len(clean_text) < self.min_text_length: + logger.debug("RAGService: Текст слишком короткий для примера, пропускаем") + return + + # Получаем эмбеддинг + embedding = await self.get_embedding(clean_text) + + # Вычисляем хеш для дедупликации + text_hash = VectorStore.compute_text_hash(clean_text) + + # Добавляем в хранилище + added = self.vector_store.add_positive(embedding, text_hash) + + if added: + logger.info(f"RAGService: Добавлен положительный пример") + + except Exception as e: + logger.error(f"RAGService: Ошибка добавления положительного примера: {e}") + + @track_time("add_negative_example", "rag_service") + async def add_negative_example(self, text: str) -> None: + """ + Добавляет текст как отрицательный пример (отклоненный пост). + + Args: + text: Текст отклоненного поста + """ + if not self._enabled: + return + + try: + clean_text = self._clean_text(text) + if len(clean_text) < self.min_text_length: + logger.debug("RAGService: Текст слишком короткий для примера, пропускаем") + return + + # Получаем эмбеддинг + embedding = await self.get_embedding(clean_text) + + # Вычисляем хеш для дедупликации + text_hash = VectorStore.compute_text_hash(clean_text) + + # Добавляем в хранилище + added = self.vector_store.add_negative(embedding, text_hash) + + if added: + logger.info(f"RAGService: Добавлен отрицательный пример") + + except Exception as e: + logger.error(f"RAGService: Ошибка добавления отрицательного примера: {e}") + + async def load_examples_from_db( + self, + positive_texts: list[str], + negative_texts: list[str], + batch_size: int = 16, + ) -> None: + """ + Загружает примеры из базы данных с батч-обработкой. + + Используется при запуске бота для восстановления VectorStore. + Батч-обработка ускоряет загрузку в 10-20 раз. + + Args: + positive_texts: Список текстов опубликованных постов + negative_texts: Список текстов отклоненных постов + batch_size: Размер батча для обработки (по умолчанию 16) + """ + if not self._enabled: + return + + logger.info( + f"RAGService: Загрузка примеров из БД с батч-обработкой " + f"(positive: {len(positive_texts)}, negative: {len(negative_texts)}, batch_size: {batch_size})" + ) + + # Убеждаемся что модель загружена + await self.load_model() + + import time + start_time = time.time() + + # Фильтруем и очищаем положительные тексты + if positive_texts: + clean_positive = [] + positive_hashes = [] + for text in positive_texts: + clean_text = self._clean_text(text) + if len(clean_text) >= self.min_text_length: + clean_positive.append(clean_text) + positive_hashes.append(VectorStore.compute_text_hash(clean_text)) + + if clean_positive: + logger.info(f"RAGService: Обработка {len(clean_positive)} положительных примеров батчами...") + positive_embeddings = await self.get_embeddings_batch(clean_positive, batch_size) + self.vector_store.add_positive_batch(positive_embeddings, positive_hashes) + + # Фильтруем и очищаем отрицательные тексты + if negative_texts: + clean_negative = [] + negative_hashes = [] + for text in negative_texts: + clean_text = self._clean_text(text) + if len(clean_text) >= self.min_text_length: + clean_negative.append(clean_text) + negative_hashes.append(VectorStore.compute_text_hash(clean_text)) + + if clean_negative: + logger.info(f"RAGService: Обработка {len(clean_negative)} отрицательных примеров батчами...") + negative_embeddings = await self.get_embeddings_batch(clean_negative, batch_size) + self.vector_store.add_negative_batch(negative_embeddings, negative_hashes) + + elapsed = time.time() - start_time + logger.info( + f"RAGService: Загрузка завершена за {elapsed:.1f} сек " + f"(positive: {self.vector_store.positive_count}, " + f"negative: {self.vector_store.negative_count})" + ) + + def save_vectors(self) -> None: + """Сохраняет векторы на диск.""" + if self.vector_store.storage_path: + self.vector_store.save_to_disk() + + def get_stats(self) -> dict: + """Возвращает статистику сервиса.""" + return { + "enabled": self._enabled, + "model_name": self.model_name, + "model_loaded": self._model_loaded, + "vector_store": self.vector_store.get_stats(), + } diff --git a/helper_bot/services/scoring/scoring_manager.py b/helper_bot/services/scoring/scoring_manager.py new file mode 100644 index 0000000..1a9b7b3 --- /dev/null +++ b/helper_bot/services/scoring/scoring_manager.py @@ -0,0 +1,242 @@ +""" +Менеджер для объединения всех сервисов скоринга. + +Координирует работу RAGService и DeepSeekService, +выполняет параллельные запросы и агрегирует результаты. +""" + +import asyncio +from typing import Optional, List + +from logs.custom_logger import logger +from helper_bot.utils.metrics import track_time, track_errors + +from .base import CombinedScore, ScoringResult +from .rag_service import RAGService +from .deepseek_service import DeepSeekService +from .vector_store import VectorStore +from .exceptions import ScoringError, InsufficientExamplesError, TextTooShortError + + +class ScoringManager: + """ + Менеджер для управления всеми сервисами скоринга. + + Объединяет RAGService и DeepSeekService, выполняет параллельные + запросы и агрегирует результаты в единый CombinedScore. + + Attributes: + rag_service: Сервис RAG с ruBERT + deepseek_service: Сервис DeepSeek API + """ + + def __init__( + self, + rag_service: Optional[RAGService] = None, + deepseek_service: Optional[DeepSeekService] = None, + ): + """ + Инициализация менеджера. + + Args: + rag_service: Сервис RAG (создается автоматически если не передан) + deepseek_service: Сервис DeepSeek (создается автоматически если не передан) + """ + self.rag_service = rag_service + self.deepseek_service = deepseek_service + + logger.info( + f"ScoringManager инициализирован " + f"(rag={rag_service is not None and rag_service.is_enabled}, " + f"deepseek={deepseek_service is not None and deepseek_service.is_enabled})" + ) + + @property + def is_any_enabled(self) -> bool: + """Проверяет, включен ли хотя бы один сервис.""" + rag_enabled = self.rag_service is not None and self.rag_service.is_enabled + deepseek_enabled = self.deepseek_service is not None and self.deepseek_service.is_enabled + return rag_enabled or deepseek_enabled + + @track_time("score_post", "scoring_manager") + @track_errors("scoring_manager", "score_post") + async def score_post(self, text: str) -> CombinedScore: + """ + Рассчитывает скоры для текста поста от всех сервисов. + + Выполняет запросы параллельно для минимизации задержки. + + Args: + text: Текст поста для оценки + + Returns: + CombinedScore с результатами от всех сервисов + """ + result = CombinedScore() + + if not text or not text.strip(): + logger.debug("ScoringManager: Пустой текст, пропускаем скоринг") + return result + + # Собираем задачи для параллельного выполнения + tasks = [] + task_names = [] + + # RAG сервис + if self.rag_service and self.rag_service.is_enabled: + tasks.append(self._get_rag_score(text)) + task_names.append("rag") + + # DeepSeek сервис + if self.deepseek_service and self.deepseek_service.is_enabled: + tasks.append(self._get_deepseek_score(text)) + task_names.append("deepseek") + + if not tasks: + logger.debug("ScoringManager: Нет активных сервисов для скоринга") + return result + + # Выполняем параллельно + results = await asyncio.gather(*tasks, return_exceptions=True) + + # Обрабатываем результаты + for name, res in zip(task_names, results): + if isinstance(res, Exception): + error_msg = str(res) + result.errors[name] = error_msg + logger.warning(f"ScoringManager: Ошибка от {name}: {error_msg}") + elif res is not None: + if name == "rag": + result.rag = res + elif name == "deepseek": + result.deepseek = res + + logger.info( + f"ScoringManager: Скоринг завершен " + f"(rag={result.rag_score}, deepseek={result.deepseek_score})" + ) + + return result + + async def _get_rag_score(self, text: str) -> Optional[ScoringResult]: + """Получает скор от RAG сервиса.""" + try: + return await self.rag_service.calculate_score(text) + except InsufficientExamplesError: + # Недостаточно примеров - это не ошибка, просто нет данных + logger.info("ScoringManager: RAG - недостаточно примеров") + return None + except TextTooShortError: + # Текст слишком короткий - пропускаем + logger.debug("ScoringManager: RAG - текст слишком короткий") + return None + except Exception as e: + logger.error(f"ScoringManager: RAG ошибка: {e}") + raise + + async def _get_deepseek_score(self, text: str) -> Optional[ScoringResult]: + """Получает скор от DeepSeek сервиса.""" + try: + return await self.deepseek_service.calculate_score(text) + except TextTooShortError: + # Текст слишком короткий - пропускаем + logger.debug("ScoringManager: DeepSeek - текст слишком короткий") + return None + except Exception as e: + logger.error(f"ScoringManager: DeepSeek ошибка: {e}") + raise + + @track_time("on_post_published", "scoring_manager") + async def on_post_published(self, text: str) -> None: + """ + Вызывается при публикации поста. + + Добавляет текст как положительный пример для обучения RAG. + + Args: + text: Текст опубликованного поста + """ + if not text or not text.strip(): + return + + tasks = [] + + if self.rag_service and self.rag_service.is_enabled: + tasks.append(self.rag_service.add_positive_example(text)) + + if self.deepseek_service and self.deepseek_service.is_enabled: + tasks.append(self.deepseek_service.add_positive_example(text)) + + if tasks: + await asyncio.gather(*tasks, return_exceptions=True) + logger.info("ScoringManager: Добавлен положительный пример") + + @track_time("on_post_declined", "scoring_manager") + async def on_post_declined(self, text: str) -> None: + """ + Вызывается при отклонении поста. + + Добавляет текст как отрицательный пример для обучения RAG. + + Args: + text: Текст отклоненного поста + """ + if not text or not text.strip(): + return + + tasks = [] + + if self.rag_service and self.rag_service.is_enabled: + tasks.append(self.rag_service.add_negative_example(text)) + + if self.deepseek_service and self.deepseek_service.is_enabled: + tasks.append(self.deepseek_service.add_negative_example(text)) + + if tasks: + await asyncio.gather(*tasks, return_exceptions=True) + logger.info("ScoringManager: Добавлен отрицательный пример") + + async def load_examples_from_db( + self, + positive_texts: List[str], + negative_texts: List[str], + ) -> None: + """ + Загружает примеры из базы данных при запуске бота. + + Args: + positive_texts: Список текстов опубликованных постов + negative_texts: Список текстов отклоненных постов + """ + if self.rag_service and self.rag_service.is_enabled: + await self.rag_service.load_examples_from_db( + positive_texts, + negative_texts + ) + + def save_vectors(self) -> None: + """Сохраняет векторы RAG на диск.""" + if self.rag_service: + self.rag_service.save_vectors() + + async def close(self) -> None: + """Закрывает ресурсы всех сервисов.""" + if self.deepseek_service: + await self.deepseek_service.close() + + # Сохраняем векторы перед закрытием + self.save_vectors() + + def get_stats(self) -> dict: + """Возвращает статистику всех сервисов.""" + stats = { + "any_enabled": self.is_any_enabled, + } + + if self.rag_service: + stats["rag"] = self.rag_service.get_stats() + + if self.deepseek_service: + stats["deepseek"] = self.deepseek_service.get_stats() + + return stats diff --git a/helper_bot/services/scoring/vector_store.py b/helper_bot/services/scoring/vector_store.py new file mode 100644 index 0000000..a0381b3 --- /dev/null +++ b/helper_bot/services/scoring/vector_store.py @@ -0,0 +1,399 @@ +""" +In-memory хранилище векторов на numpy. + +Хранит векторные представления постов для быстрого сравнения. +Поддерживает персистентность через сохранение/загрузку с диска. +""" + +import hashlib +import os +from pathlib import Path +from typing import Optional, Tuple, List +import threading + +import numpy as np + +from logs.custom_logger import logger +from .exceptions import VectorStoreError, InsufficientExamplesError + + +class VectorStore: + """ + In-memory хранилище векторов для RAG. + + Хранит отдельно положительные (опубликованные) и отрицательные (отклоненные) + примеры. Использует косинусное сходство для расчета скора. + + Attributes: + vector_dim: Размерность векторов (768 для ruBERT) + max_examples: Максимальное количество примеров каждого типа + """ + + def __init__( + self, + vector_dim: int = 768, + max_examples: int = 10000, + storage_path: Optional[str] = None, + score_multiplier: float = 5.0, + ): + """ + Инициализация хранилища. + + Args: + vector_dim: Размерность векторов + max_examples: Максимальное количество примеров каждого типа + storage_path: Путь для сохранения/загрузки векторов (опционально) + score_multiplier: Множитель для усиления разницы в скорах + """ + self.vector_dim = vector_dim + self.max_examples = max_examples + self.storage_path = storage_path + self.score_multiplier = score_multiplier + + # Инициализируем пустые массивы + # Используем список для динамического добавления, потом конвертируем в numpy + self._positive_vectors: list = [] + self._negative_vectors: list = [] + self._positive_hashes: list = [] # Хеши текстов для дедупликации + self._negative_hashes: list = [] + + # Lock для потокобезопасности + self._lock = threading.Lock() + + # Пытаемся загрузить сохраненные векторы + if storage_path and os.path.exists(storage_path): + self._load_from_disk() + + @property + def positive_count(self) -> int: + """Количество положительных примеров.""" + return len(self._positive_vectors) + + @property + def negative_count(self) -> int: + """Количество отрицательных примеров.""" + return len(self._negative_vectors) + + @property + def total_count(self) -> int: + """Общее количество примеров.""" + return self.positive_count + self.negative_count + + @staticmethod + def compute_text_hash(text: str) -> str: + """Вычисляет хеш текста для дедупликации.""" + return hashlib.md5(text.encode('utf-8')).hexdigest() + + def _normalize_vector(self, vector: np.ndarray) -> np.ndarray: + """Нормализует вектор для косинусного сходства.""" + norm = np.linalg.norm(vector) + if norm == 0: + return vector + return vector / norm + + def add_positive(self, vector: np.ndarray, text_hash: Optional[str] = None) -> bool: + """ + Добавляет положительный пример (опубликованный пост). + + Args: + vector: Векторное представление текста + text_hash: Хеш текста для дедупликации (опционально) + + Returns: + True если добавлен, False если дубликат или превышен лимит + """ + with self._lock: + # Проверяем дубликат по хешу + if text_hash and text_hash in self._positive_hashes: + logger.debug(f"VectorStore: Пропуск дубликата положительного примера") + return False + + # Проверяем лимит + if len(self._positive_vectors) >= self.max_examples: + # Удаляем самый старый пример (FIFO) + self._positive_vectors.pop(0) + self._positive_hashes.pop(0) + logger.debug("VectorStore: Удален старый положительный пример (лимит)") + + # Нормализуем и добавляем + normalized = self._normalize_vector(vector) + self._positive_vectors.append(normalized) + if text_hash: + self._positive_hashes.append(text_hash) + + logger.info(f"VectorStore: Добавлен положительный пример (всего: {self.positive_count})") + return True + + def add_positive_batch( + self, + vectors: List[np.ndarray], + text_hashes: Optional[List[str]] = None + ) -> int: + """ + Добавляет батч положительных примеров. + + Args: + vectors: Список векторов + text_hashes: Список хешей текстов для дедупликации + + Returns: + Количество добавленных примеров + """ + if text_hashes is None: + text_hashes = [None] * len(vectors) + + added = 0 + with self._lock: + for vector, text_hash in zip(vectors, text_hashes): + # Проверяем дубликат по хешу + if text_hash and text_hash in self._positive_hashes: + continue + + # Проверяем лимит + if len(self._positive_vectors) >= self.max_examples: + self._positive_vectors.pop(0) + self._positive_hashes.pop(0) + + # Нормализуем и добавляем + normalized = self._normalize_vector(vector) + self._positive_vectors.append(normalized) + if text_hash: + self._positive_hashes.append(text_hash) + added += 1 + + logger.info(f"VectorStore: Добавлено {added} положительных примеров батчем (всего: {self.positive_count})") + return added + + def add_negative(self, vector: np.ndarray, text_hash: Optional[str] = None) -> bool: + """ + Добавляет отрицательный пример (отклоненный пост). + + Args: + vector: Векторное представление текста + text_hash: Хеш текста для дедупликации (опционально) + + Returns: + True если добавлен, False если дубликат или превышен лимит + """ + with self._lock: + # Проверяем дубликат по хешу + if text_hash and text_hash in self._negative_hashes: + logger.debug(f"VectorStore: Пропуск дубликата отрицательного примера") + return False + + # Проверяем лимит + if len(self._negative_vectors) >= self.max_examples: + # Удаляем самый старый пример (FIFO) + self._negative_vectors.pop(0) + self._negative_hashes.pop(0) + logger.debug("VectorStore: Удален старый отрицательный пример (лимит)") + + # Нормализуем и добавляем + normalized = self._normalize_vector(vector) + self._negative_vectors.append(normalized) + if text_hash: + self._negative_hashes.append(text_hash) + + logger.info(f"VectorStore: Добавлен отрицательный пример (всего: {self.negative_count})") + return True + + def add_negative_batch( + self, + vectors: List[np.ndarray], + text_hashes: Optional[List[str]] = None + ) -> int: + """ + Добавляет батч отрицательных примеров. + + Args: + vectors: Список векторов + text_hashes: Список хешей текстов для дедупликации + + Returns: + Количество добавленных примеров + """ + if text_hashes is None: + text_hashes = [None] * len(vectors) + + added = 0 + with self._lock: + for vector, text_hash in zip(vectors, text_hashes): + # Проверяем дубликат по хешу + if text_hash and text_hash in self._negative_hashes: + continue + + # Проверяем лимит + if len(self._negative_vectors) >= self.max_examples: + self._negative_vectors.pop(0) + self._negative_hashes.pop(0) + + # Нормализуем и добавляем + normalized = self._normalize_vector(vector) + self._negative_vectors.append(normalized) + if text_hash: + self._negative_hashes.append(text_hash) + added += 1 + + logger.info(f"VectorStore: Добавлено {added} отрицательных примеров батчем (всего: {self.negative_count})") + return added + + def calculate_similarity_score(self, vector: np.ndarray) -> Tuple[float, float]: + """ + Рассчитывает скор на основе сходства с примерами. + + Алгоритм: + 1. Вычисляем среднее косинусное сходство с положительными примерами + 2. Вычисляем среднее косинусное сходство с отрицательными примерами + 3. Финальный скор = pos_sim / (pos_sim + neg_sim + eps) + + Args: + vector: Векторное представление нового поста + + Returns: + Tuple (score, confidence): + - score: Оценка от 0.0 до 1.0 + - confidence: Уверенность (зависит от количества примеров) + + Raises: + InsufficientExamplesError: Если недостаточно примеров + """ + with self._lock: + if self.positive_count == 0: + raise InsufficientExamplesError( + "Нет положительных примеров для сравнения" + ) + + # Нормализуем входной вектор + normalized = self._normalize_vector(vector) + + # Конвертируем в numpy массивы для быстрых вычислений + pos_matrix = np.array(self._positive_vectors) + + # Косинусное сходство с положительными примерами + # Для нормализованных векторов это просто скалярное произведение + pos_similarities = np.dot(pos_matrix, normalized) + pos_sim = float(np.mean(pos_similarities)) + + # Косинусное сходство с отрицательными примерами + if self.negative_count > 0: + neg_matrix = np.array(self._negative_vectors) + neg_similarities = np.dot(neg_matrix, normalized) + neg_sim = float(np.mean(neg_similarities)) + else: + # Если нет отрицательных примеров, используем нейтральное значение + neg_sim = pos_sim # Нейтральный скор = 0.5 + + # === Вариант 1: neg/pos (разница между положительными и отрицательными) === + diff = pos_sim - neg_sim + score_neg_pos = 0.5 + (diff * self.score_multiplier) + score_neg_pos = max(0.0, min(1.0, score_neg_pos)) + + # === Вариант 2: pos only (только положительные, топ-k ближайших) === + # Берём топ-5 ближайших положительных примеров + top_k = min(5, len(pos_similarities)) + top_k_sim = float(np.mean(np.sort(pos_similarities)[-top_k:])) + # Нормализуем: 0.85 -> 0.0, 0.95 -> 1.0 (типичный диапазон для BERT) + score_pos_only = (top_k_sim - 0.85) / 0.10 + score_pos_only = max(0.0, min(1.0, score_pos_only)) + + # Основной скор — neg/pos (можно будет переключить позже) + score = score_neg_pos + + # Confidence зависит от количества примеров (100% при 1000 примерах) + total_examples = self.positive_count + self.negative_count + confidence = min(1.0, total_examples / 1000) + + logger.info( + f"VectorStore: pos_sim={pos_sim:.4f}, neg_sim={neg_sim:.4f}, " + f"top_k_sim={top_k_sim:.4f}, score_neg_pos={score_neg_pos:.4f}, " + f"score_pos_only={score_pos_only:.4f}" + ) + + return score, confidence, score_pos_only + + def save_to_disk(self, path: Optional[str] = None) -> None: + """ + Сохраняет векторы на диск. + + Args: + path: Путь для сохранения (если не указан, используется storage_path) + """ + save_path = path or self.storage_path + if not save_path: + raise VectorStoreError("Путь для сохранения не указан") + + with self._lock: + # Создаем директорию если нужно + Path(save_path).parent.mkdir(parents=True, exist_ok=True) + + # Сохраняем в npz формате + np.savez_compressed( + save_path, + positive_vectors=np.array(self._positive_vectors) if self._positive_vectors else np.array([]), + negative_vectors=np.array(self._negative_vectors) if self._negative_vectors else np.array([]), + positive_hashes=np.array(self._positive_hashes, dtype=object), + negative_hashes=np.array(self._negative_hashes, dtype=object), + vector_dim=self.vector_dim, + max_examples=self.max_examples, + ) + + logger.info( + f"VectorStore: Сохранено на диск ({self.positive_count} pos, " + f"{self.negative_count} neg): {save_path}" + ) + + def _load_from_disk(self) -> None: + """Загружает векторы с диска.""" + if not self.storage_path or not os.path.exists(self.storage_path): + return + + try: + with self._lock: + data = np.load(self.storage_path, allow_pickle=True) + + # Загружаем векторы + pos_vectors = data.get('positive_vectors', np.array([])) + neg_vectors = data.get('negative_vectors', np.array([])) + + if pos_vectors.size > 0: + self._positive_vectors = list(pos_vectors) + if neg_vectors.size > 0: + self._negative_vectors = list(neg_vectors) + + # Загружаем хеши + pos_hashes = data.get('positive_hashes', np.array([])) + neg_hashes = data.get('negative_hashes', np.array([])) + + if pos_hashes.size > 0: + self._positive_hashes = list(pos_hashes) + if neg_hashes.size > 0: + self._negative_hashes = list(neg_hashes) + + logger.info( + f"VectorStore: Загружено с диска ({self.positive_count} pos, " + f"{self.negative_count} neg): {self.storage_path}" + ) + + except Exception as e: + logger.error(f"VectorStore: Ошибка загрузки с диска: {e}") + # Продолжаем с пустым хранилищем + + def clear(self) -> None: + """Очищает все векторы.""" + with self._lock: + self._positive_vectors.clear() + self._negative_vectors.clear() + self._positive_hashes.clear() + self._negative_hashes.clear() + logger.info("VectorStore: Хранилище очищено") + + def get_stats(self) -> dict: + """Возвращает статистику хранилища.""" + return { + "positive_count": self.positive_count, + "negative_count": self.negative_count, + "total_count": self.total_count, + "vector_dim": self.vector_dim, + "max_examples": self.max_examples, + "storage_path": self.storage_path, + } diff --git a/helper_bot/utils/base_dependency_factory.py b/helper_bot/utils/base_dependency_factory.py index fb2681b..82a0660 100644 --- a/helper_bot/utils/base_dependency_factory.py +++ b/helper_bot/utils/base_dependency_factory.py @@ -5,6 +5,7 @@ from typing import Optional from database.async_db import AsyncBotDB from dotenv import load_dotenv from helper_bot.utils.s3_storage import S3StorageService +from logs.custom_logger import logger class BaseDependencyFactory: @@ -15,6 +16,7 @@ class BaseDependencyFactory: load_dotenv(env_path) self.settings = {} + self._project_dir = project_dir database_path = os.getenv('DATABASE_PATH', 'database/tg-bot-database.db') if not os.path.isabs(database_path): @@ -24,6 +26,9 @@ class BaseDependencyFactory: self._load_settings_from_env() self._init_s3_storage() + + # ScoringManager инициализируется лениво + self._scoring_manager = None def _load_settings_from_env(self): """Загружает настройки из переменных окружения.""" @@ -59,6 +64,23 @@ class BaseDependencyFactory: 'bucket_name': os.getenv('S3_BUCKET_NAME', ''), 'region': os.getenv('S3_REGION', 'us-east-1') } + + # Настройки ML-скоринга + self.settings['Scoring'] = { + # RAG (ruBERT) + 'rag_enabled': self._parse_bool(os.getenv('RAG_ENABLED', 'false')), + 'rag_model': os.getenv('RAG_MODEL', 'DeepPavlov/rubert-base-cased'), + 'rag_cache_dir': os.getenv('RAG_CACHE_DIR', 'data/models'), + 'rag_vectors_path': os.getenv('RAG_VECTORS_PATH', 'data/vectors.npz'), + 'rag_max_examples': self._parse_int(os.getenv('RAG_MAX_EXAMPLES', '10000')), + 'rag_score_multiplier': self._parse_float(os.getenv('RAG_SCORE_MULTIPLIER', '5.0')), + # DeepSeek + 'deepseek_enabled': self._parse_bool(os.getenv('DEEPSEEK_ENABLED', 'false')), + 'deepseek_api_key': os.getenv('DEEPSEEK_API_KEY', ''), + 'deepseek_api_url': os.getenv('DEEPSEEK_API_URL', 'https://api.deepseek.com/v1/chat/completions'), + 'deepseek_model': os.getenv('DEEPSEEK_MODEL', 'deepseek-chat'), + 'deepseek_timeout': self._parse_int(os.getenv('DEEPSEEK_TIMEOUT', '30')), + } def _init_s3_storage(self): """Инициализирует S3StorageService если S3 включен.""" @@ -84,6 +106,13 @@ class BaseDependencyFactory: return int(value) except (ValueError, TypeError): return 0 + + def _parse_float(self, value: str) -> float: + """Парсит строковое значение в float.""" + try: + return float(value) + except (ValueError, TypeError): + return 0.0 def get_settings(self): return self.settings @@ -95,6 +124,100 @@ class BaseDependencyFactory: def get_s3_storage(self) -> Optional[S3StorageService]: """Возвращает S3StorageService если S3 включен, иначе None.""" return self.s3_storage + + def _init_scoring_manager(self): + """ + Инициализирует ScoringManager с RAG и DeepSeek сервисами. + + Вызывается лениво при первом обращении к get_scoring_manager(). + """ + from helper_bot.services.scoring import ( + ScoringManager, + RAGService, + DeepSeekService, + VectorStore, + ) + + scoring_config = self.settings['Scoring'] + + # Инициализация RAG сервиса + rag_service = None + if scoring_config['rag_enabled']: + # Путь к векторам + vectors_path = scoring_config['rag_vectors_path'] + if not os.path.isabs(vectors_path): + vectors_path = os.path.join(self._project_dir, vectors_path) + + # Путь к кешу моделей + cache_dir = scoring_config['rag_cache_dir'] + if not os.path.isabs(cache_dir): + cache_dir = os.path.join(self._project_dir, cache_dir) + + # Создаем директории если нужно + os.makedirs(os.path.dirname(vectors_path), exist_ok=True) + os.makedirs(cache_dir, exist_ok=True) + + # Создаем VectorStore + vector_store = VectorStore( + vector_dim=768, # ruBERT dimension + max_examples=scoring_config['rag_max_examples'], + storage_path=vectors_path, + score_multiplier=scoring_config['rag_score_multiplier'], + ) + + # Создаем RAGService + rag_service = RAGService( + model_name=scoring_config['rag_model'], + vector_store=vector_store, + cache_dir=cache_dir, + enabled=True, + ) + + logger.info(f"RAGService инициализирован: {scoring_config['rag_model']}") + + # Инициализация DeepSeek сервиса + deepseek_service = None + if scoring_config['deepseek_enabled'] and scoring_config['deepseek_api_key']: + deepseek_service = DeepSeekService( + api_key=scoring_config['deepseek_api_key'], + api_url=scoring_config['deepseek_api_url'], + model=scoring_config['deepseek_model'], + timeout=scoring_config['deepseek_timeout'], + enabled=True, + ) + logger.info(f"DeepSeekService инициализирован: {scoring_config['deepseek_model']}") + + # Создаем менеджер + self._scoring_manager = ScoringManager( + rag_service=rag_service, + deepseek_service=deepseek_service, + ) + + return self._scoring_manager + + def get_scoring_manager(self): + """ + Возвращает ScoringManager для ML-скоринга постов. + + Инициализируется лениво при первом вызове. + + Returns: + ScoringManager или None если скоринг полностью отключен + """ + if self._scoring_manager is None: + scoring_config = self.settings.get('Scoring', {}) + + # Проверяем, включен ли хотя бы один сервис + rag_enabled = scoring_config.get('rag_enabled', False) + deepseek_enabled = scoring_config.get('deepseek_enabled', False) + + if not rag_enabled and not deepseek_enabled: + logger.info("Scoring полностью отключен (RAG и DeepSeek disabled)") + return None + + self._init_scoring_manager() + + return self._scoring_manager _global_instance = None diff --git a/helper_bot/utils/helper_func.py b/helper_bot/utils/helper_func.py index 4412cef..2350a47 100644 --- a/helper_bot/utils/helper_func.py +++ b/helper_bot/utils/helper_func.py @@ -111,7 +111,16 @@ def determine_anonymity(post_text: str) -> bool: return False -def get_text_message(post_text: str, first_name: str, username: str = None, is_anonymous: Optional[bool] = None): +def get_text_message( + post_text: str, + first_name: str, + username: str = None, + is_anonymous: Optional[bool] = None, + deepseek_score: Optional[float] = None, + rag_score: Optional[float] = None, + rag_confidence: Optional[float] = None, + rag_score_pos_only: Optional[float] = None, +): """ Форматирует текст сообщения для публикации в зависимости от наличия ключевых слов "анон" и "неанон" или переданного параметра is_anonymous. @@ -121,6 +130,10 @@ def get_text_message(post_text: str, first_name: str, username: str = None, is_a first_name: Имя автора поста username: Юзернейм автора поста (может быть None) is_anonymous: Флаг анонимности (True - анонимно, False - не анонимно, None - legacy, определяется по тексту) + deepseek_score: Скор от DeepSeek API (0.0-1.0, опционально) + rag_score: Скор от RAG/ruBERT neg/pos (0.0-1.0, опционально) + rag_confidence: Уверенность RAG модели (0.0-1.0, зависит от количества примеров) + rag_score_pos_only: Скор RAG только по положительным примерам (0.0-1.0, опционально) Returns: str: - Сформированный текст сообщения. @@ -137,21 +150,37 @@ def get_text_message(post_text: str, first_name: str, username: str = None, is_a else: author_info = f"{first_name} (Ник не указан)" + # Формируем базовый текст # Если передан is_anonymous, используем его, иначе определяем по тексту (legacy) - # TODO: Уверен можно укоротить if is_anonymous is not None: if is_anonymous: - return f'{safe_post_text}\n\nПост опубликован анонимно' + final_text = f'{safe_post_text}\n\nПост опубликован анонимно' else: - return f'{safe_post_text}\n\nАвтор поста: {author_info}' + final_text = f'{safe_post_text}\n\nАвтор поста: {author_info}' else: # Legacy: определяем по тексту if "неанон" in post_text or "не анон" in post_text: - return f'{safe_post_text}\n\nАвтор поста: {author_info}' + final_text = f'{safe_post_text}\n\nАвтор поста: {author_info}' elif "анон" in post_text: - return f'{safe_post_text}\n\nПост опубликован анонимно' + final_text = f'{safe_post_text}\n\nПост опубликован анонимно' else: - return f'{safe_post_text}\n\nАвтор поста: {author_info}' + final_text = f'{safe_post_text}\n\nАвтор поста: {author_info}' + + # Добавляем блок со скорами если есть + if deepseek_score is not None or rag_score is not None or rag_score_pos_only is not None: + scores_lines = ["\n📊 Уверенность в одобрении:"] + if deepseek_score is not None: + scores_lines.append(f"DeepSeek: {deepseek_score:.2f}") + if rag_score is not None: + rag_line = f"RAG neg/pos: {rag_score:.2f}" + if rag_confidence is not None: + rag_line += f" (уверенность: {rag_confidence:.0%})" + scores_lines.append(rag_line) + if rag_score_pos_only is not None: + scores_lines.append(f"RAG pos only: {rag_score_pos_only:.2f}") + final_text += "\n" + "\n".join(scores_lines) + + return final_text @track_time("download_file", "helper_func") @track_errors("helper_func", "download_file") diff --git a/requirements.txt b/requirements.txt index 4efef4f..968c3f3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -30,4 +30,10 @@ typing_extensions~=4.12.2 emoji~=2.8.0 # S3 Storage (для хранения медиафайлов опубликованных постов) -aioboto3>=12.0.0 \ No newline at end of file +aioboto3>=12.0.0 + +# ML Scoring (для оценки вероятности публикации постов) +numpy>=1.24.0 +transformers>=4.30.0 +torch>=2.0.0 +httpx>=0.24.0 \ No newline at end of file diff --git a/scripts/add_ml_scores_columns.py b/scripts/add_ml_scores_columns.py new file mode 100644 index 0000000..a7c23ff --- /dev/null +++ b/scripts/add_ml_scores_columns.py @@ -0,0 +1,93 @@ +#!/usr/bin/env python3 +""" +Миграция: Добавление колонок для ML-скоринга постов. + +Добавляет: +- ml_scores (TEXT/JSON) - JSON с результатами оценки от разных моделей +- vector_hash (TEXT) - хеш текста для кеширования векторов + +Структура ml_scores: +{ + "deepseek": {"score": 0.75, "model": "deepseek-chat", "ts": 1706198400}, + "rag": {"score": 0.90, "model": "rubert-base-cased", "ts": 1706198400} +} +""" +import argparse +import asyncio +import os +import sys +from pathlib import Path + +# Добавляем корень проекта в путь +project_root = Path(__file__).resolve().parent.parent +sys.path.insert(0, str(project_root)) + +import aiosqlite + +# Пытаемся импортировать logger, если не получается - используем стандартный +try: + from logs.custom_logger import logger +except ImportError: + import logging + logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') + logger = logging.getLogger(__name__) + +DEFAULT_DB_PATH = "database/tg-bot-database.db" + + +async def column_exists(conn: aiosqlite.Connection, table: str, column: str) -> bool: + """Проверяет существование колонки в таблице.""" + cursor = await conn.execute(f"PRAGMA table_info({table})") + columns = await cursor.fetchall() + return any(col[1] == column for col in columns) + + +async def main(db_path: str) -> None: + """ + Основная функция миграции. + + Добавляет колонки ml_scores и vector_hash в таблицу post_from_telegram_suggest. + Миграция идемпотентна - можно запускать повторно без ошибок. + """ + db_path = os.path.abspath(db_path) + + if not os.path.exists(db_path): + logger.error(f"База данных не найдена: {db_path}") + return + + async with aiosqlite.connect(db_path) as conn: + await conn.execute("PRAGMA foreign_keys = ON") + + # Проверяем и добавляем колонку ml_scores + if not await column_exists(conn, "post_from_telegram_suggest", "ml_scores"): + await conn.execute( + "ALTER TABLE post_from_telegram_suggest ADD COLUMN ml_scores TEXT" + ) + logger.info("Колонка ml_scores добавлена в post_from_telegram_suggest") + else: + logger.info("Колонка ml_scores уже существует") + + # Проверяем и добавляем колонку vector_hash + if not await column_exists(conn, "post_from_telegram_suggest", "vector_hash"): + await conn.execute( + "ALTER TABLE post_from_telegram_suggest ADD COLUMN vector_hash TEXT" + ) + logger.info("Колонка vector_hash добавлена в post_from_telegram_suggest") + else: + logger.info("Колонка vector_hash уже существует") + + await conn.commit() + logger.info("Миграция add_ml_scores_columns завершена успешно") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Добавление колонок ml_scores и vector_hash для ML-скоринга" + ) + parser.add_argument( + "--db", + default=os.environ.get("DATABASE_PATH", DEFAULT_DB_PATH), + help="Путь к БД", + ) + args = parser.parse_args() + asyncio.run(main(args.db)) diff --git a/tests/test_scoring_services.py b/tests/test_scoring_services.py new file mode 100644 index 0000000..048796b --- /dev/null +++ b/tests/test_scoring_services.py @@ -0,0 +1,390 @@ +""" +Тесты для сервисов ML-скоринга постов. +""" + +import json +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +# Импорты для тестирования базовых классов +from helper_bot.services.scoring.base import ScoringResult, CombinedScore +from helper_bot.services.scoring.exceptions import ( + ScoringError, + InsufficientExamplesError, + TextTooShortError, +) + + +class TestScoringResult: + """Тесты для ScoringResult.""" + + def test_create_valid_score(self): + """Тест создания валидного результата.""" + result = ScoringResult( + score=0.75, + source="rag", + model="test-model", + ) + assert result.score == 0.75 + assert result.source == "rag" + assert result.model == "test-model" + + def test_score_validation_lower_bound(self): + """Тест валидации нижней границы скора.""" + with pytest.raises(ValueError): + ScoringResult(score=-0.1, source="test", model="test") + + def test_score_validation_upper_bound(self): + """Тест валидации верхней границы скора.""" + with pytest.raises(ValueError): + ScoringResult(score=1.1, source="test", model="test") + + def test_to_dict(self): + """Тест преобразования в словарь.""" + result = ScoringResult( + score=0.7534, + source="rag", + model="test-model", + confidence=0.85, + timestamp=1234567890, + ) + d = result.to_dict() + + assert d["score"] == 0.7534 # Округлено до 4 знаков + assert d["model"] == "test-model" + assert d["ts"] == 1234567890 + assert d["confidence"] == 0.85 + + def test_from_dict(self): + """Тест создания из словаря.""" + data = { + "score": 0.75, + "model": "test-model", + "ts": 1234567890, + "confidence": 0.9, + } + result = ScoringResult.from_dict("rag", data) + + assert result.score == 0.75 + assert result.source == "rag" + assert result.model == "test-model" + assert result.timestamp == 1234567890 + assert result.confidence == 0.9 + + +class TestCombinedScore: + """Тесты для CombinedScore.""" + + def test_empty_combined_score(self): + """Тест пустого объединенного скора.""" + score = CombinedScore() + + assert score.deepseek is None + assert score.rag is None + assert score.deepseek_score is None + assert score.rag_score is None + assert not score.has_any_score() + + def test_combined_score_with_rag(self): + """Тест объединенного скора с RAG.""" + rag_result = ScoringResult(score=0.8, source="rag", model="rubert") + score = CombinedScore(rag=rag_result) + + assert score.rag_score == 0.8 + assert score.deepseek_score is None + assert score.has_any_score() + + def test_combined_score_with_both(self): + """Тест объединенного скора с обоими сервисами.""" + rag_result = ScoringResult(score=0.8, source="rag", model="rubert") + deepseek_result = ScoringResult(score=0.7, source="deepseek", model="deepseek-chat") + score = CombinedScore(rag=rag_result, deepseek=deepseek_result) + + assert score.rag_score == 0.8 + assert score.deepseek_score == 0.7 + assert score.has_any_score() + + def test_to_json_dict(self): + """Тест преобразования в JSON словарь.""" + rag_result = ScoringResult(score=0.8, source="rag", model="rubert", timestamp=123) + deepseek_result = ScoringResult(score=0.7, source="deepseek", model="deepseek-chat", timestamp=456) + score = CombinedScore(rag=rag_result, deepseek=deepseek_result) + + d = score.to_json_dict() + + assert "rag" in d + assert "deepseek" in d + assert d["rag"]["score"] == 0.8 + assert d["deepseek"]["score"] == 0.7 + + # Проверяем что это валидный JSON + json_str = json.dumps(d) + assert json_str + + +class TestVectorStore: + """Тесты для VectorStore (требует numpy).""" + + @pytest.fixture + def vector_store(self): + """Создает VectorStore для тестов.""" + try: + import numpy as np + from helper_bot.services.scoring.vector_store import VectorStore + return VectorStore(vector_dim=768, max_examples=100) + except ImportError: + pytest.skip("numpy не установлен") + + def test_add_positive_example(self, vector_store): + """Тест добавления положительного примера.""" + import numpy as np + + vector = np.random.randn(768).astype(np.float32) + result = vector_store.add_positive(vector, "hash1") + + assert result is True + assert vector_store.positive_count == 1 + + def test_add_duplicate_example(self, vector_store): + """Тест добавления дубликата.""" + import numpy as np + + vector = np.random.randn(768).astype(np.float32) + vector_store.add_positive(vector, "hash1") + result = vector_store.add_positive(vector, "hash1") # Дубликат + + assert result is False + assert vector_store.positive_count == 1 + + def test_max_examples_limit(self, vector_store): + """Тест ограничения максимального количества примеров.""" + import numpy as np + + # Добавляем больше чем max_examples + for i in range(150): + vector = np.random.randn(768).astype(np.float32) + vector_store.add_positive(vector, f"hash_{i}") + + assert vector_store.positive_count == 100 # max_examples + + def test_calculate_similarity_no_examples(self, vector_store): + """Тест расчета скора без примеров.""" + import numpy as np + + vector = np.random.randn(768).astype(np.float32) + + with pytest.raises(InsufficientExamplesError): + vector_store.calculate_similarity_score(vector) + + def test_calculate_similarity_with_examples(self, vector_store): + """Тест расчета скора с примерами.""" + import numpy as np + + # Добавляем положительные примеры + for i in range(10): + vector = np.random.randn(768).astype(np.float32) + vector_store.add_positive(vector, f"pos_{i}") + + # Добавляем отрицательные примеры + for i in range(10): + vector = np.random.randn(768).astype(np.float32) + vector_store.add_negative(vector, f"neg_{i}") + + # Рассчитываем скор для нового вектора + test_vector = np.random.randn(768).astype(np.float32) + score, confidence = vector_store.calculate_similarity_score(test_vector) + + assert 0.0 <= score <= 1.0 + assert 0.0 <= confidence <= 1.0 + + def test_compute_text_hash(self, vector_store): + """Тест вычисления хеша текста.""" + from helper_bot.services.scoring.vector_store import VectorStore + + hash1 = VectorStore.compute_text_hash("Привет мир") + hash2 = VectorStore.compute_text_hash("Привет мир") + hash3 = VectorStore.compute_text_hash("Другой текст") + + assert hash1 == hash2 + assert hash1 != hash3 + + +class TestDeepSeekService: + """Тесты для DeepSeekService.""" + + @pytest.fixture + def deepseek_service(self): + """Создает DeepSeekService для тестов.""" + from helper_bot.services.scoring.deepseek_service import DeepSeekService + return DeepSeekService( + api_key="test_key", + enabled=True, + timeout=5, + ) + + def test_service_disabled_without_key(self): + """Тест отключения сервиса без API ключа.""" + from helper_bot.services.scoring.deepseek_service import DeepSeekService + service = DeepSeekService(api_key=None, enabled=True) + + assert service.is_enabled is False + + def test_parse_score_response_valid(self, deepseek_service): + """Тест парсинга валидного ответа.""" + assert deepseek_service._parse_score_response("0.75") == 0.75 + assert deepseek_service._parse_score_response("0.5") == 0.5 + assert deepseek_service._parse_score_response("1.0") == 1.0 + assert deepseek_service._parse_score_response("0") == 0.0 + + def test_parse_score_response_with_quotes(self, deepseek_service): + """Тест парсинга ответа с кавычками.""" + assert deepseek_service._parse_score_response('"0.75"') == 0.75 + assert deepseek_service._parse_score_response("'0.8'") == 0.8 + + def test_parse_score_response_with_text(self, deepseek_service): + """Тест парсинга ответа с текстом.""" + # Сервис должен найти число в тексте + assert deepseek_service._parse_score_response("Score: 0.75") == 0.75 + + def test_clean_text(self, deepseek_service): + """Тест очистки текста.""" + assert deepseek_service._clean_text(" hello world ") == "hello world" + assert deepseek_service._clean_text("^") == "" + assert deepseek_service._clean_text("") == "" + + @pytest.mark.asyncio + async def test_calculate_score_disabled(self): + """Тест расчета скора при отключенном сервисе.""" + from helper_bot.services.scoring.deepseek_service import DeepSeekService + service = DeepSeekService(api_key=None, enabled=False) + + with pytest.raises(ScoringError): + await service.calculate_score("Test text") + + @pytest.mark.asyncio + async def test_calculate_score_short_text(self, deepseek_service): + """Тест расчета скора для короткого текста.""" + with pytest.raises(TextTooShortError): + await deepseek_service.calculate_score("ab") + + +class TestScoringManager: + """Тесты для ScoringManager.""" + + @pytest.fixture + def mock_rag_service(self): + """Создает мок RAG сервиса.""" + mock = AsyncMock() + mock.is_enabled = True + mock.calculate_score = AsyncMock(return_value=ScoringResult( + score=0.8, + source="rag", + model="rubert", + )) + return mock + + @pytest.fixture + def mock_deepseek_service(self): + """Создает мок DeepSeek сервиса.""" + mock = AsyncMock() + mock.is_enabled = True + mock.calculate_score = AsyncMock(return_value=ScoringResult( + score=0.7, + source="deepseek", + model="deepseek-chat", + )) + return mock + + @pytest.mark.asyncio + async def test_score_post_both_services(self, mock_rag_service, mock_deepseek_service): + """Тест скоринга с обоими сервисами.""" + from helper_bot.services.scoring.scoring_manager import ScoringManager + + manager = ScoringManager( + rag_service=mock_rag_service, + deepseek_service=mock_deepseek_service, + ) + + result = await manager.score_post("Тестовый пост") + + assert result.rag_score == 0.8 + assert result.deepseek_score == 0.7 + assert result.has_any_score() + + @pytest.mark.asyncio + async def test_score_post_rag_only(self, mock_rag_service): + """Тест скоринга только с RAG.""" + from helper_bot.services.scoring.scoring_manager import ScoringManager + + manager = ScoringManager( + rag_service=mock_rag_service, + deepseek_service=None, + ) + + result = await manager.score_post("Тестовый пост") + + assert result.rag_score == 0.8 + assert result.deepseek_score is None + + @pytest.mark.asyncio + async def test_score_post_empty_text(self, mock_rag_service): + """Тест скоринга пустого текста.""" + from helper_bot.services.scoring.scoring_manager import ScoringManager + + manager = ScoringManager(rag_service=mock_rag_service) + + result = await manager.score_post("") + + assert not result.has_any_score() + mock_rag_service.calculate_score.assert_not_called() + + @pytest.mark.asyncio + async def test_score_post_service_error(self, mock_rag_service, mock_deepseek_service): + """Тест обработки ошибки сервиса.""" + from helper_bot.services.scoring.scoring_manager import ScoringManager + + # RAG выбрасывает ошибку + mock_rag_service.calculate_score = AsyncMock(side_effect=Exception("Test error")) + + manager = ScoringManager( + rag_service=mock_rag_service, + deepseek_service=mock_deepseek_service, + ) + + result = await manager.score_post("Тестовый пост") + + # DeepSeek должен вернуть результат + assert result.deepseek_score == 0.7 + # RAG должен быть None с ошибкой + assert result.rag_score is None + assert "rag" in result.errors + + @pytest.mark.asyncio + async def test_on_post_published(self, mock_rag_service, mock_deepseek_service): + """Тест обучения на опубликованном посте.""" + from helper_bot.services.scoring.scoring_manager import ScoringManager + + manager = ScoringManager( + rag_service=mock_rag_service, + deepseek_service=mock_deepseek_service, + ) + + await manager.on_post_published("Опубликованный пост") + + mock_rag_service.add_positive_example.assert_called_once_with("Опубликованный пост") + mock_deepseek_service.add_positive_example.assert_called_once_with("Опубликованный пост") + + @pytest.mark.asyncio + async def test_on_post_declined(self, mock_rag_service, mock_deepseek_service): + """Тест обучения на отклоненном посте.""" + from helper_bot.services.scoring.scoring_manager import ScoringManager + + manager = ScoringManager( + rag_service=mock_rag_service, + deepseek_service=mock_deepseek_service, + ) + + await manager.on_post_declined("Отклоненный пост") + + mock_rag_service.add_negative_example.assert_called_once_with("Отклоненный пост") + mock_deepseek_service.add_negative_example.assert_called_once_with("Отклоненный пост")