feat: интеграция ML-скоринга с использованием RAG и DeepSeek

- Обновлен Dockerfile для установки необходимых зависимостей.
- Добавлены новые переменные окружения для настройки ML-скоринга в env.example.
- Реализованы методы для получения и обновления ML-скоров в AsyncBotDB и PostRepository.
- Обновлены обработчики публикации постов для интеграции ML-скоринга.
- Добавлен новый обработчик для получения статистики ML-скоринга в админ-панели.
- Обновлены функции для форматирования сообщений с учетом ML-скоров.
This commit is contained in:
2026-01-26 18:40:38 +03:00
parent e2b1353408
commit 7f6f0f028c
25 changed files with 2833 additions and 52 deletions

View File

@@ -165,9 +165,8 @@ jobs:
📦 Repository: telegram-helper-bot 📦 Repository: telegram-helper-bot
🌿 Branch: main 🌿 Branch: main
📝 Commit: ${{ github.event.pull_request.merge_commit_sha || github.sha }} 📝 Commit: ${{ github.sha }}
👤 Author: ${{ github.event.pull_request.user.login || github.actor }} 👤 Author: ${{ github.actor }}
${{ github.event.pull_request.number && format('🔀 PR: #{0}', github.event.pull_request.number) || '' }}
${{ job.status == 'success' && '✅ Deployment successful! Container restarted with migrations applied.' || '❌ Deployment failed! Check logs for details.' }} ${{ job.status == 'success' && '✅ Deployment successful! Container restarted with migrations applied.' || '❌ Deployment failed! Check logs for details.' }}

View File

@@ -1,15 +1,14 @@
########################################### ###########################################
# Этап 1: Сборщик (Builder) # Этап 1: Сборщик (Builder)
########################################### ###########################################
FROM python:3.11.9-alpine as builder FROM python:3.11.9-slim as builder
# Устанавливаем инструменты для компиляции + linux-headers для psutil # Устанавливаем инструменты для компиляции
RUN apk add --no-cache \ RUN apt-get update && apt-get install --no-install-recommends -y \
gcc \ gcc \
g++ \ g++ \
musl-dev \
python3-dev \ python3-dev \
linux-headers # ← ЭТО КРИТИЧЕСКИ ВАЖНО ДЛЯ psutil && rm -rf /var/lib/apt/lists/*
WORKDIR /app WORKDIR /app
COPY requirements.txt . COPY requirements.txt .
@@ -21,29 +20,34 @@ RUN pip install --no-cache-dir --target /install -r requirements.txt
########################################### ###########################################
# Этап 2: Финальный образ (Runtime) # Этап 2: Финальный образ (Runtime)
########################################### ###########################################
FROM python:3.11.9-alpine as runtime FROM python:3.11.9-slim as runtime
# Минимальные рантайм-зависимости # Минимальные рантайм-зависимости
RUN apk add --no-cache \ RUN apt-get update && apt-get install --no-install-recommends -y \
libstdc++ \ libgomp1 \
sqlite-libs && rm -rf /var/lib/apt/lists/*
# Создаем пользователя # Создаем пользователя
RUN addgroup -g 1001 deploy && adduser -D -u 1001 -G deploy deploy RUN groupadd -g 1001 deploy && useradd -r -u 1001 -g deploy deploy
WORKDIR /app WORKDIR /app
# Копируем зависимости # Копируем зависимости
COPY --from=builder --chown=1001:1001 /install /usr/local/lib/python3.11/site-packages COPY --from=builder --chown=deploy:deploy /install /usr/local/lib/python3.11/site-packages
# Создаем структуру папок # Создаем структуру папок (включая директории для ML моделей)
RUN mkdir -p database logs voice_users && \ RUN mkdir -p database logs voice_users data/models && \
chown -R 1001:1001 /app chown -R deploy:deploy /app
# Устанавливаем переменные для HuggingFace (кеш моделей внутри /app)
ENV HF_HOME=/app/data/models
ENV TRANSFORMERS_CACHE=/app/data/models
ENV HF_HUB_CACHE=/app/data/models
# Копируем исходный код # Копируем исходный код
COPY --chown=1001:1001 . . COPY --chown=deploy:deploy . .
USER 1001 USER deploy
# Healthcheck # Healthcheck
HEALTHCHECK --interval=30s --timeout=15s --start-period=10s --retries=5 \ HEALTHCHECK --interval=30s --timeout=15s --start-period=10s --retries=5 \

View File

@@ -210,6 +210,23 @@ class AsyncBotDB:
return await self.factory.posts.update_status_for_media_group_by_helper_id( return await self.factory.posts.update_status_for_media_group_by_helper_id(
helper_message_id, status helper_message_id, status
) )
# Методы для ML Scoring
async def get_post_text_by_message_id(self, message_id: int) -> Optional[str]:
"""Получает текст поста по message_id."""
return await self.factory.posts.get_post_text_by_message_id(message_id)
async def update_ml_scores(self, message_id: int, ml_scores_json: str) -> bool:
"""Обновляет ML-скоры для поста."""
return await self.factory.posts.update_ml_scores(message_id, ml_scores_json)
async def get_approved_posts_texts(self, limit: int = 1000) -> List[str]:
"""Получает тексты одобренных постов для обучения RAG."""
return await self.factory.posts.get_approved_posts_texts(limit)
async def get_declined_posts_texts(self, limit: int = 1000) -> List[str]:
"""Получает тексты отклоненных постов для обучения RAG."""
return await self.factory.posts.get_declined_posts_texts(limit)
# Методы для работы с черным списком # Методы для работы с черным списком
async def set_user_blacklist( async def set_user_blacklist(

View File

@@ -357,3 +357,126 @@ class PostRepository(DatabaseConnection):
post_content = await self._execute_query_with_result(query, (published_message_id,)) post_content = await self._execute_query_with_result(query, (published_message_id,))
self.logger.info(f"Получен контент опубликованного поста: {len(post_content)} элементов для published_message_id={published_message_id}") self.logger.info(f"Получен контент опубликованного поста: {len(post_content)} элементов для published_message_id={published_message_id}")
return post_content return post_content
# ============================================
# Методы для работы с ML-скорингом
# ============================================
async def update_ml_scores(self, message_id: int, ml_scores_json: str) -> bool:
"""
Обновляет ML-скоры для поста.
Args:
message_id: ID сообщения в группе модерации
ml_scores_json: JSON строка со скорами
Returns:
True если обновлено успешно
"""
try:
query = "UPDATE post_from_telegram_suggest SET ml_scores = ? WHERE message_id = ?"
await self._execute_query(query, (ml_scores_json, message_id))
self.logger.info(f"ML-скоры обновлены для message_id={message_id}")
return True
except Exception as e:
self.logger.error(f"Ошибка обновления ML-скоров для message_id={message_id}: {e}")
return False
async def get_ml_scores_by_message_id(self, message_id: int) -> Optional[str]:
"""
Получает ML-скоры для поста.
Args:
message_id: ID сообщения
Returns:
JSON строка со скорами или None
"""
query = "SELECT ml_scores FROM post_from_telegram_suggest WHERE message_id = ?"
rows = await self._execute_query_with_result(query, (message_id,))
if rows and rows[0][0]:
return rows[0][0]
return None
async def get_post_text_by_message_id(self, message_id: int) -> Optional[str]:
"""
Получает текст поста по message_id.
Args:
message_id: ID сообщения
Returns:
Текст поста или None
"""
query = "SELECT text FROM post_from_telegram_suggest WHERE message_id = ?"
rows = await self._execute_query_with_result(query, (message_id,))
if rows and rows[0][0]:
return rows[0][0]
return None
async def get_approved_posts_texts(self, limit: int = 1000) -> List[str]:
"""
Получает тексты опубликованных постов для обучения RAG.
Args:
limit: Максимальное количество постов
Returns:
Список текстов
"""
query = """
SELECT text FROM post_from_telegram_suggest
WHERE status = 'approved'
AND text IS NOT NULL
AND text != ''
AND text != '^'
ORDER BY created_at DESC
LIMIT ?
"""
rows = await self._execute_query_with_result(query, (limit,))
texts = [row[0] for row in rows if row[0]]
self.logger.info(f"Получено {len(texts)} опубликованных постов для обучения")
return texts
async def get_declined_posts_texts(self, limit: int = 1000) -> List[str]:
"""
Получает тексты отклоненных постов для обучения RAG.
Args:
limit: Максимальное количество постов
Returns:
Список текстов
"""
query = """
SELECT text FROM post_from_telegram_suggest
WHERE status = 'declined'
AND text IS NOT NULL
AND text != ''
AND text != '^'
ORDER BY created_at DESC
LIMIT ?
"""
rows = await self._execute_query_with_result(query, (limit,))
texts = [row[0] for row in rows if row[0]]
self.logger.info(f"Получено {len(texts)} отклоненных постов для обучения")
return texts
async def update_vector_hash(self, message_id: int, vector_hash: str) -> bool:
"""
Обновляет хеш вектора для поста (для кеширования).
Args:
message_id: ID сообщения
vector_hash: Хеш вектора
Returns:
True если обновлено успешно
"""
try:
query = "UPDATE post_from_telegram_suggest SET vector_hash = ? WHERE message_id = ?"
await self._execute_query(query, (vector_hash, message_id))
return True
except Exception as e:
self.logger.error(f"Ошибка обновления vector_hash для message_id={message_id}: {e}")
return False

View File

@@ -35,3 +35,20 @@ METRICS_PORT=8080
# Logging # Logging
LOG_LEVEL=INFO LOG_LEVEL=INFO
LOG_RETENTION_DAYS=30 LOG_RETENTION_DAYS=30
# ML Scoring - RAG (ruBERT)
# Включает локальное векторное сравнение с использованием ruBERT
RAG_ENABLED=false
RAG_MODEL=DeepPavlov/rubert-base-cased
RAG_CACHE_DIR=data/models
RAG_VECTORS_PATH=data/vectors.npz
RAG_MAX_EXAMPLES=10000
RAG_SCORE_MULTIPLIER=5
# ML Scoring - DeepSeek API
# Включает оценку постов через DeepSeek API
DEEPSEEK_ENABLED=false
DEEPSEEK_API_KEY=your_deepseek_api_key_here
DEEPSEEK_API_URL=https://api.deepseek.com/v1/chat/completions
DEEPSEEK_MODEL=deepseek-chat
DEEPSEEK_TIMEOUT=30

View File

@@ -16,6 +16,7 @@ from helper_bot.keyboards.keyboards import (create_keyboard_for_approve_ban,
create_keyboard_for_ban_reason, create_keyboard_for_ban_reason,
create_keyboard_with_pagination, create_keyboard_with_pagination,
get_reply_keyboard_admin) get_reply_keyboard_admin)
from helper_bot.utils.base_dependency_factory import get_global_instance
# Local imports - metrics # Local imports - metrics
from helper_bot.utils.metrics import db_query_time, track_errors, track_time from helper_bot.utils.metrics import db_query_time, track_errors, track_time
from logs.custom_logger import logger from logs.custom_logger import logger
@@ -137,6 +138,69 @@ async def get_banned_users(
await handle_admin_error(message, e, state, "get_banned_users") await handle_admin_error(message, e, state, "get_banned_users")
@admin_router.message(
ChatTypeFilter(chat_type=["private"]),
StateFilter("ADMIN"),
F.text == '📊 ML Статистика'
)
@track_time("get_ml_stats", "admin_handlers")
@track_errors("admin_handlers", "get_ml_stats")
async def get_ml_stats(
message: types.Message,
state: FSMContext,
**kwargs
):
"""Получение статистики ML-скоринга"""
try:
logger.info(f"Запрос ML статистики от пользователя: {message.from_user.full_name}")
bdf = get_global_instance()
scoring_manager = bdf.get_scoring_manager()
if not scoring_manager:
await message.answer("📊 ML Scoring отключен\n\nДля включения установите RAG_ENABLED=true или DEEPSEEK_ENABLED=true в .env")
return
stats = scoring_manager.get_stats()
# Формируем текст статистики
lines = ["📊 <b>ML Scoring Статистика</b>\n"]
# RAG статистика
if "rag" in stats:
rag = stats["rag"]
lines.append("🤖 <b>RAG (ruBERT):</b>")
lines.append(f" • Статус: {'✅ Включен' if rag.get('enabled') else '❌ Отключен'}")
lines.append(f" • Модель: {rag.get('model_name', 'N/A')}")
lines.append(f" • Модель загружена: {'' if rag.get('model_loaded') else ''}")
vs = rag.get("vector_store", {})
lines.append(f" • Положительных примеров: {vs.get('positive_count', 0)}")
lines.append(f" • Отрицательных примеров: {vs.get('negative_count', 0)}")
lines.append(f"Всего примеров: {vs.get('total_count', 0)}")
lines.append(f" • Макс. примеров: {vs.get('max_examples', 'N/A')}")
lines.append("")
# DeepSeek статистика
if "deepseek" in stats:
ds = stats["deepseek"]
lines.append("🔮 <b>DeepSeek API:</b>")
lines.append(f" • Статус: {'✅ Включен' if ds.get('enabled') else '❌ Отключен'}")
lines.append(f" • Модель: {ds.get('model', 'N/A')}")
lines.append(f" • Таймаут: {ds.get('timeout', 'N/A')}с")
lines.append("")
# Если ничего не включено
if "rag" not in stats and "deepseek" not in stats:
lines.append("⚠️ Ни один сервис не настроен")
await message.answer("\n".join(lines), parse_mode="HTML")
except Exception as e:
logger.error(f"Ошибка получения ML статистики: {e}")
await message.answer(f"❌ Ошибка получения статистики: {str(e)}")
# ============================================================================ # ============================================================================
# ХЕНДЛЕРЫ ПРОЦЕССА БАНА # ХЕНДЛЕРЫ ПРОЦЕССА БАНА
# ============================================================================ # ============================================================================

View File

@@ -15,7 +15,8 @@ def get_post_publish_service() -> PostPublishService:
db = bdf.get_db() db = bdf.get_db()
settings = bdf.settings settings = bdf.settings
s3_storage = bdf.get_s3_storage() s3_storage = bdf.get_s3_storage()
return PostPublishService(None, db, settings, s3_storage) scoring_manager = bdf.get_scoring_manager()
return PostPublishService(None, db, settings, s3_storage, scoring_manager)
def get_ban_service() -> BanService: def get_ban_service() -> BanService:

View File

@@ -29,12 +29,13 @@ from .exceptions import (BanError, PostNotFoundError, PublishError,
class PostPublishService: class PostPublishService:
def __init__(self, bot: Bot, db, settings: Dict[str, Any], s3_storage=None): def __init__(self, bot: Bot, db, settings: Dict[str, Any], s3_storage=None, scoring_manager=None):
# bot может быть None - в этом случае используем бота из контекста сообщения # bot может быть None - в этом случае используем бота из контекста сообщения
self.bot = bot self.bot = bot
self.db = db self.db = db
self.settings = settings self.settings = settings
self.s3_storage = s3_storage self.s3_storage = s3_storage
self.scoring_manager = scoring_manager
self.group_for_posts = settings['Telegram']['group_for_posts'] self.group_for_posts = settings['Telegram']['group_for_posts']
self.main_public = settings['Telegram']['main_public'] self.main_public = settings['Telegram']['main_public']
self.important_logs = settings['Telegram']['important_logs'] self.important_logs = settings['Telegram']['important_logs']
@@ -392,6 +393,9 @@ class PostPublishService:
async def _decline_single_post(self, call: CallbackQuery) -> None: async def _decline_single_post(self, call: CallbackQuery) -> None:
"""Отклонение одиночного поста""" """Отклонение одиночного поста"""
author_id = await self._get_author_id(call.message.message_id) author_id = await self._get_author_id(call.message.message_id)
# Обучаем RAG на отклоненном посте перед удалением
await self._train_on_declined(call.message.message_id)
updated_rows = await self.db.update_status_by_message_id(call.message.message_id, "declined") updated_rows = await self.db.update_status_by_message_id(call.message.message_id, "declined")
if updated_rows == 0: if updated_rows == 0:
@@ -485,6 +489,9 @@ class PostPublishService:
@track_errors("post_publish_service", "_delete_post_and_notify_author") @track_errors("post_publish_service", "_delete_post_and_notify_author")
async def _delete_post_and_notify_author(self, call: CallbackQuery, author_id: int) -> None: async def _delete_post_and_notify_author(self, call: CallbackQuery, author_id: int) -> None:
"""Удаление поста и уведомление автора""" """Удаление поста и уведомление автора"""
# Получаем текст поста для обучения RAG перед удалением
await self._train_on_published(call.message.message_id)
await self._get_bot(call.message).delete_message(chat_id=self.group_for_posts, message_id=call.message.message_id) await self._get_bot(call.message).delete_message(chat_id=self.group_for_posts, message_id=call.message.message_id)
try: try:
@@ -493,6 +500,32 @@ class PostPublishService:
if str(e) == ERROR_BOT_BLOCKED: if str(e) == ERROR_BOT_BLOCKED:
raise UserBlockedBotError("Пользователь заблокировал бота") raise UserBlockedBotError("Пользователь заблокировал бота")
raise raise
async def _train_on_published(self, message_id: int) -> None:
"""Обучает RAG на опубликованном посте."""
if not self.scoring_manager:
return
try:
text = await self.db.get_post_text_by_message_id(message_id)
if text and text.strip() and text != "^":
await self.scoring_manager.on_post_published(text)
logger.debug(f"RAG обучен на опубликованном посте: {message_id}")
except Exception as e:
logger.error(f"Ошибка обучения RAG на опубликованном посте {message_id}: {e}")
async def _train_on_declined(self, message_id: int) -> None:
"""Обучает RAG на отклоненном посте."""
if not self.scoring_manager:
return
try:
text = await self.db.get_post_text_by_message_id(message_id)
if text and text.strip() and text != "^":
await self.scoring_manager.on_post_declined(text)
logger.debug(f"RAG обучен на отклоненном посте: {message_id}")
except Exception as e:
logger.error(f"Ошибка обучения RAG на отклоненном посте {message_id}: {e}")
@track_time("_delete_media_group_and_notify_author", "post_publish_service") @track_time("_delete_media_group_and_notify_author", "post_publish_service")
@track_errors("post_publish_service", "_delete_media_group_and_notify_author") @track_errors("post_publish_service", "_delete_media_group_and_notify_author")

View File

@@ -35,11 +35,11 @@ sleep = asyncio.sleep
class PrivateHandlers: class PrivateHandlers:
"""Main handler class for private messages""" """Main handler class for private messages"""
def __init__(self, db: AsyncBotDB, settings: BotSettings, s3_storage=None): def __init__(self, db: AsyncBotDB, settings: BotSettings, s3_storage=None, scoring_manager=None):
self.db = db self.db = db
self.settings = settings self.settings = settings
self.user_service = UserService(db, settings) self.user_service = UserService(db, settings)
self.post_service = PostService(db, settings, s3_storage) self.post_service = PostService(db, settings, s3_storage, scoring_manager)
self.sticker_service = StickerService(settings) self.sticker_service = StickerService(settings)
self.router = Router() self.router = Router()
@@ -240,18 +240,24 @@ class PrivateHandlers:
# Factory function to create handlers with dependencies # Factory function to create handlers with dependencies
def create_private_handlers(db: AsyncBotDB, settings: BotSettings, s3_storage=None) -> PrivateHandlers: def create_private_handlers(db: AsyncBotDB, settings: BotSettings, s3_storage=None, scoring_manager=None) -> PrivateHandlers:
"""Create private handlers instance with dependencies""" """Create private handlers instance with dependencies"""
return PrivateHandlers(db, settings, s3_storage) return PrivateHandlers(db, settings, s3_storage, scoring_manager)
# Legacy router for backward compatibility # Legacy router for backward compatibility
private_router = Router() private_router = Router()
# Флаг инициализации для защиты от повторного вызова
_legacy_router_initialized = False
# Initialize with global dependencies (for backward compatibility) # Initialize with global dependencies (for backward compatibility)
def init_legacy_router(): def init_legacy_router():
"""Initialize legacy router with global dependencies""" """Initialize legacy router with global dependencies"""
global private_router global private_router, _legacy_router_initialized
if _legacy_router_initialized:
return
from helper_bot.utils.base_dependency_factory import get_global_instance from helper_bot.utils.base_dependency_factory import get_global_instance
@@ -269,11 +275,13 @@ def init_legacy_router():
db = bdf.get_db() db = bdf.get_db()
s3_storage = bdf.get_s3_storage() s3_storage = bdf.get_s3_storage()
handlers = create_private_handlers(db, settings, s3_storage) scoring_manager = bdf.get_scoring_manager()
handlers = create_private_handlers(db, settings, s3_storage, scoring_manager)
# Instead of trying to copy handlers, we'll use the new router directly # Instead of trying to copy handlers, we'll use the new router directly
# This maintains backward compatibility while using the new architecture # This maintains backward compatibility while using the new architecture
private_router = handlers.router private_router = handlers.router
_legacy_router_initialized = True
# Initialize legacy router # Initialize legacy router
init_legacy_router() init_legacy_router()

View File

@@ -128,10 +128,11 @@ class UserService:
class PostService: class PostService:
"""Service for post-related operations""" """Service for post-related operations"""
def __init__(self, db: DatabaseProtocol, settings: BotSettings, s3_storage=None) -> None: def __init__(self, db: DatabaseProtocol, settings: BotSettings, s3_storage=None, scoring_manager=None) -> None:
self.db = db self.db = db
self.settings = settings self.settings = settings
self.s3_storage = s3_storage self.s3_storage = s3_storage
self.scoring_manager = scoring_manager
async def _save_media_background(self, sent_message: types.Message, bot_db: Any, s3_storage) -> None: async def _save_media_background(self, sent_message: types.Message, bot_db: Any, s3_storage) -> None:
"""Сохраняет медиа в фоне, чтобы не блокировать ответ пользователю""" """Сохраняет медиа в фоне, чтобы не блокировать ответ пользователю"""
@@ -142,18 +143,65 @@ class PostService:
except Exception as e: except Exception as e:
logger.error(f"_save_media_background: Ошибка при сохранении медиа для поста {sent_message.message_id}: {e}") logger.error(f"_save_media_background: Ошибка при сохранении медиа для поста {sent_message.message_id}: {e}")
async def _get_scores(self, text: str) -> tuple:
"""
Получает скоры для текста поста.
Returns:
Tuple (deepseek_score, rag_score, rag_confidence, rag_score_pos_only, ml_scores_json)
"""
if not self.scoring_manager or not text or not text.strip():
return None, None, None, None, None
try:
scores = await self.scoring_manager.score_post(text)
# Формируем JSON для сохранения в БД
import json
ml_scores_json = json.dumps(scores.to_json_dict()) if scores.has_any_score() else None
# Получаем данные от RAG
rag_confidence = scores.rag.confidence if scores.rag else None
rag_score_pos_only = scores.rag.metadata.get("score_pos_only") if scores.rag else None
return scores.deepseek_score, scores.rag_score, rag_confidence, rag_score_pos_only, ml_scores_json
except Exception as e:
logger.error(f"PostService: Ошибка получения скоров: {e}")
return None, None, None, None, None
async def _save_scores_background(self, message_id: int, ml_scores_json: str) -> None:
"""Сохраняет скоры в БД в фоне."""
if ml_scores_json:
try:
await self.db.update_ml_scores(message_id, ml_scores_json)
except Exception as e:
logger.error(f"PostService: Ошибка сохранения скоров для {message_id}: {e}")
@track_time("handle_text_post", "post_service") @track_time("handle_text_post", "post_service")
@track_errors("post_service", "handle_text_post") @track_errors("post_service", "handle_text_post")
@db_query_time("handle_text_post", "posts", "insert") @db_query_time("handle_text_post", "posts", "insert")
async def handle_text_post(self, message: types.Message, first_name: str) -> None: async def handle_text_post(self, message: types.Message, first_name: str) -> None:
"""Handle text post submission""" """Handle text post submission"""
post_text = get_text_message(message.text.lower(), first_name, message.from_user.username) raw_text = message.text or ""
# Получаем скоры для текста
deepseek_score, rag_score, rag_confidence, rag_score_pos_only, ml_scores_json = await self._get_scores(raw_text)
# Формируем текст с учетом скоров
post_text = get_text_message(
message.text.lower(),
first_name,
message.from_user.username,
deepseek_score=deepseek_score,
rag_score=rag_score,
rag_confidence=rag_confidence,
rag_score_pos_only=rag_score_pos_only,
)
markup = get_reply_keyboard_for_post() markup = get_reply_keyboard_for_post()
sent_message = await send_text_message(self.settings.group_for_posts, message, post_text, markup) sent_message = await send_text_message(self.settings.group_for_posts, message, post_text, markup)
# Сохраняем сырой текст и определяем анонимность # Определяем анонимность
raw_text = message.text or ""
is_anonymous = determine_anonymity(raw_text) is_anonymous = determine_anonymity(raw_text)
post = TelegramPost( post = TelegramPost(
@@ -164,23 +212,39 @@ class PostService:
is_anonymous=is_anonymous is_anonymous=is_anonymous
) )
await self.db.add_post(post) await self.db.add_post(post)
# Сохраняем скоры в фоне
if ml_scores_json:
asyncio.create_task(self._save_scores_background(sent_message.message_id, ml_scores_json))
@track_time("handle_photo_post", "post_service") @track_time("handle_photo_post", "post_service")
@track_errors("post_service", "handle_photo_post") @track_errors("post_service", "handle_photo_post")
@db_query_time("handle_photo_post", "posts", "insert") @db_query_time("handle_photo_post", "posts", "insert")
async def handle_photo_post(self, message: types.Message, first_name: str) -> None: async def handle_photo_post(self, message: types.Message, first_name: str) -> None:
"""Handle photo post submission""" """Handle photo post submission"""
raw_caption = message.caption or ""
# Получаем скоры для текста
deepseek_score, rag_score, rag_confidence, rag_score_pos_only, ml_scores_json = await self._get_scores(raw_caption)
post_caption = "" post_caption = ""
if message.caption: if message.caption:
post_caption = get_text_message(message.caption.lower(), first_name, message.from_user.username) post_caption = get_text_message(
message.caption.lower(),
first_name,
message.from_user.username,
deepseek_score=deepseek_score,
rag_score=rag_score,
rag_confidence=rag_confidence,
rag_score_pos_only=rag_score_pos_only,
)
markup = get_reply_keyboard_for_post() markup = get_reply_keyboard_for_post()
sent_message = await send_photo_message( sent_message = await send_photo_message(
self.settings.group_for_posts, message, message.photo[-1].file_id, post_caption, markup self.settings.group_for_posts, message, message.photo[-1].file_id, post_caption, markup
) )
# Сохраняем сырой caption и определяем анонимность # Определяем анонимность
raw_caption = message.caption or ""
is_anonymous = determine_anonymity(raw_caption) is_anonymous = determine_anonymity(raw_caption)
post = TelegramPost( post = TelegramPost(
@@ -191,25 +255,40 @@ class PostService:
is_anonymous=is_anonymous is_anonymous=is_anonymous
) )
await self.db.add_post(post) await self.db.add_post(post)
# Сохраняем медиа в фоне, чтобы не блокировать ответ пользователю
# Сохраняем медиа и скоры в фоне
asyncio.create_task(self._save_media_background(sent_message, self.db, self.s3_storage)) asyncio.create_task(self._save_media_background(sent_message, self.db, self.s3_storage))
if ml_scores_json:
asyncio.create_task(self._save_scores_background(sent_message.message_id, ml_scores_json))
@track_time("handle_video_post", "post_service") @track_time("handle_video_post", "post_service")
@track_errors("post_service", "handle_video_post") @track_errors("post_service", "handle_video_post")
@db_query_time("handle_video_post", "posts", "insert") @db_query_time("handle_video_post", "posts", "insert")
async def handle_video_post(self, message: types.Message, first_name: str) -> None: async def handle_video_post(self, message: types.Message, first_name: str) -> None:
"""Handle video post submission""" """Handle video post submission"""
raw_caption = message.caption or ""
# Получаем скоры для текста
deepseek_score, rag_score, rag_confidence, rag_score_pos_only, ml_scores_json = await self._get_scores(raw_caption)
post_caption = "" post_caption = ""
if message.caption: if message.caption:
post_caption = get_text_message(message.caption.lower(), first_name, message.from_user.username) post_caption = get_text_message(
message.caption.lower(),
first_name,
message.from_user.username,
deepseek_score=deepseek_score,
rag_score=rag_score,
rag_confidence=rag_confidence,
rag_score_pos_only=rag_score_pos_only,
)
markup = get_reply_keyboard_for_post() markup = get_reply_keyboard_for_post()
sent_message = await send_video_message( sent_message = await send_video_message(
self.settings.group_for_posts, message, message.video.file_id, post_caption, markup self.settings.group_for_posts, message, message.video.file_id, post_caption, markup
) )
# Сохраняем сырой caption и определяем анонимность # Определяем анонимность
raw_caption = message.caption or ""
is_anonymous = determine_anonymity(raw_caption) is_anonymous = determine_anonymity(raw_caption)
post = TelegramPost( post = TelegramPost(
@@ -220,8 +299,11 @@ class PostService:
is_anonymous=is_anonymous is_anonymous=is_anonymous
) )
await self.db.add_post(post) await self.db.add_post(post)
# Сохраняем медиа в фоне, чтобы не блокировать ответ пользователю
# Сохраняем медиа и скоры в фоне
asyncio.create_task(self._save_media_background(sent_message, self.db, self.s3_storage)) asyncio.create_task(self._save_media_background(sent_message, self.db, self.s3_storage))
if ml_scores_json:
asyncio.create_task(self._save_scores_background(sent_message.message_id, ml_scores_json))
@track_time("handle_video_note_post", "post_service") @track_time("handle_video_note_post", "post_service")
@track_errors("post_service", "handle_video_note_post") @track_errors("post_service", "handle_video_note_post")
@@ -253,17 +335,29 @@ class PostService:
@db_query_time("handle_audio_post", "posts", "insert") @db_query_time("handle_audio_post", "posts", "insert")
async def handle_audio_post(self, message: types.Message, first_name: str) -> None: async def handle_audio_post(self, message: types.Message, first_name: str) -> None:
"""Handle audio post submission""" """Handle audio post submission"""
raw_caption = message.caption or ""
# Получаем скоры для текста
deepseek_score, rag_score, rag_confidence, rag_score_pos_only, ml_scores_json = await self._get_scores(raw_caption)
post_caption = "" post_caption = ""
if message.caption: if message.caption:
post_caption = get_text_message(message.caption.lower(), first_name, message.from_user.username) post_caption = get_text_message(
message.caption.lower(),
first_name,
message.from_user.username,
deepseek_score=deepseek_score,
rag_score=rag_score,
rag_confidence=rag_confidence,
rag_score_pos_only=rag_score_pos_only,
)
markup = get_reply_keyboard_for_post() markup = get_reply_keyboard_for_post()
sent_message = await send_audio_message( sent_message = await send_audio_message(
self.settings.group_for_posts, message, message.audio.file_id, post_caption, markup self.settings.group_for_posts, message, message.audio.file_id, post_caption, markup
) )
# Сохраняем сырой caption и определяем анонимность # Определяем анонимность
raw_caption = message.caption or ""
is_anonymous = determine_anonymity(raw_caption) is_anonymous = determine_anonymity(raw_caption)
post = TelegramPost( post = TelegramPost(
@@ -274,8 +368,11 @@ class PostService:
is_anonymous=is_anonymous is_anonymous=is_anonymous
) )
await self.db.add_post(post) await self.db.add_post(post)
# Сохраняем медиа в фоне, чтобы не блокировать ответ пользователю
# Сохраняем медиа и скоры в фоне
asyncio.create_task(self._save_media_background(sent_message, self.db, self.s3_storage)) asyncio.create_task(self._save_media_background(sent_message, self.db, self.s3_storage))
if ml_scores_json:
asyncio.create_task(self._save_scores_background(sent_message.message_id, ml_scores_json))
@track_time("handle_voice_post", "post_service") @track_time("handle_voice_post", "post_service")
@track_errors("post_service", "handle_voice_post") @track_errors("post_service", "handle_voice_post")
@@ -310,10 +407,23 @@ class PostService:
"""Handle media group post submission""" """Handle media group post submission"""
post_caption = " " post_caption = " "
raw_caption = "" raw_caption = ""
ml_scores_json = None
if album and album[0].caption: if album and album[0].caption:
raw_caption = album[0].caption or "" raw_caption = album[0].caption or ""
post_caption = get_text_message(album[0].caption.lower(), first_name, message.from_user.username)
# Получаем скоры для текста
deepseek_score, rag_score, rag_confidence, rag_score_pos_only, ml_scores_json = await self._get_scores(raw_caption)
post_caption = get_text_message(
album[0].caption.lower(),
first_name,
message.from_user.username,
deepseek_score=deepseek_score,
rag_score=rag_score,
rag_confidence=rag_confidence,
rag_score_pos_only=rag_score_pos_only,
)
is_anonymous = determine_anonymity(raw_caption) is_anonymous = determine_anonymity(raw_caption)
media_group = await prepare_media_group_from_middlewares(album, post_caption) media_group = await prepare_media_group_from_middlewares(album, post_caption)
@@ -333,6 +443,10 @@ class PostService:
) )
await self.db.add_post(main_post) await self.db.add_post(main_post)
# Сохраняем скоры в фоне
if ml_scores_json:
asyncio.create_task(self._save_scores_background(main_post_id, ml_scores_json))
for msg_id in media_group_message_ids: for msg_id in media_group_message_ids:
await self.db.add_message_link(main_post_id, msg_id) await self.db.add_message_link(main_post_id, msg_id)

View File

@@ -47,6 +47,9 @@ def get_reply_keyboard_admin():
) )
builder.row( builder.row(
types.KeyboardButton(text="Разбан (список)"), types.KeyboardButton(text="Разбан (список)"),
types.KeyboardButton(text="📊 ML Статистика")
)
builder.row(
types.KeyboardButton(text="Вернуться в бота") types.KeyboardButton(text="Вернуться в бота")
) )
markup = builder.as_markup(resize_keyboard=True, one_time_keyboard=True) markup = builder.as_markup(resize_keyboard=True, one_time_keyboard=True)

View File

@@ -78,6 +78,22 @@ async def start_bot(bdf):
await bot.delete_webhook(drop_pending_updates=True) await bot.delete_webhook(drop_pending_updates=True)
# Загружаем примеры для RAG из базы данных
scoring_manager = bdf.get_scoring_manager()
if scoring_manager and scoring_manager.rag_service and scoring_manager.rag_service.is_enabled:
try:
db = bdf.get_db()
positive_texts = await db.get_approved_posts_texts(limit=5000)
negative_texts = await db.get_declined_posts_texts(limit=5000)
if positive_texts or negative_texts:
await scoring_manager.load_examples_from_db(positive_texts, negative_texts)
logging.info(f"RAG: Загружено {len(positive_texts)} положительных и {len(negative_texts)} отрицательных примеров")
else:
logging.warning("RAG: Нет примеров в базе данных для загрузки")
except Exception as e:
logging.error(f"Ошибка загрузки примеров для RAG: {e}")
# Запускаем HTTP сервер для метрик параллельно с ботом # Запускаем HTTP сервер для метрик параллельно с ботом
metrics_host = bdf.settings.get('Metrics', {}).get('host', '0.0.0.0') metrics_host = bdf.settings.get('Metrics', {}).get('host', '0.0.0.0')
metrics_port = bdf.settings.get('Metrics', {}).get('port', 8080) metrics_port = bdf.settings.get('Metrics', {}).get('port', 8080)

View File

@@ -0,0 +1,5 @@
"""
Сервисы приложения.
Содержит бизнес-логику, не связанную напрямую с handlers.
"""

View File

@@ -0,0 +1,42 @@
"""
Сервисы для ML-скоринга постов.
Включает:
- RAGService - локальное векторное сравнение с ruBERT
- DeepSeekService - интеграция с DeepSeek API
- ScoringManager - объединение всех сервисов скоринга
- VectorStore - in-memory хранилище векторов
"""
from .base import ScoringResult, ScoringServiceProtocol, CombinedScore
from .exceptions import (
ScoringError,
ModelNotLoadedError,
VectorStoreError,
DeepSeekAPIError,
InsufficientExamplesError,
TextTooShortError,
)
from .vector_store import VectorStore
from .rag_service import RAGService
from .deepseek_service import DeepSeekService
from .scoring_manager import ScoringManager
__all__ = [
# Базовые классы
"ScoringResult",
"ScoringServiceProtocol",
"CombinedScore",
# Исключения
"ScoringError",
"ModelNotLoadedError",
"VectorStoreError",
"DeepSeekAPIError",
"InsufficientExamplesError",
"TextTooShortError",
# Сервисы
"VectorStore",
"RAGService",
"DeepSeekService",
"ScoringManager",
]

View File

@@ -0,0 +1,155 @@
"""
Базовые классы и протоколы для сервисов скоринга.
"""
from dataclasses import dataclass, field
from typing import Optional, Protocol, Dict, Any
from datetime import datetime
@dataclass
class ScoringResult:
"""
Результат оценки поста от одного сервиса.
Attributes:
score: Оценка от 0.0 до 1.0 (вероятность публикации)
source: Источник оценки ("deepseek", "rag", etc.)
model: Название используемой модели
confidence: Уверенность в оценке (опционально)
timestamp: Время получения оценки
metadata: Дополнительные данные
"""
score: float
source: str
model: str
confidence: Optional[float] = None
timestamp: int = field(default_factory=lambda: int(datetime.now().timestamp()))
metadata: Dict[str, Any] = field(default_factory=dict)
def __post_init__(self):
"""Валидация score в диапазоне [0.0, 1.0]."""
if not 0.0 <= self.score <= 1.0:
raise ValueError(f"Score должен быть в диапазоне [0.0, 1.0], получено: {self.score}")
def to_dict(self) -> Dict[str, Any]:
"""Преобразует результат в словарь для сохранения в JSON."""
result = {
"score": round(self.score, 4),
"model": self.model,
"ts": self.timestamp,
}
if self.confidence is not None:
result["confidence"] = round(self.confidence, 4)
if self.metadata:
result["metadata"] = self.metadata
return result
@classmethod
def from_dict(cls, source: str, data: Dict[str, Any]) -> "ScoringResult":
"""Создает ScoringResult из словаря."""
return cls(
score=data["score"],
source=source,
model=data.get("model", "unknown"),
confidence=data.get("confidence"),
timestamp=data.get("ts", int(datetime.now().timestamp())),
metadata=data.get("metadata", {}),
)
@dataclass
class CombinedScore:
"""
Объединенный результат от всех сервисов скоринга.
Attributes:
deepseek: Результат от DeepSeek API (None если отключен/ошибка)
rag: Результат от RAG сервиса (None если отключен/ошибка)
errors: Словарь с ошибками по источникам
"""
deepseek: Optional[ScoringResult] = None
rag: Optional[ScoringResult] = None
errors: Dict[str, str] = field(default_factory=dict)
@property
def deepseek_score(self) -> Optional[float]:
"""Возвращает только числовой скор от DeepSeek."""
return self.deepseek.score if self.deepseek else None
@property
def rag_score(self) -> Optional[float]:
"""Возвращает только числовой скор от RAG."""
return self.rag.score if self.rag else None
def to_json_dict(self) -> Dict[str, Any]:
"""
Преобразует в словарь для сохранения в ml_scores колонку.
Формат:
{
"deepseek": {"score": 0.75, "model": "...", "ts": ...},
"rag": {"score": 0.90, "model": "...", "ts": ...}
}
"""
result = {}
if self.deepseek:
result["deepseek"] = self.deepseek.to_dict()
if self.rag:
result["rag"] = self.rag.to_dict()
return result
def has_any_score(self) -> bool:
"""Проверяет, есть ли хотя бы один успешный скор."""
return self.deepseek is not None or self.rag is not None
class ScoringServiceProtocol(Protocol):
"""
Протокол для сервисов скоринга.
Любой сервис скоринга должен реализовывать эти методы.
"""
@property
def source_name(self) -> str:
"""Возвращает имя источника ("deepseek", "rag", etc.)."""
...
@property
def is_enabled(self) -> bool:
"""Проверяет, включен ли сервис."""
...
async def calculate_score(self, text: str) -> ScoringResult:
"""
Рассчитывает скор для текста поста.
Args:
text: Текст поста для оценки
Returns:
ScoringResult с оценкой
Raises:
ScoringError: При ошибке расчета
"""
...
async def add_positive_example(self, text: str) -> None:
"""
Добавляет текст как положительный пример (опубликованный пост).
Args:
text: Текст опубликованного поста
"""
...
async def add_negative_example(self, text: str) -> None:
"""
Добавляет текст как отрицательный пример (отклоненный пост).
Args:
text: Текст отклоненного поста
"""
...

View File

@@ -0,0 +1,358 @@
"""
DeepSeek API сервис для скоринга постов.
Использует DeepSeek API для семантической оценки релевантности поста.
"""
import asyncio
import json
from typing import Optional, List
import httpx
from logs.custom_logger import logger
from helper_bot.utils.metrics import track_time, track_errors
from .base import ScoringResult
from .exceptions import DeepSeekAPIError, ScoringError, TextTooShortError
class DeepSeekService:
"""
Сервис для оценки постов через DeepSeek API.
Отправляет текст поста в DeepSeek с промптом для оценки
и получает числовой скор релевантности.
Attributes:
api_key: API ключ DeepSeek
api_url: URL API эндпоинта
model: Название модели
timeout: Таймаут запроса в секундах
"""
# Промпт для оценки поста
SCORING_PROMPT = """Роль: Ты — строгий и внимательный модератор сообщества в социальной сети, ориентированного на знакомства между людьми. Твоя задача — оценить, можно ли опубликовать пост, основываясь на четких правилах.
Контекст группы: Это группа для поиска и знакомства с людьми. Пользователи могут искать кого угодно: случайно увиденных на улице, в транспорте, в кафе, старых знакомых, новых друзей или пару. Это главная и единственная цель группы.
---
ПРАВИЛА ЗАПРЕТА (пост НЕ ДОЛЖЕН быть опубликован, если содержит это):
1. Запрещенные законом тематики: Любые призывы, обсуждение или поиск чего-либо незаконного (наркотики, оружие, мошенничество, насилие и т.д.).
2. Поиск и утеря животных, найденные предметы: Запрещены посты про потерявшихся/найденных кошек, собак, хомяков, а также про потерянные/найденные телефоны, ключи, сумки и т.п.
3. Конкуренция (Дайвинчик): Любое упоминание группы/проекта/чата "Дайвинчик" или любых других групп-конкурентов. Запрещены призывы переходить в другие сообщества.
4. Сбор больших компаний и групп: Запрещены посты с целью собрать большую тусовку, компанию, группу для похода, вечеринки, игры и т.д. (например, "собираем команду для футбола", "кто хочет на квартиру?").
5. Организация чатов и других сообществ: Запрещено создание или реклама сторонних чатов, каналов, групп в телеграме, дискорде и т.п.
---
ПРАВИЛА РАЗРЕШЕНИЯ (пост МОЖЕТ быть опубликован, если):
· Цель — найти конкретного человека или познакомиться с кем-то новым.
· Формат: Описание человека, обстоятельств встречи, примет, места и времени. Или прямой призыв к знакомству.
· Примеры ДОПУСТИМЫХ постов (ориентируйся на них):
· "мальчики нефоры/патлатые, гоу знакомиться😻 анон"
· "ищу девочку, ехала на 21 автобусе примерно в 15:20. села на детской поликлинике и вышла в заречье вся в черной одежде и с черным баулом"
· "ищу мальчика ехали на 35 автобусе часов в 7 вечера я была с девочками,у нас с тобой еще куртки одинаковые ,я рядом с тобой сидела,напиши в комментарии если у тебя нету девочки. анон админу любви."
---
ИНСТРУКЦИЯ ПО ОЦЕНКЕ:
Проанализируй полученный пост и присвой ему итоговый Вес (Score) от 0.0 до 1.0, где:
· 1.0 — Пост полностью соответствует правилам. Цель — найти/познакомиться с человеком. Ничего из списка запретов не нарушено. Можно публиковать.
· 0.0 — Пост категорически нарушает правила. Содержит явные признаки одного или нескольких пунктов из списка запрета. Публиковать НЕЛЬЗЯ.
· 0.2 - 0.8 — Пост находится в "серой зоне". Присваивай промежуточный вес, оценивая степень риска и соответствия цели группы.
· Ближе к 0.2: Сильно сомнительный пост, есть явные признаки запрещенной темы (например, упоминание "собраться компанией", косвенная реклама другого места).
· 0.5: Нейтральный или неочевидный пост. Нужно проверить, нет ли скрытого смысла, нарушающего правила.
· Ближе к 0.8: В целом допустимый пост, но с небольшими странностями или двусмысленностями, не нарушающими правила напрямую.
---
{text}
---
Ответь ТОЛЬКО числом от 0.0 до 1.0, без дополнительных объяснений.
Пример ответа: 0.75"""
DEFAULT_API_URL = "https://api.deepseek.com/v1/chat/completions"
DEFAULT_MODEL = "deepseek-chat"
def __init__(
self,
api_key: Optional[str] = None,
api_url: Optional[str] = None,
model: Optional[str] = None,
timeout: int = 30,
enabled: bool = True,
min_text_length: int = 3,
max_retries: int = 3,
):
"""
Инициализация DeepSeek сервиса.
Args:
api_key: API ключ DeepSeek
api_url: URL API эндпоинта
model: Название модели
timeout: Таймаут запроса в секундах
enabled: Включен ли сервис
min_text_length: Минимальная длина текста для обработки
max_retries: Максимальное количество повторных попыток
"""
self.api_key = api_key
self.api_url = api_url or self.DEFAULT_API_URL
self.model = model or self.DEFAULT_MODEL
self.timeout = timeout
self._enabled = enabled and bool(api_key)
self.min_text_length = min_text_length
self.max_retries = max_retries
# HTTP клиент (создается лениво)
self._client: Optional[httpx.AsyncClient] = None
if not api_key and enabled:
logger.warning("DeepSeekService: API ключ не указан, сервис отключен")
self._enabled = False
logger.info(
f"DeepSeekService инициализирован "
f"(model={self.model}, enabled={self._enabled})"
)
@property
def source_name(self) -> str:
"""Имя источника для результатов."""
return "deepseek"
@property
def is_enabled(self) -> bool:
"""Проверяет, включен ли сервис."""
return self._enabled
async def _get_client(self) -> httpx.AsyncClient:
"""Получает или создает HTTP клиент."""
if self._client is None:
self._client = httpx.AsyncClient(
timeout=httpx.Timeout(self.timeout),
headers={
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
},
)
return self._client
async def close(self) -> None:
"""Закрывает HTTP клиент."""
if self._client:
await self._client.aclose()
self._client = None
def _clean_text(self, text: str) -> str:
"""Очищает текст от лишних символов."""
if not text:
return ""
# Удаляем лишние пробелы и переносы строк
clean = " ".join(text.split())
# Удаляем служебные символы
if clean == "^":
return ""
return clean.strip()
def _parse_score_response(self, response_text: str) -> float:
"""
Парсит ответ от DeepSeek и извлекает скор.
Args:
response_text: Текст ответа от API
Returns:
Числовой скор от 0.0 до 1.0
Raises:
DeepSeekAPIError: Если не удалось распарсить ответ
"""
try:
# Пытаемся найти число в ответе
text = response_text.strip()
# Убираем возможные обрамления
text = text.strip('"\'`')
# Пробуем распарсить как число
score = float(text)
# Ограничиваем диапазон
score = max(0.0, min(1.0, score))
return score
except ValueError:
# Пробуем найти число в тексте
import re
matches = re.findall(r'0\.\d+|1\.0|0|1', text)
if matches:
score = float(matches[0])
return max(0.0, min(1.0, score))
logger.error(f"DeepSeekService: Не удалось распарсить ответ: {response_text}")
raise DeepSeekAPIError(f"Не удалось распарсить скор из ответа: {response_text}")
@track_time("calculate_score", "deepseek_service")
@track_errors("deepseek_service", "calculate_score")
async def calculate_score(self, text: str) -> ScoringResult:
"""
Рассчитывает скор для текста поста через DeepSeek API.
Args:
text: Текст поста для оценки
Returns:
ScoringResult с оценкой
Raises:
ScoringError: При ошибке расчета
"""
if not self._enabled:
raise ScoringError("DeepSeek сервис отключен")
# Очищаем текст
clean_text = self._clean_text(text)
if len(clean_text) < self.min_text_length:
raise TextTooShortError(
f"Текст слишком короткий (минимум {self.min_text_length} символов)"
)
# Формируем промпт
prompt = self.SCORING_PROMPT.format(text=clean_text)
# Выполняем запрос с повторными попытками
last_error = None
for attempt in range(self.max_retries):
try:
score = await self._make_api_request(prompt)
return ScoringResult(
score=score,
source=self.source_name,
model=self.model,
metadata={
"text_length": len(clean_text),
"attempt": attempt + 1,
},
)
except DeepSeekAPIError as e:
last_error = e
logger.warning(
f"DeepSeekService: Попытка {attempt + 1}/{self.max_retries} "
f"не удалась: {e}"
)
if attempt < self.max_retries - 1:
# Экспоненциальная задержка
await asyncio.sleep(2 ** attempt)
raise ScoringError(f"Все попытки запроса к DeepSeek API не удались: {last_error}")
async def _make_api_request(self, prompt: str) -> float:
"""
Выполняет запрос к DeepSeek API.
Args:
prompt: Промпт для отправки
Returns:
Числовой скор от 0.0 до 1.0
Raises:
DeepSeekAPIError: При ошибке API
"""
client = await self._get_client()
payload = {
"model": self.model,
"messages": [
{
"role": "user",
"content": prompt,
}
],
"temperature": 0.1, # Низкая температура для детерминированности
"max_tokens": 10, # Ожидаем только число
}
try:
response = await client.post(self.api_url, json=payload)
response.raise_for_status()
data = response.json()
# Извлекаем ответ
if "choices" not in data or not data["choices"]:
raise DeepSeekAPIError("Пустой ответ от API")
response_text = data["choices"][0]["message"]["content"]
# Парсим скор
score = self._parse_score_response(response_text)
logger.debug(f"DeepSeekService: Получен скор {score} для текста")
return score
except httpx.HTTPStatusError as e:
error_msg = f"HTTP ошибка {e.response.status_code}"
try:
error_data = e.response.json()
if "error" in error_data:
error_msg = error_data["error"].get("message", error_msg)
except Exception:
pass
raise DeepSeekAPIError(error_msg)
except httpx.TimeoutException:
raise DeepSeekAPIError(f"Таймаут запроса ({self.timeout}s)")
except Exception as e:
raise DeepSeekAPIError(f"Ошибка запроса: {e}")
async def add_positive_example(self, text: str) -> None:
"""
Добавляет текст как положительный пример.
Для DeepSeek не требуется хранить примеры - оценка выполняется
на основе промпта. Метод существует для совместимости с протоколом.
Args:
text: Текст опубликованного поста
"""
# DeepSeek не использует примеры для обучения
# Промпт уже содержит критерии оценки
pass
async def add_negative_example(self, text: str) -> None:
"""
Добавляет текст как отрицательный пример.
Для DeepSeek не требуется хранить примеры - оценка выполняется
на основе промпта. Метод существует для совместимости с протоколом.
Args:
text: Текст отклоненного поста
"""
# DeepSeek не использует примеры для обучения
pass
def get_stats(self) -> dict:
"""Возвращает статистику сервиса."""
return {
"enabled": self._enabled,
"model": self.model,
"api_url": self.api_url,
"timeout": self.timeout,
"max_retries": self.max_retries,
}

View File

@@ -0,0 +1,33 @@
"""
Исключения для сервисов скоринга.
"""
class ScoringError(Exception):
"""Базовое исключение для ошибок скоринга."""
pass
class ModelNotLoadedError(ScoringError):
"""Модель не загружена или недоступна."""
pass
class VectorStoreError(ScoringError):
"""Ошибка при работе с хранилищем векторов."""
pass
class DeepSeekAPIError(ScoringError):
"""Ошибка при обращении к DeepSeek API."""
pass
class InsufficientExamplesError(ScoringError):
"""Недостаточно примеров для расчета скора."""
pass
class TextTooShortError(ScoringError):
"""Текст слишком короткий для векторизации."""
pass

View File

@@ -0,0 +1,507 @@
"""
RAG сервис для скоринга постов с использованием ruBERT.
Использует модель DeepPavlov/rubert-base-cased для создания эмбеддингов
и сравнивает их с эталонными примерами через VectorStore.
"""
import asyncio
from typing import Optional, List
import numpy as np
from logs.custom_logger import logger
from helper_bot.utils.metrics import track_time, track_errors
from .base import ScoringResult
from .vector_store import VectorStore
from .exceptions import (
ModelNotLoadedError,
ScoringError,
InsufficientExamplesError,
TextTooShortError,
)
class RAGService:
"""
RAG сервис для оценки постов на основе векторного сходства.
Использует ruBERT для создания эмбеддингов текста и сравнивает
их с эталонными примерами (опубликованные vs отклоненные посты).
Attributes:
model_name: Название модели HuggingFace
vector_store: Хранилище векторов
min_text_length: Минимальная длина текста для обработки
"""
# Название модели по умолчанию
DEFAULT_MODEL = "DeepPavlov/rubert-base-cased"
def __init__(
self,
model_name: Optional[str] = None,
vector_store: Optional[VectorStore] = None,
cache_dir: Optional[str] = None,
enabled: bool = True,
min_text_length: int = 3,
):
"""
Инициализация RAG сервиса.
Args:
model_name: Название модели HuggingFace (по умолчанию ruBERT)
vector_store: Хранилище векторов (создается автоматически если не передано)
cache_dir: Директория для кеширования модели
enabled: Включен ли сервис
min_text_length: Минимальная длина текста для обработки
"""
self.model_name = model_name or self.DEFAULT_MODEL
self.cache_dir = cache_dir
self._enabled = enabled
self.min_text_length = min_text_length
# Модель и токенизатор загружаются лениво
self._model = None
self._tokenizer = None
self._model_loaded = False
# Хранилище векторов
self.vector_store = vector_store or VectorStore()
logger.info(f"RAGService инициализирован (model={self.model_name}, enabled={enabled})")
@property
def source_name(self) -> str:
"""Имя источника для результатов."""
return "rag"
@property
def is_enabled(self) -> bool:
"""Проверяет, включен ли сервис."""
return self._enabled
@property
def is_model_loaded(self) -> bool:
"""Проверяет, загружена ли модель."""
return self._model_loaded
async def load_model(self) -> None:
"""
Загружает модель и токенизатор.
Выполняется асинхронно в отдельном потоке чтобы не блокировать event loop.
"""
if self._model_loaded:
return
if not self._enabled:
logger.warning("RAGService: Сервис отключен, модель не загружается")
return
logger.info(f"RAGService: Загрузка модели {self.model_name}...")
try:
# Загрузка в отдельном потоке
loop = asyncio.get_event_loop()
await loop.run_in_executor(None, self._load_model_sync)
self._model_loaded = True
logger.info(f"RAGService: Модель {self.model_name} успешно загружена")
except Exception as e:
logger.error(f"RAGService: Ошибка загрузки модели: {e}")
raise ModelNotLoadedError(f"Не удалось загрузить модель {self.model_name}: {e}")
def _load_model_sync(self) -> None:
"""Синхронная загрузка модели (вызывается в executor)."""
logger.info("RAGService: Начало _load_model_sync, импорт transformers...")
from transformers import AutoTokenizer, AutoModel
import torch
# Определяем устройство
self._device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"RAGService: Устройство определено: {self._device}")
# Загружаем токенизатор
logger.info(f"RAGService: Загрузка токенизатора из {self.model_name}...")
self._tokenizer = AutoTokenizer.from_pretrained(
self.model_name,
cache_dir=self.cache_dir,
)
logger.info("RAGService: Токенизатор загружен")
# Загружаем модель
logger.info(f"RAGService: Загрузка модели из {self.model_name} (это может занять несколько минут)...")
self._model = AutoModel.from_pretrained(
self.model_name,
cache_dir=self.cache_dir,
)
logger.info("RAGService: Модель загружена, перенос на устройство...")
self._model.to(self._device)
self._model.eval() # Режим инференса
logger.info(f"RAGService: Модель готова на устройстве: {self._device}")
def _get_embedding_sync(self, text: str) -> np.ndarray:
"""
Получает эмбеддинг текста (синхронно).
Использует [CLS] токен как представление всего текста.
Args:
text: Текст для векторизации
Returns:
Numpy массив с эмбеддингом (768 измерений для ruBERT)
"""
import torch
# Токенизация с ограничением длины
inputs = self._tokenizer(
text,
return_tensors="pt",
truncation=True,
max_length=512,
padding=True,
)
inputs = {k: v.to(self._device) for k, v in inputs.items()}
# Получаем эмбеддинг
with torch.no_grad():
outputs = self._model(**inputs)
# Используем [CLS] токен (первый токен)
embedding = outputs.last_hidden_state[:, 0, :].cpu().numpy()
return embedding.flatten()
def _get_embeddings_batch_sync(self, texts: List[str], batch_size: int = 16) -> List[np.ndarray]:
"""
Получает эмбеддинги для батча текстов (синхронно).
Обрабатывает тексты пачками для эффективного использования GPU/CPU.
Args:
texts: Список текстов для векторизации
batch_size: Размер батча (по умолчанию 16)
Returns:
Список numpy массивов с эмбеддингами
"""
import torch
all_embeddings = []
for i in range(0, len(texts), batch_size):
batch_texts = texts[i:i + batch_size]
# Токенизация батча
inputs = self._tokenizer(
batch_texts,
return_tensors="pt",
truncation=True,
max_length=512,
padding=True,
)
inputs = {k: v.to(self._device) for k, v in inputs.items()}
# Получаем эмбеддинги
with torch.no_grad():
outputs = self._model(**inputs)
# [CLS] токен для каждого текста в батче
batch_embeddings = outputs.last_hidden_state[:, 0, :].cpu().numpy()
# Разбиваем на отдельные эмбеддинги
for j in range(len(batch_texts)):
all_embeddings.append(batch_embeddings[j])
if i > 0 and i % (batch_size * 10) == 0:
logger.info(f"RAGService: Обработано {i}/{len(texts)} текстов")
return all_embeddings
async def get_embeddings_batch(self, texts: List[str], batch_size: int = 16) -> List[np.ndarray]:
"""
Получает эмбеддинги для батча текстов (асинхронно).
Args:
texts: Список текстов для векторизации
batch_size: Размер батча
Returns:
Список numpy массивов с эмбеддингами
"""
if not self._model_loaded:
await self.load_model()
if not self._model_loaded:
raise ModelNotLoadedError("Модель не загружена")
# Очищаем тексты
clean_texts = [self._clean_text(text) for text in texts]
# Выполняем батч-обработку в thread pool
loop = asyncio.get_event_loop()
embeddings = await loop.run_in_executor(
None,
self._get_embeddings_batch_sync,
clean_texts,
batch_size,
)
return embeddings
async def get_embedding(self, text: str) -> np.ndarray:
"""
Получает эмбеддинг текста (асинхронно).
Args:
text: Текст для векторизации
Returns:
Numpy массив с эмбеддингом
Raises:
ModelNotLoadedError: Если модель не загружена
TextTooShortError: Если текст слишком короткий
"""
if not self._model_loaded:
await self.load_model()
if not self._model_loaded:
raise ModelNotLoadedError("Модель не загружена")
# Очищаем текст
clean_text = self._clean_text(text)
if len(clean_text) < self.min_text_length:
raise TextTooShortError(
f"Текст слишком короткий (минимум {self.min_text_length} символов)"
)
# Выполняем в отдельном потоке
loop = asyncio.get_event_loop()
embedding = await loop.run_in_executor(
None,
self._get_embedding_sync,
clean_text
)
return embedding
def _clean_text(self, text: str) -> str:
"""Очищает текст от лишних символов."""
if not text:
return ""
# Удаляем лишние пробелы и переносы строк
clean = " ".join(text.split())
# Удаляем служебные символы (например "^" для helper сообщений)
if clean == "^":
return ""
return clean.strip()
@track_time("calculate_score", "rag_service")
@track_errors("rag_service", "calculate_score")
async def calculate_score(self, text: str) -> ScoringResult:
"""
Рассчитывает скор для текста поста.
Args:
text: Текст поста для оценки
Returns:
ScoringResult с оценкой
Raises:
ScoringError: При ошибке расчета
"""
if not self._enabled:
raise ScoringError("RAG сервис отключен")
try:
# Получаем эмбеддинг текста
embedding = await self.get_embedding(text)
# Логируем первые элементы вектора для отладки
logger.info(
f"RAGService: embedding[:3]={embedding[:3].tolist()}, "
f"text_preview='{text[:30]}'"
)
# Рассчитываем скор через VectorStore
score, confidence, score_pos_only = self.vector_store.calculate_similarity_score(embedding)
return ScoringResult(
score=score,
source=self.source_name,
model=self.model_name,
confidence=confidence,
metadata={
"positive_examples": self.vector_store.positive_count,
"negative_examples": self.vector_store.negative_count,
"score_pos_only": score_pos_only, # Для сравнения
},
)
except InsufficientExamplesError:
# Не достаточно примеров - возвращаем нейтральный скор
logger.warning("RAGService: Недостаточно примеров для расчета скора")
raise
except TextTooShortError:
logger.warning(f"RAGService: Текст слишком короткий для оценки")
raise
except Exception as e:
logger.error(f"RAGService: Ошибка расчета скора: {e}")
raise ScoringError(f"Ошибка расчета скора: {e}")
@track_time("add_positive_example", "rag_service")
async def add_positive_example(self, text: str) -> None:
"""
Добавляет текст как положительный пример (опубликованный пост).
Args:
text: Текст опубликованного поста
"""
if not self._enabled:
return
try:
clean_text = self._clean_text(text)
if len(clean_text) < self.min_text_length:
logger.debug("RAGService: Текст слишком короткий для примера, пропускаем")
return
# Получаем эмбеддинг
embedding = await self.get_embedding(clean_text)
# Вычисляем хеш для дедупликации
text_hash = VectorStore.compute_text_hash(clean_text)
# Добавляем в хранилище
added = self.vector_store.add_positive(embedding, text_hash)
if added:
logger.info(f"RAGService: Добавлен положительный пример")
except Exception as e:
logger.error(f"RAGService: Ошибка добавления положительного примера: {e}")
@track_time("add_negative_example", "rag_service")
async def add_negative_example(self, text: str) -> None:
"""
Добавляет текст как отрицательный пример (отклоненный пост).
Args:
text: Текст отклоненного поста
"""
if not self._enabled:
return
try:
clean_text = self._clean_text(text)
if len(clean_text) < self.min_text_length:
logger.debug("RAGService: Текст слишком короткий для примера, пропускаем")
return
# Получаем эмбеддинг
embedding = await self.get_embedding(clean_text)
# Вычисляем хеш для дедупликации
text_hash = VectorStore.compute_text_hash(clean_text)
# Добавляем в хранилище
added = self.vector_store.add_negative(embedding, text_hash)
if added:
logger.info(f"RAGService: Добавлен отрицательный пример")
except Exception as e:
logger.error(f"RAGService: Ошибка добавления отрицательного примера: {e}")
async def load_examples_from_db(
self,
positive_texts: list[str],
negative_texts: list[str],
batch_size: int = 16,
) -> None:
"""
Загружает примеры из базы данных с батч-обработкой.
Используется при запуске бота для восстановления VectorStore.
Батч-обработка ускоряет загрузку в 10-20 раз.
Args:
positive_texts: Список текстов опубликованных постов
negative_texts: Список текстов отклоненных постов
batch_size: Размер батча для обработки (по умолчанию 16)
"""
if not self._enabled:
return
logger.info(
f"RAGService: Загрузка примеров из БД с батч-обработкой "
f"(positive: {len(positive_texts)}, negative: {len(negative_texts)}, batch_size: {batch_size})"
)
# Убеждаемся что модель загружена
await self.load_model()
import time
start_time = time.time()
# Фильтруем и очищаем положительные тексты
if positive_texts:
clean_positive = []
positive_hashes = []
for text in positive_texts:
clean_text = self._clean_text(text)
if len(clean_text) >= self.min_text_length:
clean_positive.append(clean_text)
positive_hashes.append(VectorStore.compute_text_hash(clean_text))
if clean_positive:
logger.info(f"RAGService: Обработка {len(clean_positive)} положительных примеров батчами...")
positive_embeddings = await self.get_embeddings_batch(clean_positive, batch_size)
self.vector_store.add_positive_batch(positive_embeddings, positive_hashes)
# Фильтруем и очищаем отрицательные тексты
if negative_texts:
clean_negative = []
negative_hashes = []
for text in negative_texts:
clean_text = self._clean_text(text)
if len(clean_text) >= self.min_text_length:
clean_negative.append(clean_text)
negative_hashes.append(VectorStore.compute_text_hash(clean_text))
if clean_negative:
logger.info(f"RAGService: Обработка {len(clean_negative)} отрицательных примеров батчами...")
negative_embeddings = await self.get_embeddings_batch(clean_negative, batch_size)
self.vector_store.add_negative_batch(negative_embeddings, negative_hashes)
elapsed = time.time() - start_time
logger.info(
f"RAGService: Загрузка завершена за {elapsed:.1f} сек "
f"(positive: {self.vector_store.positive_count}, "
f"negative: {self.vector_store.negative_count})"
)
def save_vectors(self) -> None:
"""Сохраняет векторы на диск."""
if self.vector_store.storage_path:
self.vector_store.save_to_disk()
def get_stats(self) -> dict:
"""Возвращает статистику сервиса."""
return {
"enabled": self._enabled,
"model_name": self.model_name,
"model_loaded": self._model_loaded,
"vector_store": self.vector_store.get_stats(),
}

View File

@@ -0,0 +1,242 @@
"""
Менеджер для объединения всех сервисов скоринга.
Координирует работу RAGService и DeepSeekService,
выполняет параллельные запросы и агрегирует результаты.
"""
import asyncio
from typing import Optional, List
from logs.custom_logger import logger
from helper_bot.utils.metrics import track_time, track_errors
from .base import CombinedScore, ScoringResult
from .rag_service import RAGService
from .deepseek_service import DeepSeekService
from .vector_store import VectorStore
from .exceptions import ScoringError, InsufficientExamplesError, TextTooShortError
class ScoringManager:
"""
Менеджер для управления всеми сервисами скоринга.
Объединяет RAGService и DeepSeekService, выполняет параллельные
запросы и агрегирует результаты в единый CombinedScore.
Attributes:
rag_service: Сервис RAG с ruBERT
deepseek_service: Сервис DeepSeek API
"""
def __init__(
self,
rag_service: Optional[RAGService] = None,
deepseek_service: Optional[DeepSeekService] = None,
):
"""
Инициализация менеджера.
Args:
rag_service: Сервис RAG (создается автоматически если не передан)
deepseek_service: Сервис DeepSeek (создается автоматически если не передан)
"""
self.rag_service = rag_service
self.deepseek_service = deepseek_service
logger.info(
f"ScoringManager инициализирован "
f"(rag={rag_service is not None and rag_service.is_enabled}, "
f"deepseek={deepseek_service is not None and deepseek_service.is_enabled})"
)
@property
def is_any_enabled(self) -> bool:
"""Проверяет, включен ли хотя бы один сервис."""
rag_enabled = self.rag_service is not None and self.rag_service.is_enabled
deepseek_enabled = self.deepseek_service is not None and self.deepseek_service.is_enabled
return rag_enabled or deepseek_enabled
@track_time("score_post", "scoring_manager")
@track_errors("scoring_manager", "score_post")
async def score_post(self, text: str) -> CombinedScore:
"""
Рассчитывает скоры для текста поста от всех сервисов.
Выполняет запросы параллельно для минимизации задержки.
Args:
text: Текст поста для оценки
Returns:
CombinedScore с результатами от всех сервисов
"""
result = CombinedScore()
if not text or not text.strip():
logger.debug("ScoringManager: Пустой текст, пропускаем скоринг")
return result
# Собираем задачи для параллельного выполнения
tasks = []
task_names = []
# RAG сервис
if self.rag_service and self.rag_service.is_enabled:
tasks.append(self._get_rag_score(text))
task_names.append("rag")
# DeepSeek сервис
if self.deepseek_service and self.deepseek_service.is_enabled:
tasks.append(self._get_deepseek_score(text))
task_names.append("deepseek")
if not tasks:
logger.debug("ScoringManager: Нет активных сервисов для скоринга")
return result
# Выполняем параллельно
results = await asyncio.gather(*tasks, return_exceptions=True)
# Обрабатываем результаты
for name, res in zip(task_names, results):
if isinstance(res, Exception):
error_msg = str(res)
result.errors[name] = error_msg
logger.warning(f"ScoringManager: Ошибка от {name}: {error_msg}")
elif res is not None:
if name == "rag":
result.rag = res
elif name == "deepseek":
result.deepseek = res
logger.info(
f"ScoringManager: Скоринг завершен "
f"(rag={result.rag_score}, deepseek={result.deepseek_score})"
)
return result
async def _get_rag_score(self, text: str) -> Optional[ScoringResult]:
"""Получает скор от RAG сервиса."""
try:
return await self.rag_service.calculate_score(text)
except InsufficientExamplesError:
# Недостаточно примеров - это не ошибка, просто нет данных
logger.info("ScoringManager: RAG - недостаточно примеров")
return None
except TextTooShortError:
# Текст слишком короткий - пропускаем
logger.debug("ScoringManager: RAG - текст слишком короткий")
return None
except Exception as e:
logger.error(f"ScoringManager: RAG ошибка: {e}")
raise
async def _get_deepseek_score(self, text: str) -> Optional[ScoringResult]:
"""Получает скор от DeepSeek сервиса."""
try:
return await self.deepseek_service.calculate_score(text)
except TextTooShortError:
# Текст слишком короткий - пропускаем
logger.debug("ScoringManager: DeepSeek - текст слишком короткий")
return None
except Exception as e:
logger.error(f"ScoringManager: DeepSeek ошибка: {e}")
raise
@track_time("on_post_published", "scoring_manager")
async def on_post_published(self, text: str) -> None:
"""
Вызывается при публикации поста.
Добавляет текст как положительный пример для обучения RAG.
Args:
text: Текст опубликованного поста
"""
if not text or not text.strip():
return
tasks = []
if self.rag_service and self.rag_service.is_enabled:
tasks.append(self.rag_service.add_positive_example(text))
if self.deepseek_service and self.deepseek_service.is_enabled:
tasks.append(self.deepseek_service.add_positive_example(text))
if tasks:
await asyncio.gather(*tasks, return_exceptions=True)
logger.info("ScoringManager: Добавлен положительный пример")
@track_time("on_post_declined", "scoring_manager")
async def on_post_declined(self, text: str) -> None:
"""
Вызывается при отклонении поста.
Добавляет текст как отрицательный пример для обучения RAG.
Args:
text: Текст отклоненного поста
"""
if not text or not text.strip():
return
tasks = []
if self.rag_service and self.rag_service.is_enabled:
tasks.append(self.rag_service.add_negative_example(text))
if self.deepseek_service and self.deepseek_service.is_enabled:
tasks.append(self.deepseek_service.add_negative_example(text))
if tasks:
await asyncio.gather(*tasks, return_exceptions=True)
logger.info("ScoringManager: Добавлен отрицательный пример")
async def load_examples_from_db(
self,
positive_texts: List[str],
negative_texts: List[str],
) -> None:
"""
Загружает примеры из базы данных при запуске бота.
Args:
positive_texts: Список текстов опубликованных постов
negative_texts: Список текстов отклоненных постов
"""
if self.rag_service and self.rag_service.is_enabled:
await self.rag_service.load_examples_from_db(
positive_texts,
negative_texts
)
def save_vectors(self) -> None:
"""Сохраняет векторы RAG на диск."""
if self.rag_service:
self.rag_service.save_vectors()
async def close(self) -> None:
"""Закрывает ресурсы всех сервисов."""
if self.deepseek_service:
await self.deepseek_service.close()
# Сохраняем векторы перед закрытием
self.save_vectors()
def get_stats(self) -> dict:
"""Возвращает статистику всех сервисов."""
stats = {
"any_enabled": self.is_any_enabled,
}
if self.rag_service:
stats["rag"] = self.rag_service.get_stats()
if self.deepseek_service:
stats["deepseek"] = self.deepseek_service.get_stats()
return stats

View File

@@ -0,0 +1,399 @@
"""
In-memory хранилище векторов на numpy.
Хранит векторные представления постов для быстрого сравнения.
Поддерживает персистентность через сохранение/загрузку с диска.
"""
import hashlib
import os
from pathlib import Path
from typing import Optional, Tuple, List
import threading
import numpy as np
from logs.custom_logger import logger
from .exceptions import VectorStoreError, InsufficientExamplesError
class VectorStore:
"""
In-memory хранилище векторов для RAG.
Хранит отдельно положительные (опубликованные) и отрицательные (отклоненные)
примеры. Использует косинусное сходство для расчета скора.
Attributes:
vector_dim: Размерность векторов (768 для ruBERT)
max_examples: Максимальное количество примеров каждого типа
"""
def __init__(
self,
vector_dim: int = 768,
max_examples: int = 10000,
storage_path: Optional[str] = None,
score_multiplier: float = 5.0,
):
"""
Инициализация хранилища.
Args:
vector_dim: Размерность векторов
max_examples: Максимальное количество примеров каждого типа
storage_path: Путь для сохранения/загрузки векторов (опционально)
score_multiplier: Множитель для усиления разницы в скорах
"""
self.vector_dim = vector_dim
self.max_examples = max_examples
self.storage_path = storage_path
self.score_multiplier = score_multiplier
# Инициализируем пустые массивы
# Используем список для динамического добавления, потом конвертируем в numpy
self._positive_vectors: list = []
self._negative_vectors: list = []
self._positive_hashes: list = [] # Хеши текстов для дедупликации
self._negative_hashes: list = []
# Lock для потокобезопасности
self._lock = threading.Lock()
# Пытаемся загрузить сохраненные векторы
if storage_path and os.path.exists(storage_path):
self._load_from_disk()
@property
def positive_count(self) -> int:
"""Количество положительных примеров."""
return len(self._positive_vectors)
@property
def negative_count(self) -> int:
"""Количество отрицательных примеров."""
return len(self._negative_vectors)
@property
def total_count(self) -> int:
"""Общее количество примеров."""
return self.positive_count + self.negative_count
@staticmethod
def compute_text_hash(text: str) -> str:
"""Вычисляет хеш текста для дедупликации."""
return hashlib.md5(text.encode('utf-8')).hexdigest()
def _normalize_vector(self, vector: np.ndarray) -> np.ndarray:
"""Нормализует вектор для косинусного сходства."""
norm = np.linalg.norm(vector)
if norm == 0:
return vector
return vector / norm
def add_positive(self, vector: np.ndarray, text_hash: Optional[str] = None) -> bool:
"""
Добавляет положительный пример (опубликованный пост).
Args:
vector: Векторное представление текста
text_hash: Хеш текста для дедупликации (опционально)
Returns:
True если добавлен, False если дубликат или превышен лимит
"""
with self._lock:
# Проверяем дубликат по хешу
if text_hash and text_hash in self._positive_hashes:
logger.debug(f"VectorStore: Пропуск дубликата положительного примера")
return False
# Проверяем лимит
if len(self._positive_vectors) >= self.max_examples:
# Удаляем самый старый пример (FIFO)
self._positive_vectors.pop(0)
self._positive_hashes.pop(0)
logger.debug("VectorStore: Удален старый положительный пример (лимит)")
# Нормализуем и добавляем
normalized = self._normalize_vector(vector)
self._positive_vectors.append(normalized)
if text_hash:
self._positive_hashes.append(text_hash)
logger.info(f"VectorStore: Добавлен положительный пример (всего: {self.positive_count})")
return True
def add_positive_batch(
self,
vectors: List[np.ndarray],
text_hashes: Optional[List[str]] = None
) -> int:
"""
Добавляет батч положительных примеров.
Args:
vectors: Список векторов
text_hashes: Список хешей текстов для дедупликации
Returns:
Количество добавленных примеров
"""
if text_hashes is None:
text_hashes = [None] * len(vectors)
added = 0
with self._lock:
for vector, text_hash in zip(vectors, text_hashes):
# Проверяем дубликат по хешу
if text_hash and text_hash in self._positive_hashes:
continue
# Проверяем лимит
if len(self._positive_vectors) >= self.max_examples:
self._positive_vectors.pop(0)
self._positive_hashes.pop(0)
# Нормализуем и добавляем
normalized = self._normalize_vector(vector)
self._positive_vectors.append(normalized)
if text_hash:
self._positive_hashes.append(text_hash)
added += 1
logger.info(f"VectorStore: Добавлено {added} положительных примеров батчем (всего: {self.positive_count})")
return added
def add_negative(self, vector: np.ndarray, text_hash: Optional[str] = None) -> bool:
"""
Добавляет отрицательный пример (отклоненный пост).
Args:
vector: Векторное представление текста
text_hash: Хеш текста для дедупликации (опционально)
Returns:
True если добавлен, False если дубликат или превышен лимит
"""
with self._lock:
# Проверяем дубликат по хешу
if text_hash and text_hash in self._negative_hashes:
logger.debug(f"VectorStore: Пропуск дубликата отрицательного примера")
return False
# Проверяем лимит
if len(self._negative_vectors) >= self.max_examples:
# Удаляем самый старый пример (FIFO)
self._negative_vectors.pop(0)
self._negative_hashes.pop(0)
logger.debug("VectorStore: Удален старый отрицательный пример (лимит)")
# Нормализуем и добавляем
normalized = self._normalize_vector(vector)
self._negative_vectors.append(normalized)
if text_hash:
self._negative_hashes.append(text_hash)
logger.info(f"VectorStore: Добавлен отрицательный пример (всего: {self.negative_count})")
return True
def add_negative_batch(
self,
vectors: List[np.ndarray],
text_hashes: Optional[List[str]] = None
) -> int:
"""
Добавляет батч отрицательных примеров.
Args:
vectors: Список векторов
text_hashes: Список хешей текстов для дедупликации
Returns:
Количество добавленных примеров
"""
if text_hashes is None:
text_hashes = [None] * len(vectors)
added = 0
with self._lock:
for vector, text_hash in zip(vectors, text_hashes):
# Проверяем дубликат по хешу
if text_hash and text_hash in self._negative_hashes:
continue
# Проверяем лимит
if len(self._negative_vectors) >= self.max_examples:
self._negative_vectors.pop(0)
self._negative_hashes.pop(0)
# Нормализуем и добавляем
normalized = self._normalize_vector(vector)
self._negative_vectors.append(normalized)
if text_hash:
self._negative_hashes.append(text_hash)
added += 1
logger.info(f"VectorStore: Добавлено {added} отрицательных примеров батчем (всего: {self.negative_count})")
return added
def calculate_similarity_score(self, vector: np.ndarray) -> Tuple[float, float]:
"""
Рассчитывает скор на основе сходства с примерами.
Алгоритм:
1. Вычисляем среднее косинусное сходство с положительными примерами
2. Вычисляем среднее косинусное сходство с отрицательными примерами
3. Финальный скор = pos_sim / (pos_sim + neg_sim + eps)
Args:
vector: Векторное представление нового поста
Returns:
Tuple (score, confidence):
- score: Оценка от 0.0 до 1.0
- confidence: Уверенность (зависит от количества примеров)
Raises:
InsufficientExamplesError: Если недостаточно примеров
"""
with self._lock:
if self.positive_count == 0:
raise InsufficientExamplesError(
"Нет положительных примеров для сравнения"
)
# Нормализуем входной вектор
normalized = self._normalize_vector(vector)
# Конвертируем в numpy массивы для быстрых вычислений
pos_matrix = np.array(self._positive_vectors)
# Косинусное сходство с положительными примерами
# Для нормализованных векторов это просто скалярное произведение
pos_similarities = np.dot(pos_matrix, normalized)
pos_sim = float(np.mean(pos_similarities))
# Косинусное сходство с отрицательными примерами
if self.negative_count > 0:
neg_matrix = np.array(self._negative_vectors)
neg_similarities = np.dot(neg_matrix, normalized)
neg_sim = float(np.mean(neg_similarities))
else:
# Если нет отрицательных примеров, используем нейтральное значение
neg_sim = pos_sim # Нейтральный скор = 0.5
# === Вариант 1: neg/pos (разница между положительными и отрицательными) ===
diff = pos_sim - neg_sim
score_neg_pos = 0.5 + (diff * self.score_multiplier)
score_neg_pos = max(0.0, min(1.0, score_neg_pos))
# === Вариант 2: pos only (только положительные, топ-k ближайших) ===
# Берём топ-5 ближайших положительных примеров
top_k = min(5, len(pos_similarities))
top_k_sim = float(np.mean(np.sort(pos_similarities)[-top_k:]))
# Нормализуем: 0.85 -> 0.0, 0.95 -> 1.0 (типичный диапазон для BERT)
score_pos_only = (top_k_sim - 0.85) / 0.10
score_pos_only = max(0.0, min(1.0, score_pos_only))
# Основной скор — neg/pos (можно будет переключить позже)
score = score_neg_pos
# Confidence зависит от количества примеров (100% при 1000 примерах)
total_examples = self.positive_count + self.negative_count
confidence = min(1.0, total_examples / 1000)
logger.info(
f"VectorStore: pos_sim={pos_sim:.4f}, neg_sim={neg_sim:.4f}, "
f"top_k_sim={top_k_sim:.4f}, score_neg_pos={score_neg_pos:.4f}, "
f"score_pos_only={score_pos_only:.4f}"
)
return score, confidence, score_pos_only
def save_to_disk(self, path: Optional[str] = None) -> None:
"""
Сохраняет векторы на диск.
Args:
path: Путь для сохранения (если не указан, используется storage_path)
"""
save_path = path or self.storage_path
if not save_path:
raise VectorStoreError("Путь для сохранения не указан")
with self._lock:
# Создаем директорию если нужно
Path(save_path).parent.mkdir(parents=True, exist_ok=True)
# Сохраняем в npz формате
np.savez_compressed(
save_path,
positive_vectors=np.array(self._positive_vectors) if self._positive_vectors else np.array([]),
negative_vectors=np.array(self._negative_vectors) if self._negative_vectors else np.array([]),
positive_hashes=np.array(self._positive_hashes, dtype=object),
negative_hashes=np.array(self._negative_hashes, dtype=object),
vector_dim=self.vector_dim,
max_examples=self.max_examples,
)
logger.info(
f"VectorStore: Сохранено на диск ({self.positive_count} pos, "
f"{self.negative_count} neg): {save_path}"
)
def _load_from_disk(self) -> None:
"""Загружает векторы с диска."""
if not self.storage_path or not os.path.exists(self.storage_path):
return
try:
with self._lock:
data = np.load(self.storage_path, allow_pickle=True)
# Загружаем векторы
pos_vectors = data.get('positive_vectors', np.array([]))
neg_vectors = data.get('negative_vectors', np.array([]))
if pos_vectors.size > 0:
self._positive_vectors = list(pos_vectors)
if neg_vectors.size > 0:
self._negative_vectors = list(neg_vectors)
# Загружаем хеши
pos_hashes = data.get('positive_hashes', np.array([]))
neg_hashes = data.get('negative_hashes', np.array([]))
if pos_hashes.size > 0:
self._positive_hashes = list(pos_hashes)
if neg_hashes.size > 0:
self._negative_hashes = list(neg_hashes)
logger.info(
f"VectorStore: Загружено с диска ({self.positive_count} pos, "
f"{self.negative_count} neg): {self.storage_path}"
)
except Exception as e:
logger.error(f"VectorStore: Ошибка загрузки с диска: {e}")
# Продолжаем с пустым хранилищем
def clear(self) -> None:
"""Очищает все векторы."""
with self._lock:
self._positive_vectors.clear()
self._negative_vectors.clear()
self._positive_hashes.clear()
self._negative_hashes.clear()
logger.info("VectorStore: Хранилище очищено")
def get_stats(self) -> dict:
"""Возвращает статистику хранилища."""
return {
"positive_count": self.positive_count,
"negative_count": self.negative_count,
"total_count": self.total_count,
"vector_dim": self.vector_dim,
"max_examples": self.max_examples,
"storage_path": self.storage_path,
}

View File

@@ -5,6 +5,7 @@ from typing import Optional
from database.async_db import AsyncBotDB from database.async_db import AsyncBotDB
from dotenv import load_dotenv from dotenv import load_dotenv
from helper_bot.utils.s3_storage import S3StorageService from helper_bot.utils.s3_storage import S3StorageService
from logs.custom_logger import logger
class BaseDependencyFactory: class BaseDependencyFactory:
@@ -15,6 +16,7 @@ class BaseDependencyFactory:
load_dotenv(env_path) load_dotenv(env_path)
self.settings = {} self.settings = {}
self._project_dir = project_dir
database_path = os.getenv('DATABASE_PATH', 'database/tg-bot-database.db') database_path = os.getenv('DATABASE_PATH', 'database/tg-bot-database.db')
if not os.path.isabs(database_path): if not os.path.isabs(database_path):
@@ -24,6 +26,9 @@ class BaseDependencyFactory:
self._load_settings_from_env() self._load_settings_from_env()
self._init_s3_storage() self._init_s3_storage()
# ScoringManager инициализируется лениво
self._scoring_manager = None
def _load_settings_from_env(self): def _load_settings_from_env(self):
"""Загружает настройки из переменных окружения.""" """Загружает настройки из переменных окружения."""
@@ -59,6 +64,23 @@ class BaseDependencyFactory:
'bucket_name': os.getenv('S3_BUCKET_NAME', ''), 'bucket_name': os.getenv('S3_BUCKET_NAME', ''),
'region': os.getenv('S3_REGION', 'us-east-1') 'region': os.getenv('S3_REGION', 'us-east-1')
} }
# Настройки ML-скоринга
self.settings['Scoring'] = {
# RAG (ruBERT)
'rag_enabled': self._parse_bool(os.getenv('RAG_ENABLED', 'false')),
'rag_model': os.getenv('RAG_MODEL', 'DeepPavlov/rubert-base-cased'),
'rag_cache_dir': os.getenv('RAG_CACHE_DIR', 'data/models'),
'rag_vectors_path': os.getenv('RAG_VECTORS_PATH', 'data/vectors.npz'),
'rag_max_examples': self._parse_int(os.getenv('RAG_MAX_EXAMPLES', '10000')),
'rag_score_multiplier': self._parse_float(os.getenv('RAG_SCORE_MULTIPLIER', '5.0')),
# DeepSeek
'deepseek_enabled': self._parse_bool(os.getenv('DEEPSEEK_ENABLED', 'false')),
'deepseek_api_key': os.getenv('DEEPSEEK_API_KEY', ''),
'deepseek_api_url': os.getenv('DEEPSEEK_API_URL', 'https://api.deepseek.com/v1/chat/completions'),
'deepseek_model': os.getenv('DEEPSEEK_MODEL', 'deepseek-chat'),
'deepseek_timeout': self._parse_int(os.getenv('DEEPSEEK_TIMEOUT', '30')),
}
def _init_s3_storage(self): def _init_s3_storage(self):
"""Инициализирует S3StorageService если S3 включен.""" """Инициализирует S3StorageService если S3 включен."""
@@ -84,6 +106,13 @@ class BaseDependencyFactory:
return int(value) return int(value)
except (ValueError, TypeError): except (ValueError, TypeError):
return 0 return 0
def _parse_float(self, value: str) -> float:
"""Парсит строковое значение в float."""
try:
return float(value)
except (ValueError, TypeError):
return 0.0
def get_settings(self): def get_settings(self):
return self.settings return self.settings
@@ -95,6 +124,100 @@ class BaseDependencyFactory:
def get_s3_storage(self) -> Optional[S3StorageService]: def get_s3_storage(self) -> Optional[S3StorageService]:
"""Возвращает S3StorageService если S3 включен, иначе None.""" """Возвращает S3StorageService если S3 включен, иначе None."""
return self.s3_storage return self.s3_storage
def _init_scoring_manager(self):
"""
Инициализирует ScoringManager с RAG и DeepSeek сервисами.
Вызывается лениво при первом обращении к get_scoring_manager().
"""
from helper_bot.services.scoring import (
ScoringManager,
RAGService,
DeepSeekService,
VectorStore,
)
scoring_config = self.settings['Scoring']
# Инициализация RAG сервиса
rag_service = None
if scoring_config['rag_enabled']:
# Путь к векторам
vectors_path = scoring_config['rag_vectors_path']
if not os.path.isabs(vectors_path):
vectors_path = os.path.join(self._project_dir, vectors_path)
# Путь к кешу моделей
cache_dir = scoring_config['rag_cache_dir']
if not os.path.isabs(cache_dir):
cache_dir = os.path.join(self._project_dir, cache_dir)
# Создаем директории если нужно
os.makedirs(os.path.dirname(vectors_path), exist_ok=True)
os.makedirs(cache_dir, exist_ok=True)
# Создаем VectorStore
vector_store = VectorStore(
vector_dim=768, # ruBERT dimension
max_examples=scoring_config['rag_max_examples'],
storage_path=vectors_path,
score_multiplier=scoring_config['rag_score_multiplier'],
)
# Создаем RAGService
rag_service = RAGService(
model_name=scoring_config['rag_model'],
vector_store=vector_store,
cache_dir=cache_dir,
enabled=True,
)
logger.info(f"RAGService инициализирован: {scoring_config['rag_model']}")
# Инициализация DeepSeek сервиса
deepseek_service = None
if scoring_config['deepseek_enabled'] and scoring_config['deepseek_api_key']:
deepseek_service = DeepSeekService(
api_key=scoring_config['deepseek_api_key'],
api_url=scoring_config['deepseek_api_url'],
model=scoring_config['deepseek_model'],
timeout=scoring_config['deepseek_timeout'],
enabled=True,
)
logger.info(f"DeepSeekService инициализирован: {scoring_config['deepseek_model']}")
# Создаем менеджер
self._scoring_manager = ScoringManager(
rag_service=rag_service,
deepseek_service=deepseek_service,
)
return self._scoring_manager
def get_scoring_manager(self):
"""
Возвращает ScoringManager для ML-скоринга постов.
Инициализируется лениво при первом вызове.
Returns:
ScoringManager или None если скоринг полностью отключен
"""
if self._scoring_manager is None:
scoring_config = self.settings.get('Scoring', {})
# Проверяем, включен ли хотя бы один сервис
rag_enabled = scoring_config.get('rag_enabled', False)
deepseek_enabled = scoring_config.get('deepseek_enabled', False)
if not rag_enabled and not deepseek_enabled:
logger.info("Scoring полностью отключен (RAG и DeepSeek disabled)")
return None
self._init_scoring_manager()
return self._scoring_manager
_global_instance = None _global_instance = None

View File

@@ -111,7 +111,16 @@ def determine_anonymity(post_text: str) -> bool:
return False return False
def get_text_message(post_text: str, first_name: str, username: str = None, is_anonymous: Optional[bool] = None): def get_text_message(
post_text: str,
first_name: str,
username: str = None,
is_anonymous: Optional[bool] = None,
deepseek_score: Optional[float] = None,
rag_score: Optional[float] = None,
rag_confidence: Optional[float] = None,
rag_score_pos_only: Optional[float] = None,
):
""" """
Форматирует текст сообщения для публикации в зависимости от наличия ключевых слов "анон" и "неанон" Форматирует текст сообщения для публикации в зависимости от наличия ключевых слов "анон" и "неанон"
или переданного параметра is_anonymous. или переданного параметра is_anonymous.
@@ -121,6 +130,10 @@ def get_text_message(post_text: str, first_name: str, username: str = None, is_a
first_name: Имя автора поста first_name: Имя автора поста
username: Юзернейм автора поста (может быть None) username: Юзернейм автора поста (может быть None)
is_anonymous: Флаг анонимности (True - анонимно, False - не анонимно, None - legacy, определяется по тексту) is_anonymous: Флаг анонимности (True - анонимно, False - не анонимно, None - legacy, определяется по тексту)
deepseek_score: Скор от DeepSeek API (0.0-1.0, опционально)
rag_score: Скор от RAG/ruBERT neg/pos (0.0-1.0, опционально)
rag_confidence: Уверенность RAG модели (0.0-1.0, зависит от количества примеров)
rag_score_pos_only: Скор RAG только по положительным примерам (0.0-1.0, опционально)
Returns: Returns:
str: - Сформированный текст сообщения. str: - Сформированный текст сообщения.
@@ -137,21 +150,37 @@ def get_text_message(post_text: str, first_name: str, username: str = None, is_a
else: else:
author_info = f"{first_name} (Ник не указан)" author_info = f"{first_name} (Ник не указан)"
# Формируем базовый текст
# Если передан is_anonymous, используем его, иначе определяем по тексту (legacy) # Если передан is_anonymous, используем его, иначе определяем по тексту (legacy)
# TODO: Уверен можно укоротить
if is_anonymous is not None: if is_anonymous is not None:
if is_anonymous: if is_anonymous:
return f'{safe_post_text}\n\nПост опубликован анонимно' final_text = f'{safe_post_text}\n\nПост опубликован анонимно'
else: else:
return f'{safe_post_text}\n\nАвтор поста: {author_info}' final_text = f'{safe_post_text}\n\nАвтор поста: {author_info}'
else: else:
# Legacy: определяем по тексту # Legacy: определяем по тексту
if "неанон" in post_text or "не анон" in post_text: if "неанон" in post_text or "не анон" in post_text:
return f'{safe_post_text}\n\nАвтор поста: {author_info}' final_text = f'{safe_post_text}\n\nАвтор поста: {author_info}'
elif "анон" in post_text: elif "анон" in post_text:
return f'{safe_post_text}\n\nПост опубликован анонимно' final_text = f'{safe_post_text}\n\nПост опубликован анонимно'
else: else:
return f'{safe_post_text}\n\nАвтор поста: {author_info}' final_text = f'{safe_post_text}\n\nАвтор поста: {author_info}'
# Добавляем блок со скорами если есть
if deepseek_score is not None or rag_score is not None or rag_score_pos_only is not None:
scores_lines = ["\n📊 Уверенность в одобрении:"]
if deepseek_score is not None:
scores_lines.append(f"DeepSeek: {deepseek_score:.2f}")
if rag_score is not None:
rag_line = f"RAG neg/pos: {rag_score:.2f}"
if rag_confidence is not None:
rag_line += f" (уверенность: {rag_confidence:.0%})"
scores_lines.append(rag_line)
if rag_score_pos_only is not None:
scores_lines.append(f"RAG pos only: {rag_score_pos_only:.2f}")
final_text += "\n" + "\n".join(scores_lines)
return final_text
@track_time("download_file", "helper_func") @track_time("download_file", "helper_func")
@track_errors("helper_func", "download_file") @track_errors("helper_func", "download_file")

View File

@@ -30,4 +30,10 @@ typing_extensions~=4.12.2
emoji~=2.8.0 emoji~=2.8.0
# S3 Storage (для хранения медиафайлов опубликованных постов) # S3 Storage (для хранения медиафайлов опубликованных постов)
aioboto3>=12.0.0 aioboto3>=12.0.0
# ML Scoring (для оценки вероятности публикации постов)
numpy>=1.24.0
transformers>=4.30.0
torch>=2.0.0
httpx>=0.24.0

View File

@@ -0,0 +1,93 @@
#!/usr/bin/env python3
"""
Миграция: Добавление колонок для ML-скоринга постов.
Добавляет:
- ml_scores (TEXT/JSON) - JSON с результатами оценки от разных моделей
- vector_hash (TEXT) - хеш текста для кеширования векторов
Структура ml_scores:
{
"deepseek": {"score": 0.75, "model": "deepseek-chat", "ts": 1706198400},
"rag": {"score": 0.90, "model": "rubert-base-cased", "ts": 1706198400}
}
"""
import argparse
import asyncio
import os
import sys
from pathlib import Path
# Добавляем корень проекта в путь
project_root = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(project_root))
import aiosqlite
# Пытаемся импортировать logger, если не получается - используем стандартный
try:
from logs.custom_logger import logger
except ImportError:
import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
DEFAULT_DB_PATH = "database/tg-bot-database.db"
async def column_exists(conn: aiosqlite.Connection, table: str, column: str) -> bool:
"""Проверяет существование колонки в таблице."""
cursor = await conn.execute(f"PRAGMA table_info({table})")
columns = await cursor.fetchall()
return any(col[1] == column for col in columns)
async def main(db_path: str) -> None:
"""
Основная функция миграции.
Добавляет колонки ml_scores и vector_hash в таблицу post_from_telegram_suggest.
Миграция идемпотентна - можно запускать повторно без ошибок.
"""
db_path = os.path.abspath(db_path)
if not os.path.exists(db_path):
logger.error(f"База данных не найдена: {db_path}")
return
async with aiosqlite.connect(db_path) as conn:
await conn.execute("PRAGMA foreign_keys = ON")
# Проверяем и добавляем колонку ml_scores
if not await column_exists(conn, "post_from_telegram_suggest", "ml_scores"):
await conn.execute(
"ALTER TABLE post_from_telegram_suggest ADD COLUMN ml_scores TEXT"
)
logger.info("Колонка ml_scores добавлена в post_from_telegram_suggest")
else:
logger.info("Колонка ml_scores уже существует")
# Проверяем и добавляем колонку vector_hash
if not await column_exists(conn, "post_from_telegram_suggest", "vector_hash"):
await conn.execute(
"ALTER TABLE post_from_telegram_suggest ADD COLUMN vector_hash TEXT"
)
logger.info("Колонка vector_hash добавлена в post_from_telegram_suggest")
else:
logger.info("Колонка vector_hash уже существует")
await conn.commit()
logger.info("Миграция add_ml_scores_columns завершена успешно")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Добавление колонок ml_scores и vector_hash для ML-скоринга"
)
parser.add_argument(
"--db",
default=os.environ.get("DATABASE_PATH", DEFAULT_DB_PATH),
help="Путь к БД",
)
args = parser.parse_args()
asyncio.run(main(args.db))

View File

@@ -0,0 +1,390 @@
"""
Тесты для сервисов ML-скоринга постов.
"""
import json
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
# Импорты для тестирования базовых классов
from helper_bot.services.scoring.base import ScoringResult, CombinedScore
from helper_bot.services.scoring.exceptions import (
ScoringError,
InsufficientExamplesError,
TextTooShortError,
)
class TestScoringResult:
"""Тесты для ScoringResult."""
def test_create_valid_score(self):
"""Тест создания валидного результата."""
result = ScoringResult(
score=0.75,
source="rag",
model="test-model",
)
assert result.score == 0.75
assert result.source == "rag"
assert result.model == "test-model"
def test_score_validation_lower_bound(self):
"""Тест валидации нижней границы скора."""
with pytest.raises(ValueError):
ScoringResult(score=-0.1, source="test", model="test")
def test_score_validation_upper_bound(self):
"""Тест валидации верхней границы скора."""
with pytest.raises(ValueError):
ScoringResult(score=1.1, source="test", model="test")
def test_to_dict(self):
"""Тест преобразования в словарь."""
result = ScoringResult(
score=0.7534,
source="rag",
model="test-model",
confidence=0.85,
timestamp=1234567890,
)
d = result.to_dict()
assert d["score"] == 0.7534 # Округлено до 4 знаков
assert d["model"] == "test-model"
assert d["ts"] == 1234567890
assert d["confidence"] == 0.85
def test_from_dict(self):
"""Тест создания из словаря."""
data = {
"score": 0.75,
"model": "test-model",
"ts": 1234567890,
"confidence": 0.9,
}
result = ScoringResult.from_dict("rag", data)
assert result.score == 0.75
assert result.source == "rag"
assert result.model == "test-model"
assert result.timestamp == 1234567890
assert result.confidence == 0.9
class TestCombinedScore:
"""Тесты для CombinedScore."""
def test_empty_combined_score(self):
"""Тест пустого объединенного скора."""
score = CombinedScore()
assert score.deepseek is None
assert score.rag is None
assert score.deepseek_score is None
assert score.rag_score is None
assert not score.has_any_score()
def test_combined_score_with_rag(self):
"""Тест объединенного скора с RAG."""
rag_result = ScoringResult(score=0.8, source="rag", model="rubert")
score = CombinedScore(rag=rag_result)
assert score.rag_score == 0.8
assert score.deepseek_score is None
assert score.has_any_score()
def test_combined_score_with_both(self):
"""Тест объединенного скора с обоими сервисами."""
rag_result = ScoringResult(score=0.8, source="rag", model="rubert")
deepseek_result = ScoringResult(score=0.7, source="deepseek", model="deepseek-chat")
score = CombinedScore(rag=rag_result, deepseek=deepseek_result)
assert score.rag_score == 0.8
assert score.deepseek_score == 0.7
assert score.has_any_score()
def test_to_json_dict(self):
"""Тест преобразования в JSON словарь."""
rag_result = ScoringResult(score=0.8, source="rag", model="rubert", timestamp=123)
deepseek_result = ScoringResult(score=0.7, source="deepseek", model="deepseek-chat", timestamp=456)
score = CombinedScore(rag=rag_result, deepseek=deepseek_result)
d = score.to_json_dict()
assert "rag" in d
assert "deepseek" in d
assert d["rag"]["score"] == 0.8
assert d["deepseek"]["score"] == 0.7
# Проверяем что это валидный JSON
json_str = json.dumps(d)
assert json_str
class TestVectorStore:
"""Тесты для VectorStore (требует numpy)."""
@pytest.fixture
def vector_store(self):
"""Создает VectorStore для тестов."""
try:
import numpy as np
from helper_bot.services.scoring.vector_store import VectorStore
return VectorStore(vector_dim=768, max_examples=100)
except ImportError:
pytest.skip("numpy не установлен")
def test_add_positive_example(self, vector_store):
"""Тест добавления положительного примера."""
import numpy as np
vector = np.random.randn(768).astype(np.float32)
result = vector_store.add_positive(vector, "hash1")
assert result is True
assert vector_store.positive_count == 1
def test_add_duplicate_example(self, vector_store):
"""Тест добавления дубликата."""
import numpy as np
vector = np.random.randn(768).astype(np.float32)
vector_store.add_positive(vector, "hash1")
result = vector_store.add_positive(vector, "hash1") # Дубликат
assert result is False
assert vector_store.positive_count == 1
def test_max_examples_limit(self, vector_store):
"""Тест ограничения максимального количества примеров."""
import numpy as np
# Добавляем больше чем max_examples
for i in range(150):
vector = np.random.randn(768).astype(np.float32)
vector_store.add_positive(vector, f"hash_{i}")
assert vector_store.positive_count == 100 # max_examples
def test_calculate_similarity_no_examples(self, vector_store):
"""Тест расчета скора без примеров."""
import numpy as np
vector = np.random.randn(768).astype(np.float32)
with pytest.raises(InsufficientExamplesError):
vector_store.calculate_similarity_score(vector)
def test_calculate_similarity_with_examples(self, vector_store):
"""Тест расчета скора с примерами."""
import numpy as np
# Добавляем положительные примеры
for i in range(10):
vector = np.random.randn(768).astype(np.float32)
vector_store.add_positive(vector, f"pos_{i}")
# Добавляем отрицательные примеры
for i in range(10):
vector = np.random.randn(768).astype(np.float32)
vector_store.add_negative(vector, f"neg_{i}")
# Рассчитываем скор для нового вектора
test_vector = np.random.randn(768).astype(np.float32)
score, confidence = vector_store.calculate_similarity_score(test_vector)
assert 0.0 <= score <= 1.0
assert 0.0 <= confidence <= 1.0
def test_compute_text_hash(self, vector_store):
"""Тест вычисления хеша текста."""
from helper_bot.services.scoring.vector_store import VectorStore
hash1 = VectorStore.compute_text_hash("Привет мир")
hash2 = VectorStore.compute_text_hash("Привет мир")
hash3 = VectorStore.compute_text_hash("Другой текст")
assert hash1 == hash2
assert hash1 != hash3
class TestDeepSeekService:
"""Тесты для DeepSeekService."""
@pytest.fixture
def deepseek_service(self):
"""Создает DeepSeekService для тестов."""
from helper_bot.services.scoring.deepseek_service import DeepSeekService
return DeepSeekService(
api_key="test_key",
enabled=True,
timeout=5,
)
def test_service_disabled_without_key(self):
"""Тест отключения сервиса без API ключа."""
from helper_bot.services.scoring.deepseek_service import DeepSeekService
service = DeepSeekService(api_key=None, enabled=True)
assert service.is_enabled is False
def test_parse_score_response_valid(self, deepseek_service):
"""Тест парсинга валидного ответа."""
assert deepseek_service._parse_score_response("0.75") == 0.75
assert deepseek_service._parse_score_response("0.5") == 0.5
assert deepseek_service._parse_score_response("1.0") == 1.0
assert deepseek_service._parse_score_response("0") == 0.0
def test_parse_score_response_with_quotes(self, deepseek_service):
"""Тест парсинга ответа с кавычками."""
assert deepseek_service._parse_score_response('"0.75"') == 0.75
assert deepseek_service._parse_score_response("'0.8'") == 0.8
def test_parse_score_response_with_text(self, deepseek_service):
"""Тест парсинга ответа с текстом."""
# Сервис должен найти число в тексте
assert deepseek_service._parse_score_response("Score: 0.75") == 0.75
def test_clean_text(self, deepseek_service):
"""Тест очистки текста."""
assert deepseek_service._clean_text(" hello world ") == "hello world"
assert deepseek_service._clean_text("^") == ""
assert deepseek_service._clean_text("") == ""
@pytest.mark.asyncio
async def test_calculate_score_disabled(self):
"""Тест расчета скора при отключенном сервисе."""
from helper_bot.services.scoring.deepseek_service import DeepSeekService
service = DeepSeekService(api_key=None, enabled=False)
with pytest.raises(ScoringError):
await service.calculate_score("Test text")
@pytest.mark.asyncio
async def test_calculate_score_short_text(self, deepseek_service):
"""Тест расчета скора для короткого текста."""
with pytest.raises(TextTooShortError):
await deepseek_service.calculate_score("ab")
class TestScoringManager:
"""Тесты для ScoringManager."""
@pytest.fixture
def mock_rag_service(self):
"""Создает мок RAG сервиса."""
mock = AsyncMock()
mock.is_enabled = True
mock.calculate_score = AsyncMock(return_value=ScoringResult(
score=0.8,
source="rag",
model="rubert",
))
return mock
@pytest.fixture
def mock_deepseek_service(self):
"""Создает мок DeepSeek сервиса."""
mock = AsyncMock()
mock.is_enabled = True
mock.calculate_score = AsyncMock(return_value=ScoringResult(
score=0.7,
source="deepseek",
model="deepseek-chat",
))
return mock
@pytest.mark.asyncio
async def test_score_post_both_services(self, mock_rag_service, mock_deepseek_service):
"""Тест скоринга с обоими сервисами."""
from helper_bot.services.scoring.scoring_manager import ScoringManager
manager = ScoringManager(
rag_service=mock_rag_service,
deepseek_service=mock_deepseek_service,
)
result = await manager.score_post("Тестовый пост")
assert result.rag_score == 0.8
assert result.deepseek_score == 0.7
assert result.has_any_score()
@pytest.mark.asyncio
async def test_score_post_rag_only(self, mock_rag_service):
"""Тест скоринга только с RAG."""
from helper_bot.services.scoring.scoring_manager import ScoringManager
manager = ScoringManager(
rag_service=mock_rag_service,
deepseek_service=None,
)
result = await manager.score_post("Тестовый пост")
assert result.rag_score == 0.8
assert result.deepseek_score is None
@pytest.mark.asyncio
async def test_score_post_empty_text(self, mock_rag_service):
"""Тест скоринга пустого текста."""
from helper_bot.services.scoring.scoring_manager import ScoringManager
manager = ScoringManager(rag_service=mock_rag_service)
result = await manager.score_post("")
assert not result.has_any_score()
mock_rag_service.calculate_score.assert_not_called()
@pytest.mark.asyncio
async def test_score_post_service_error(self, mock_rag_service, mock_deepseek_service):
"""Тест обработки ошибки сервиса."""
from helper_bot.services.scoring.scoring_manager import ScoringManager
# RAG выбрасывает ошибку
mock_rag_service.calculate_score = AsyncMock(side_effect=Exception("Test error"))
manager = ScoringManager(
rag_service=mock_rag_service,
deepseek_service=mock_deepseek_service,
)
result = await manager.score_post("Тестовый пост")
# DeepSeek должен вернуть результат
assert result.deepseek_score == 0.7
# RAG должен быть None с ошибкой
assert result.rag_score is None
assert "rag" in result.errors
@pytest.mark.asyncio
async def test_on_post_published(self, mock_rag_service, mock_deepseek_service):
"""Тест обучения на опубликованном посте."""
from helper_bot.services.scoring.scoring_manager import ScoringManager
manager = ScoringManager(
rag_service=mock_rag_service,
deepseek_service=mock_deepseek_service,
)
await manager.on_post_published("Опубликованный пост")
mock_rag_service.add_positive_example.assert_called_once_with("Опубликованный пост")
mock_deepseek_service.add_positive_example.assert_called_once_with("Опубликованный пост")
@pytest.mark.asyncio
async def test_on_post_declined(self, mock_rag_service, mock_deepseek_service):
"""Тест обучения на отклоненном посте."""
from helper_bot.services.scoring.scoring_manager import ScoringManager
manager = ScoringManager(
rag_service=mock_rag_service,
deepseek_service=mock_deepseek_service,
)
await manager.on_post_declined("Отклоненный пост")
mock_rag_service.add_negative_example.assert_called_once_with("Отклоненный пост")
mock_deepseek_service.add_negative_example.assert_called_once_with("Отклоненный пост")