feat: интеграция ML-скоринга с использованием RAG и DeepSeek
- Обновлен Dockerfile для установки необходимых зависимостей. - Добавлены новые переменные окружения для настройки ML-скоринга в env.example. - Реализованы методы для получения и обновления ML-скоров в AsyncBotDB и PostRepository. - Обновлены обработчики публикации постов для интеграции ML-скоринга. - Добавлен новый обработчик для получения статистики ML-скоринга в админ-панели. - Обновлены функции для форматирования сообщений с учетом ML-скоров.
This commit is contained in:
5
.github/workflows/deploy.yml
vendored
5
.github/workflows/deploy.yml
vendored
@@ -165,9 +165,8 @@ jobs:
|
||||
|
||||
📦 Repository: telegram-helper-bot
|
||||
🌿 Branch: main
|
||||
📝 Commit: ${{ github.event.pull_request.merge_commit_sha || github.sha }}
|
||||
👤 Author: ${{ github.event.pull_request.user.login || github.actor }}
|
||||
${{ github.event.pull_request.number && format('🔀 PR: #{0}', github.event.pull_request.number) || '' }}
|
||||
📝 Commit: ${{ github.sha }}
|
||||
👤 Author: ${{ github.actor }}
|
||||
|
||||
${{ job.status == 'success' && '✅ Deployment successful! Container restarted with migrations applied.' || '❌ Deployment failed! Check logs for details.' }}
|
||||
|
||||
|
||||
36
Dockerfile
36
Dockerfile
@@ -1,15 +1,14 @@
|
||||
###########################################
|
||||
# Этап 1: Сборщик (Builder)
|
||||
###########################################
|
||||
FROM python:3.11.9-alpine as builder
|
||||
FROM python:3.11.9-slim as builder
|
||||
|
||||
# Устанавливаем инструменты для компиляции + linux-headers для psutil
|
||||
RUN apk add --no-cache \
|
||||
# Устанавливаем инструменты для компиляции
|
||||
RUN apt-get update && apt-get install --no-install-recommends -y \
|
||||
gcc \
|
||||
g++ \
|
||||
musl-dev \
|
||||
python3-dev \
|
||||
linux-headers # ← ЭТО КРИТИЧЕСКИ ВАЖНО ДЛЯ psutil
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
WORKDIR /app
|
||||
COPY requirements.txt .
|
||||
@@ -21,29 +20,34 @@ RUN pip install --no-cache-dir --target /install -r requirements.txt
|
||||
###########################################
|
||||
# Этап 2: Финальный образ (Runtime)
|
||||
###########################################
|
||||
FROM python:3.11.9-alpine as runtime
|
||||
FROM python:3.11.9-slim as runtime
|
||||
|
||||
# Минимальные рантайм-зависимости
|
||||
RUN apk add --no-cache \
|
||||
libstdc++ \
|
||||
sqlite-libs
|
||||
RUN apt-get update && apt-get install --no-install-recommends -y \
|
||||
libgomp1 \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Создаем пользователя
|
||||
RUN addgroup -g 1001 deploy && adduser -D -u 1001 -G deploy deploy
|
||||
RUN groupadd -g 1001 deploy && useradd -r -u 1001 -g deploy deploy
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Копируем зависимости
|
||||
COPY --from=builder --chown=1001:1001 /install /usr/local/lib/python3.11/site-packages
|
||||
COPY --from=builder --chown=deploy:deploy /install /usr/local/lib/python3.11/site-packages
|
||||
|
||||
# Создаем структуру папок
|
||||
RUN mkdir -p database logs voice_users && \
|
||||
chown -R 1001:1001 /app
|
||||
# Создаем структуру папок (включая директории для ML моделей)
|
||||
RUN mkdir -p database logs voice_users data/models && \
|
||||
chown -R deploy:deploy /app
|
||||
|
||||
# Устанавливаем переменные для HuggingFace (кеш моделей внутри /app)
|
||||
ENV HF_HOME=/app/data/models
|
||||
ENV TRANSFORMERS_CACHE=/app/data/models
|
||||
ENV HF_HUB_CACHE=/app/data/models
|
||||
|
||||
# Копируем исходный код
|
||||
COPY --chown=1001:1001 . .
|
||||
COPY --chown=deploy:deploy . .
|
||||
|
||||
USER 1001
|
||||
USER deploy
|
||||
|
||||
# Healthcheck
|
||||
HEALTHCHECK --interval=30s --timeout=15s --start-period=10s --retries=5 \
|
||||
|
||||
@@ -211,6 +211,23 @@ class AsyncBotDB:
|
||||
helper_message_id, status
|
||||
)
|
||||
|
||||
# Методы для ML Scoring
|
||||
async def get_post_text_by_message_id(self, message_id: int) -> Optional[str]:
|
||||
"""Получает текст поста по message_id."""
|
||||
return await self.factory.posts.get_post_text_by_message_id(message_id)
|
||||
|
||||
async def update_ml_scores(self, message_id: int, ml_scores_json: str) -> bool:
|
||||
"""Обновляет ML-скоры для поста."""
|
||||
return await self.factory.posts.update_ml_scores(message_id, ml_scores_json)
|
||||
|
||||
async def get_approved_posts_texts(self, limit: int = 1000) -> List[str]:
|
||||
"""Получает тексты одобренных постов для обучения RAG."""
|
||||
return await self.factory.posts.get_approved_posts_texts(limit)
|
||||
|
||||
async def get_declined_posts_texts(self, limit: int = 1000) -> List[str]:
|
||||
"""Получает тексты отклоненных постов для обучения RAG."""
|
||||
return await self.factory.posts.get_declined_posts_texts(limit)
|
||||
|
||||
# Методы для работы с черным списком
|
||||
async def set_user_blacklist(
|
||||
self,
|
||||
|
||||
@@ -357,3 +357,126 @@ class PostRepository(DatabaseConnection):
|
||||
post_content = await self._execute_query_with_result(query, (published_message_id,))
|
||||
self.logger.info(f"Получен контент опубликованного поста: {len(post_content)} элементов для published_message_id={published_message_id}")
|
||||
return post_content
|
||||
|
||||
# ============================================
|
||||
# Методы для работы с ML-скорингом
|
||||
# ============================================
|
||||
|
||||
async def update_ml_scores(self, message_id: int, ml_scores_json: str) -> bool:
|
||||
"""
|
||||
Обновляет ML-скоры для поста.
|
||||
|
||||
Args:
|
||||
message_id: ID сообщения в группе модерации
|
||||
ml_scores_json: JSON строка со скорами
|
||||
|
||||
Returns:
|
||||
True если обновлено успешно
|
||||
"""
|
||||
try:
|
||||
query = "UPDATE post_from_telegram_suggest SET ml_scores = ? WHERE message_id = ?"
|
||||
await self._execute_query(query, (ml_scores_json, message_id))
|
||||
self.logger.info(f"ML-скоры обновлены для message_id={message_id}")
|
||||
return True
|
||||
except Exception as e:
|
||||
self.logger.error(f"Ошибка обновления ML-скоров для message_id={message_id}: {e}")
|
||||
return False
|
||||
|
||||
async def get_ml_scores_by_message_id(self, message_id: int) -> Optional[str]:
|
||||
"""
|
||||
Получает ML-скоры для поста.
|
||||
|
||||
Args:
|
||||
message_id: ID сообщения
|
||||
|
||||
Returns:
|
||||
JSON строка со скорами или None
|
||||
"""
|
||||
query = "SELECT ml_scores FROM post_from_telegram_suggest WHERE message_id = ?"
|
||||
rows = await self._execute_query_with_result(query, (message_id,))
|
||||
if rows and rows[0][0]:
|
||||
return rows[0][0]
|
||||
return None
|
||||
|
||||
async def get_post_text_by_message_id(self, message_id: int) -> Optional[str]:
|
||||
"""
|
||||
Получает текст поста по message_id.
|
||||
|
||||
Args:
|
||||
message_id: ID сообщения
|
||||
|
||||
Returns:
|
||||
Текст поста или None
|
||||
"""
|
||||
query = "SELECT text FROM post_from_telegram_suggest WHERE message_id = ?"
|
||||
rows = await self._execute_query_with_result(query, (message_id,))
|
||||
if rows and rows[0][0]:
|
||||
return rows[0][0]
|
||||
return None
|
||||
|
||||
async def get_approved_posts_texts(self, limit: int = 1000) -> List[str]:
|
||||
"""
|
||||
Получает тексты опубликованных постов для обучения RAG.
|
||||
|
||||
Args:
|
||||
limit: Максимальное количество постов
|
||||
|
||||
Returns:
|
||||
Список текстов
|
||||
"""
|
||||
query = """
|
||||
SELECT text FROM post_from_telegram_suggest
|
||||
WHERE status = 'approved'
|
||||
AND text IS NOT NULL
|
||||
AND text != ''
|
||||
AND text != '^'
|
||||
ORDER BY created_at DESC
|
||||
LIMIT ?
|
||||
"""
|
||||
rows = await self._execute_query_with_result(query, (limit,))
|
||||
texts = [row[0] for row in rows if row[0]]
|
||||
self.logger.info(f"Получено {len(texts)} опубликованных постов для обучения")
|
||||
return texts
|
||||
|
||||
async def get_declined_posts_texts(self, limit: int = 1000) -> List[str]:
|
||||
"""
|
||||
Получает тексты отклоненных постов для обучения RAG.
|
||||
|
||||
Args:
|
||||
limit: Максимальное количество постов
|
||||
|
||||
Returns:
|
||||
Список текстов
|
||||
"""
|
||||
query = """
|
||||
SELECT text FROM post_from_telegram_suggest
|
||||
WHERE status = 'declined'
|
||||
AND text IS NOT NULL
|
||||
AND text != ''
|
||||
AND text != '^'
|
||||
ORDER BY created_at DESC
|
||||
LIMIT ?
|
||||
"""
|
||||
rows = await self._execute_query_with_result(query, (limit,))
|
||||
texts = [row[0] for row in rows if row[0]]
|
||||
self.logger.info(f"Получено {len(texts)} отклоненных постов для обучения")
|
||||
return texts
|
||||
|
||||
async def update_vector_hash(self, message_id: int, vector_hash: str) -> bool:
|
||||
"""
|
||||
Обновляет хеш вектора для поста (для кеширования).
|
||||
|
||||
Args:
|
||||
message_id: ID сообщения
|
||||
vector_hash: Хеш вектора
|
||||
|
||||
Returns:
|
||||
True если обновлено успешно
|
||||
"""
|
||||
try:
|
||||
query = "UPDATE post_from_telegram_suggest SET vector_hash = ? WHERE message_id = ?"
|
||||
await self._execute_query(query, (vector_hash, message_id))
|
||||
return True
|
||||
except Exception as e:
|
||||
self.logger.error(f"Ошибка обновления vector_hash для message_id={message_id}: {e}")
|
||||
return False
|
||||
|
||||
17
env.example
17
env.example
@@ -35,3 +35,20 @@ METRICS_PORT=8080
|
||||
# Logging
|
||||
LOG_LEVEL=INFO
|
||||
LOG_RETENTION_DAYS=30
|
||||
|
||||
# ML Scoring - RAG (ruBERT)
|
||||
# Включает локальное векторное сравнение с использованием ruBERT
|
||||
RAG_ENABLED=false
|
||||
RAG_MODEL=DeepPavlov/rubert-base-cased
|
||||
RAG_CACHE_DIR=data/models
|
||||
RAG_VECTORS_PATH=data/vectors.npz
|
||||
RAG_MAX_EXAMPLES=10000
|
||||
RAG_SCORE_MULTIPLIER=5
|
||||
|
||||
# ML Scoring - DeepSeek API
|
||||
# Включает оценку постов через DeepSeek API
|
||||
DEEPSEEK_ENABLED=false
|
||||
DEEPSEEK_API_KEY=your_deepseek_api_key_here
|
||||
DEEPSEEK_API_URL=https://api.deepseek.com/v1/chat/completions
|
||||
DEEPSEEK_MODEL=deepseek-chat
|
||||
DEEPSEEK_TIMEOUT=30
|
||||
|
||||
@@ -16,6 +16,7 @@ from helper_bot.keyboards.keyboards import (create_keyboard_for_approve_ban,
|
||||
create_keyboard_for_ban_reason,
|
||||
create_keyboard_with_pagination,
|
||||
get_reply_keyboard_admin)
|
||||
from helper_bot.utils.base_dependency_factory import get_global_instance
|
||||
# Local imports - metrics
|
||||
from helper_bot.utils.metrics import db_query_time, track_errors, track_time
|
||||
from logs.custom_logger import logger
|
||||
@@ -137,6 +138,69 @@ async def get_banned_users(
|
||||
await handle_admin_error(message, e, state, "get_banned_users")
|
||||
|
||||
|
||||
@admin_router.message(
|
||||
ChatTypeFilter(chat_type=["private"]),
|
||||
StateFilter("ADMIN"),
|
||||
F.text == '📊 ML Статистика'
|
||||
)
|
||||
@track_time("get_ml_stats", "admin_handlers")
|
||||
@track_errors("admin_handlers", "get_ml_stats")
|
||||
async def get_ml_stats(
|
||||
message: types.Message,
|
||||
state: FSMContext,
|
||||
**kwargs
|
||||
):
|
||||
"""Получение статистики ML-скоринга"""
|
||||
try:
|
||||
logger.info(f"Запрос ML статистики от пользователя: {message.from_user.full_name}")
|
||||
|
||||
bdf = get_global_instance()
|
||||
scoring_manager = bdf.get_scoring_manager()
|
||||
|
||||
if not scoring_manager:
|
||||
await message.answer("📊 ML Scoring отключен\n\nДля включения установите RAG_ENABLED=true или DEEPSEEK_ENABLED=true в .env")
|
||||
return
|
||||
|
||||
stats = scoring_manager.get_stats()
|
||||
|
||||
# Формируем текст статистики
|
||||
lines = ["📊 <b>ML Scoring Статистика</b>\n"]
|
||||
|
||||
# RAG статистика
|
||||
if "rag" in stats:
|
||||
rag = stats["rag"]
|
||||
lines.append("🤖 <b>RAG (ruBERT):</b>")
|
||||
lines.append(f" • Статус: {'✅ Включен' if rag.get('enabled') else '❌ Отключен'}")
|
||||
lines.append(f" • Модель: {rag.get('model_name', 'N/A')}")
|
||||
lines.append(f" • Модель загружена: {'✅' if rag.get('model_loaded') else '❌'}")
|
||||
|
||||
vs = rag.get("vector_store", {})
|
||||
lines.append(f" • Положительных примеров: {vs.get('positive_count', 0)}")
|
||||
lines.append(f" • Отрицательных примеров: {vs.get('negative_count', 0)}")
|
||||
lines.append(f" • Всего примеров: {vs.get('total_count', 0)}")
|
||||
lines.append(f" • Макс. примеров: {vs.get('max_examples', 'N/A')}")
|
||||
lines.append("")
|
||||
|
||||
# DeepSeek статистика
|
||||
if "deepseek" in stats:
|
||||
ds = stats["deepseek"]
|
||||
lines.append("🔮 <b>DeepSeek API:</b>")
|
||||
lines.append(f" • Статус: {'✅ Включен' if ds.get('enabled') else '❌ Отключен'}")
|
||||
lines.append(f" • Модель: {ds.get('model', 'N/A')}")
|
||||
lines.append(f" • Таймаут: {ds.get('timeout', 'N/A')}с")
|
||||
lines.append("")
|
||||
|
||||
# Если ничего не включено
|
||||
if "rag" not in stats and "deepseek" not in stats:
|
||||
lines.append("⚠️ Ни один сервис не настроен")
|
||||
|
||||
await message.answer("\n".join(lines), parse_mode="HTML")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Ошибка получения ML статистики: {e}")
|
||||
await message.answer(f"❌ Ошибка получения статистики: {str(e)}")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# ХЕНДЛЕРЫ ПРОЦЕССА БАНА
|
||||
# ============================================================================
|
||||
|
||||
@@ -15,7 +15,8 @@ def get_post_publish_service() -> PostPublishService:
|
||||
db = bdf.get_db()
|
||||
settings = bdf.settings
|
||||
s3_storage = bdf.get_s3_storage()
|
||||
return PostPublishService(None, db, settings, s3_storage)
|
||||
scoring_manager = bdf.get_scoring_manager()
|
||||
return PostPublishService(None, db, settings, s3_storage, scoring_manager)
|
||||
|
||||
|
||||
def get_ban_service() -> BanService:
|
||||
|
||||
@@ -29,12 +29,13 @@ from .exceptions import (BanError, PostNotFoundError, PublishError,
|
||||
|
||||
|
||||
class PostPublishService:
|
||||
def __init__(self, bot: Bot, db, settings: Dict[str, Any], s3_storage=None):
|
||||
def __init__(self, bot: Bot, db, settings: Dict[str, Any], s3_storage=None, scoring_manager=None):
|
||||
# bot может быть None - в этом случае используем бота из контекста сообщения
|
||||
self.bot = bot
|
||||
self.db = db
|
||||
self.settings = settings
|
||||
self.s3_storage = s3_storage
|
||||
self.scoring_manager = scoring_manager
|
||||
self.group_for_posts = settings['Telegram']['group_for_posts']
|
||||
self.main_public = settings['Telegram']['main_public']
|
||||
self.important_logs = settings['Telegram']['important_logs']
|
||||
@@ -393,6 +394,9 @@ class PostPublishService:
|
||||
"""Отклонение одиночного поста"""
|
||||
author_id = await self._get_author_id(call.message.message_id)
|
||||
|
||||
# Обучаем RAG на отклоненном посте перед удалением
|
||||
await self._train_on_declined(call.message.message_id)
|
||||
|
||||
updated_rows = await self.db.update_status_by_message_id(call.message.message_id, "declined")
|
||||
if updated_rows == 0:
|
||||
logger.error(f"Не удалось обновить статус поста message_id={call.message.message_id} на 'declined'")
|
||||
@@ -485,6 +489,9 @@ class PostPublishService:
|
||||
@track_errors("post_publish_service", "_delete_post_and_notify_author")
|
||||
async def _delete_post_and_notify_author(self, call: CallbackQuery, author_id: int) -> None:
|
||||
"""Удаление поста и уведомление автора"""
|
||||
# Получаем текст поста для обучения RAG перед удалением
|
||||
await self._train_on_published(call.message.message_id)
|
||||
|
||||
await self._get_bot(call.message).delete_message(chat_id=self.group_for_posts, message_id=call.message.message_id)
|
||||
|
||||
try:
|
||||
@@ -494,6 +501,32 @@ class PostPublishService:
|
||||
raise UserBlockedBotError("Пользователь заблокировал бота")
|
||||
raise
|
||||
|
||||
async def _train_on_published(self, message_id: int) -> None:
|
||||
"""Обучает RAG на опубликованном посте."""
|
||||
if not self.scoring_manager:
|
||||
return
|
||||
|
||||
try:
|
||||
text = await self.db.get_post_text_by_message_id(message_id)
|
||||
if text and text.strip() and text != "^":
|
||||
await self.scoring_manager.on_post_published(text)
|
||||
logger.debug(f"RAG обучен на опубликованном посте: {message_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Ошибка обучения RAG на опубликованном посте {message_id}: {e}")
|
||||
|
||||
async def _train_on_declined(self, message_id: int) -> None:
|
||||
"""Обучает RAG на отклоненном посте."""
|
||||
if not self.scoring_manager:
|
||||
return
|
||||
|
||||
try:
|
||||
text = await self.db.get_post_text_by_message_id(message_id)
|
||||
if text and text.strip() and text != "^":
|
||||
await self.scoring_manager.on_post_declined(text)
|
||||
logger.debug(f"RAG обучен на отклоненном посте: {message_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Ошибка обучения RAG на отклоненном посте {message_id}: {e}")
|
||||
|
||||
@track_time("_delete_media_group_and_notify_author", "post_publish_service")
|
||||
@track_errors("post_publish_service", "_delete_media_group_and_notify_author")
|
||||
@track_media_processing("media_group")
|
||||
|
||||
@@ -35,11 +35,11 @@ sleep = asyncio.sleep
|
||||
class PrivateHandlers:
|
||||
"""Main handler class for private messages"""
|
||||
|
||||
def __init__(self, db: AsyncBotDB, settings: BotSettings, s3_storage=None):
|
||||
def __init__(self, db: AsyncBotDB, settings: BotSettings, s3_storage=None, scoring_manager=None):
|
||||
self.db = db
|
||||
self.settings = settings
|
||||
self.user_service = UserService(db, settings)
|
||||
self.post_service = PostService(db, settings, s3_storage)
|
||||
self.post_service = PostService(db, settings, s3_storage, scoring_manager)
|
||||
self.sticker_service = StickerService(settings)
|
||||
|
||||
self.router = Router()
|
||||
@@ -240,18 +240,24 @@ class PrivateHandlers:
|
||||
|
||||
|
||||
# Factory function to create handlers with dependencies
|
||||
def create_private_handlers(db: AsyncBotDB, settings: BotSettings, s3_storage=None) -> PrivateHandlers:
|
||||
def create_private_handlers(db: AsyncBotDB, settings: BotSettings, s3_storage=None, scoring_manager=None) -> PrivateHandlers:
|
||||
"""Create private handlers instance with dependencies"""
|
||||
return PrivateHandlers(db, settings, s3_storage)
|
||||
return PrivateHandlers(db, settings, s3_storage, scoring_manager)
|
||||
|
||||
|
||||
# Legacy router for backward compatibility
|
||||
private_router = Router()
|
||||
|
||||
# Флаг инициализации для защиты от повторного вызова
|
||||
_legacy_router_initialized = False
|
||||
|
||||
# Initialize with global dependencies (for backward compatibility)
|
||||
def init_legacy_router():
|
||||
"""Initialize legacy router with global dependencies"""
|
||||
global private_router
|
||||
global private_router, _legacy_router_initialized
|
||||
|
||||
if _legacy_router_initialized:
|
||||
return
|
||||
|
||||
from helper_bot.utils.base_dependency_factory import get_global_instance
|
||||
|
||||
@@ -269,11 +275,13 @@ def init_legacy_router():
|
||||
|
||||
db = bdf.get_db()
|
||||
s3_storage = bdf.get_s3_storage()
|
||||
handlers = create_private_handlers(db, settings, s3_storage)
|
||||
scoring_manager = bdf.get_scoring_manager()
|
||||
handlers = create_private_handlers(db, settings, s3_storage, scoring_manager)
|
||||
|
||||
# Instead of trying to copy handlers, we'll use the new router directly
|
||||
# This maintains backward compatibility while using the new architecture
|
||||
private_router = handlers.router
|
||||
_legacy_router_initialized = True
|
||||
|
||||
# Initialize legacy router
|
||||
init_legacy_router()
|
||||
|
||||
@@ -128,10 +128,11 @@ class UserService:
|
||||
class PostService:
|
||||
"""Service for post-related operations"""
|
||||
|
||||
def __init__(self, db: DatabaseProtocol, settings: BotSettings, s3_storage=None) -> None:
|
||||
def __init__(self, db: DatabaseProtocol, settings: BotSettings, s3_storage=None, scoring_manager=None) -> None:
|
||||
self.db = db
|
||||
self.settings = settings
|
||||
self.s3_storage = s3_storage
|
||||
self.scoring_manager = scoring_manager
|
||||
|
||||
async def _save_media_background(self, sent_message: types.Message, bot_db: Any, s3_storage) -> None:
|
||||
"""Сохраняет медиа в фоне, чтобы не блокировать ответ пользователю"""
|
||||
@@ -142,18 +143,65 @@ class PostService:
|
||||
except Exception as e:
|
||||
logger.error(f"_save_media_background: Ошибка при сохранении медиа для поста {sent_message.message_id}: {e}")
|
||||
|
||||
async def _get_scores(self, text: str) -> tuple:
|
||||
"""
|
||||
Получает скоры для текста поста.
|
||||
|
||||
Returns:
|
||||
Tuple (deepseek_score, rag_score, rag_confidence, rag_score_pos_only, ml_scores_json)
|
||||
"""
|
||||
if not self.scoring_manager or not text or not text.strip():
|
||||
return None, None, None, None, None
|
||||
|
||||
try:
|
||||
scores = await self.scoring_manager.score_post(text)
|
||||
|
||||
# Формируем JSON для сохранения в БД
|
||||
import json
|
||||
ml_scores_json = json.dumps(scores.to_json_dict()) if scores.has_any_score() else None
|
||||
|
||||
# Получаем данные от RAG
|
||||
rag_confidence = scores.rag.confidence if scores.rag else None
|
||||
rag_score_pos_only = scores.rag.metadata.get("score_pos_only") if scores.rag else None
|
||||
|
||||
return scores.deepseek_score, scores.rag_score, rag_confidence, rag_score_pos_only, ml_scores_json
|
||||
except Exception as e:
|
||||
logger.error(f"PostService: Ошибка получения скоров: {e}")
|
||||
return None, None, None, None, None
|
||||
|
||||
async def _save_scores_background(self, message_id: int, ml_scores_json: str) -> None:
|
||||
"""Сохраняет скоры в БД в фоне."""
|
||||
if ml_scores_json:
|
||||
try:
|
||||
await self.db.update_ml_scores(message_id, ml_scores_json)
|
||||
except Exception as e:
|
||||
logger.error(f"PostService: Ошибка сохранения скоров для {message_id}: {e}")
|
||||
|
||||
@track_time("handle_text_post", "post_service")
|
||||
@track_errors("post_service", "handle_text_post")
|
||||
@db_query_time("handle_text_post", "posts", "insert")
|
||||
async def handle_text_post(self, message: types.Message, first_name: str) -> None:
|
||||
"""Handle text post submission"""
|
||||
post_text = get_text_message(message.text.lower(), first_name, message.from_user.username)
|
||||
raw_text = message.text or ""
|
||||
|
||||
# Получаем скоры для текста
|
||||
deepseek_score, rag_score, rag_confidence, rag_score_pos_only, ml_scores_json = await self._get_scores(raw_text)
|
||||
|
||||
# Формируем текст с учетом скоров
|
||||
post_text = get_text_message(
|
||||
message.text.lower(),
|
||||
first_name,
|
||||
message.from_user.username,
|
||||
deepseek_score=deepseek_score,
|
||||
rag_score=rag_score,
|
||||
rag_confidence=rag_confidence,
|
||||
rag_score_pos_only=rag_score_pos_only,
|
||||
)
|
||||
markup = get_reply_keyboard_for_post()
|
||||
|
||||
sent_message = await send_text_message(self.settings.group_for_posts, message, post_text, markup)
|
||||
|
||||
# Сохраняем сырой текст и определяем анонимность
|
||||
raw_text = message.text or ""
|
||||
# Определяем анонимность
|
||||
is_anonymous = determine_anonymity(raw_text)
|
||||
|
||||
post = TelegramPost(
|
||||
@@ -165,22 +213,38 @@ class PostService:
|
||||
)
|
||||
await self.db.add_post(post)
|
||||
|
||||
# Сохраняем скоры в фоне
|
||||
if ml_scores_json:
|
||||
asyncio.create_task(self._save_scores_background(sent_message.message_id, ml_scores_json))
|
||||
|
||||
@track_time("handle_photo_post", "post_service")
|
||||
@track_errors("post_service", "handle_photo_post")
|
||||
@db_query_time("handle_photo_post", "posts", "insert")
|
||||
async def handle_photo_post(self, message: types.Message, first_name: str) -> None:
|
||||
"""Handle photo post submission"""
|
||||
raw_caption = message.caption or ""
|
||||
|
||||
# Получаем скоры для текста
|
||||
deepseek_score, rag_score, rag_confidence, rag_score_pos_only, ml_scores_json = await self._get_scores(raw_caption)
|
||||
|
||||
post_caption = ""
|
||||
if message.caption:
|
||||
post_caption = get_text_message(message.caption.lower(), first_name, message.from_user.username)
|
||||
post_caption = get_text_message(
|
||||
message.caption.lower(),
|
||||
first_name,
|
||||
message.from_user.username,
|
||||
deepseek_score=deepseek_score,
|
||||
rag_score=rag_score,
|
||||
rag_confidence=rag_confidence,
|
||||
rag_score_pos_only=rag_score_pos_only,
|
||||
)
|
||||
|
||||
markup = get_reply_keyboard_for_post()
|
||||
sent_message = await send_photo_message(
|
||||
self.settings.group_for_posts, message, message.photo[-1].file_id, post_caption, markup
|
||||
)
|
||||
|
||||
# Сохраняем сырой caption и определяем анонимность
|
||||
raw_caption = message.caption or ""
|
||||
# Определяем анонимность
|
||||
is_anonymous = determine_anonymity(raw_caption)
|
||||
|
||||
post = TelegramPost(
|
||||
@@ -191,25 +255,40 @@ class PostService:
|
||||
is_anonymous=is_anonymous
|
||||
)
|
||||
await self.db.add_post(post)
|
||||
# Сохраняем медиа в фоне, чтобы не блокировать ответ пользователю
|
||||
|
||||
# Сохраняем медиа и скоры в фоне
|
||||
asyncio.create_task(self._save_media_background(sent_message, self.db, self.s3_storage))
|
||||
if ml_scores_json:
|
||||
asyncio.create_task(self._save_scores_background(sent_message.message_id, ml_scores_json))
|
||||
|
||||
@track_time("handle_video_post", "post_service")
|
||||
@track_errors("post_service", "handle_video_post")
|
||||
@db_query_time("handle_video_post", "posts", "insert")
|
||||
async def handle_video_post(self, message: types.Message, first_name: str) -> None:
|
||||
"""Handle video post submission"""
|
||||
raw_caption = message.caption or ""
|
||||
|
||||
# Получаем скоры для текста
|
||||
deepseek_score, rag_score, rag_confidence, rag_score_pos_only, ml_scores_json = await self._get_scores(raw_caption)
|
||||
|
||||
post_caption = ""
|
||||
if message.caption:
|
||||
post_caption = get_text_message(message.caption.lower(), first_name, message.from_user.username)
|
||||
post_caption = get_text_message(
|
||||
message.caption.lower(),
|
||||
first_name,
|
||||
message.from_user.username,
|
||||
deepseek_score=deepseek_score,
|
||||
rag_score=rag_score,
|
||||
rag_confidence=rag_confidence,
|
||||
rag_score_pos_only=rag_score_pos_only,
|
||||
)
|
||||
|
||||
markup = get_reply_keyboard_for_post()
|
||||
sent_message = await send_video_message(
|
||||
self.settings.group_for_posts, message, message.video.file_id, post_caption, markup
|
||||
)
|
||||
|
||||
# Сохраняем сырой caption и определяем анонимность
|
||||
raw_caption = message.caption or ""
|
||||
# Определяем анонимность
|
||||
is_anonymous = determine_anonymity(raw_caption)
|
||||
|
||||
post = TelegramPost(
|
||||
@@ -220,8 +299,11 @@ class PostService:
|
||||
is_anonymous=is_anonymous
|
||||
)
|
||||
await self.db.add_post(post)
|
||||
# Сохраняем медиа в фоне, чтобы не блокировать ответ пользователю
|
||||
|
||||
# Сохраняем медиа и скоры в фоне
|
||||
asyncio.create_task(self._save_media_background(sent_message, self.db, self.s3_storage))
|
||||
if ml_scores_json:
|
||||
asyncio.create_task(self._save_scores_background(sent_message.message_id, ml_scores_json))
|
||||
|
||||
@track_time("handle_video_note_post", "post_service")
|
||||
@track_errors("post_service", "handle_video_note_post")
|
||||
@@ -253,17 +335,29 @@ class PostService:
|
||||
@db_query_time("handle_audio_post", "posts", "insert")
|
||||
async def handle_audio_post(self, message: types.Message, first_name: str) -> None:
|
||||
"""Handle audio post submission"""
|
||||
raw_caption = message.caption or ""
|
||||
|
||||
# Получаем скоры для текста
|
||||
deepseek_score, rag_score, rag_confidence, rag_score_pos_only, ml_scores_json = await self._get_scores(raw_caption)
|
||||
|
||||
post_caption = ""
|
||||
if message.caption:
|
||||
post_caption = get_text_message(message.caption.lower(), first_name, message.from_user.username)
|
||||
post_caption = get_text_message(
|
||||
message.caption.lower(),
|
||||
first_name,
|
||||
message.from_user.username,
|
||||
deepseek_score=deepseek_score,
|
||||
rag_score=rag_score,
|
||||
rag_confidence=rag_confidence,
|
||||
rag_score_pos_only=rag_score_pos_only,
|
||||
)
|
||||
|
||||
markup = get_reply_keyboard_for_post()
|
||||
sent_message = await send_audio_message(
|
||||
self.settings.group_for_posts, message, message.audio.file_id, post_caption, markup
|
||||
)
|
||||
|
||||
# Сохраняем сырой caption и определяем анонимность
|
||||
raw_caption = message.caption or ""
|
||||
# Определяем анонимность
|
||||
is_anonymous = determine_anonymity(raw_caption)
|
||||
|
||||
post = TelegramPost(
|
||||
@@ -274,8 +368,11 @@ class PostService:
|
||||
is_anonymous=is_anonymous
|
||||
)
|
||||
await self.db.add_post(post)
|
||||
# Сохраняем медиа в фоне, чтобы не блокировать ответ пользователю
|
||||
|
||||
# Сохраняем медиа и скоры в фоне
|
||||
asyncio.create_task(self._save_media_background(sent_message, self.db, self.s3_storage))
|
||||
if ml_scores_json:
|
||||
asyncio.create_task(self._save_scores_background(sent_message.message_id, ml_scores_json))
|
||||
|
||||
@track_time("handle_voice_post", "post_service")
|
||||
@track_errors("post_service", "handle_voice_post")
|
||||
@@ -310,10 +407,23 @@ class PostService:
|
||||
"""Handle media group post submission"""
|
||||
post_caption = " "
|
||||
raw_caption = ""
|
||||
ml_scores_json = None
|
||||
|
||||
if album and album[0].caption:
|
||||
raw_caption = album[0].caption or ""
|
||||
post_caption = get_text_message(album[0].caption.lower(), first_name, message.from_user.username)
|
||||
|
||||
# Получаем скоры для текста
|
||||
deepseek_score, rag_score, rag_confidence, rag_score_pos_only, ml_scores_json = await self._get_scores(raw_caption)
|
||||
|
||||
post_caption = get_text_message(
|
||||
album[0].caption.lower(),
|
||||
first_name,
|
||||
message.from_user.username,
|
||||
deepseek_score=deepseek_score,
|
||||
rag_score=rag_score,
|
||||
rag_confidence=rag_confidence,
|
||||
rag_score_pos_only=rag_score_pos_only,
|
||||
)
|
||||
|
||||
is_anonymous = determine_anonymity(raw_caption)
|
||||
media_group = await prepare_media_group_from_middlewares(album, post_caption)
|
||||
@@ -333,6 +443,10 @@ class PostService:
|
||||
)
|
||||
await self.db.add_post(main_post)
|
||||
|
||||
# Сохраняем скоры в фоне
|
||||
if ml_scores_json:
|
||||
asyncio.create_task(self._save_scores_background(main_post_id, ml_scores_json))
|
||||
|
||||
for msg_id in media_group_message_ids:
|
||||
await self.db.add_message_link(main_post_id, msg_id)
|
||||
|
||||
|
||||
@@ -47,6 +47,9 @@ def get_reply_keyboard_admin():
|
||||
)
|
||||
builder.row(
|
||||
types.KeyboardButton(text="Разбан (список)"),
|
||||
types.KeyboardButton(text="📊 ML Статистика")
|
||||
)
|
||||
builder.row(
|
||||
types.KeyboardButton(text="Вернуться в бота")
|
||||
)
|
||||
markup = builder.as_markup(resize_keyboard=True, one_time_keyboard=True)
|
||||
|
||||
@@ -78,6 +78,22 @@ async def start_bot(bdf):
|
||||
|
||||
await bot.delete_webhook(drop_pending_updates=True)
|
||||
|
||||
# Загружаем примеры для RAG из базы данных
|
||||
scoring_manager = bdf.get_scoring_manager()
|
||||
if scoring_manager and scoring_manager.rag_service and scoring_manager.rag_service.is_enabled:
|
||||
try:
|
||||
db = bdf.get_db()
|
||||
positive_texts = await db.get_approved_posts_texts(limit=5000)
|
||||
negative_texts = await db.get_declined_posts_texts(limit=5000)
|
||||
|
||||
if positive_texts or negative_texts:
|
||||
await scoring_manager.load_examples_from_db(positive_texts, negative_texts)
|
||||
logging.info(f"RAG: Загружено {len(positive_texts)} положительных и {len(negative_texts)} отрицательных примеров")
|
||||
else:
|
||||
logging.warning("RAG: Нет примеров в базе данных для загрузки")
|
||||
except Exception as e:
|
||||
logging.error(f"Ошибка загрузки примеров для RAG: {e}")
|
||||
|
||||
# Запускаем HTTP сервер для метрик параллельно с ботом
|
||||
metrics_host = bdf.settings.get('Metrics', {}).get('host', '0.0.0.0')
|
||||
metrics_port = bdf.settings.get('Metrics', {}).get('port', 8080)
|
||||
|
||||
5
helper_bot/services/__init__.py
Normal file
5
helper_bot/services/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""
|
||||
Сервисы приложения.
|
||||
|
||||
Содержит бизнес-логику, не связанную напрямую с handlers.
|
||||
"""
|
||||
42
helper_bot/services/scoring/__init__.py
Normal file
42
helper_bot/services/scoring/__init__.py
Normal file
@@ -0,0 +1,42 @@
|
||||
"""
|
||||
Сервисы для ML-скоринга постов.
|
||||
|
||||
Включает:
|
||||
- RAGService - локальное векторное сравнение с ruBERT
|
||||
- DeepSeekService - интеграция с DeepSeek API
|
||||
- ScoringManager - объединение всех сервисов скоринга
|
||||
- VectorStore - in-memory хранилище векторов
|
||||
"""
|
||||
|
||||
from .base import ScoringResult, ScoringServiceProtocol, CombinedScore
|
||||
from .exceptions import (
|
||||
ScoringError,
|
||||
ModelNotLoadedError,
|
||||
VectorStoreError,
|
||||
DeepSeekAPIError,
|
||||
InsufficientExamplesError,
|
||||
TextTooShortError,
|
||||
)
|
||||
from .vector_store import VectorStore
|
||||
from .rag_service import RAGService
|
||||
from .deepseek_service import DeepSeekService
|
||||
from .scoring_manager import ScoringManager
|
||||
|
||||
__all__ = [
|
||||
# Базовые классы
|
||||
"ScoringResult",
|
||||
"ScoringServiceProtocol",
|
||||
"CombinedScore",
|
||||
# Исключения
|
||||
"ScoringError",
|
||||
"ModelNotLoadedError",
|
||||
"VectorStoreError",
|
||||
"DeepSeekAPIError",
|
||||
"InsufficientExamplesError",
|
||||
"TextTooShortError",
|
||||
# Сервисы
|
||||
"VectorStore",
|
||||
"RAGService",
|
||||
"DeepSeekService",
|
||||
"ScoringManager",
|
||||
]
|
||||
155
helper_bot/services/scoring/base.py
Normal file
155
helper_bot/services/scoring/base.py
Normal file
@@ -0,0 +1,155 @@
|
||||
"""
|
||||
Базовые классы и протоколы для сервисов скоринга.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, Protocol, Dict, Any
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScoringResult:
|
||||
"""
|
||||
Результат оценки поста от одного сервиса.
|
||||
|
||||
Attributes:
|
||||
score: Оценка от 0.0 до 1.0 (вероятность публикации)
|
||||
source: Источник оценки ("deepseek", "rag", etc.)
|
||||
model: Название используемой модели
|
||||
confidence: Уверенность в оценке (опционально)
|
||||
timestamp: Время получения оценки
|
||||
metadata: Дополнительные данные
|
||||
"""
|
||||
score: float
|
||||
source: str
|
||||
model: str
|
||||
confidence: Optional[float] = None
|
||||
timestamp: int = field(default_factory=lambda: int(datetime.now().timestamp()))
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self):
|
||||
"""Валидация score в диапазоне [0.0, 1.0]."""
|
||||
if not 0.0 <= self.score <= 1.0:
|
||||
raise ValueError(f"Score должен быть в диапазоне [0.0, 1.0], получено: {self.score}")
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Преобразует результат в словарь для сохранения в JSON."""
|
||||
result = {
|
||||
"score": round(self.score, 4),
|
||||
"model": self.model,
|
||||
"ts": self.timestamp,
|
||||
}
|
||||
if self.confidence is not None:
|
||||
result["confidence"] = round(self.confidence, 4)
|
||||
if self.metadata:
|
||||
result["metadata"] = self.metadata
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, source: str, data: Dict[str, Any]) -> "ScoringResult":
|
||||
"""Создает ScoringResult из словаря."""
|
||||
return cls(
|
||||
score=data["score"],
|
||||
source=source,
|
||||
model=data.get("model", "unknown"),
|
||||
confidence=data.get("confidence"),
|
||||
timestamp=data.get("ts", int(datetime.now().timestamp())),
|
||||
metadata=data.get("metadata", {}),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CombinedScore:
|
||||
"""
|
||||
Объединенный результат от всех сервисов скоринга.
|
||||
|
||||
Attributes:
|
||||
deepseek: Результат от DeepSeek API (None если отключен/ошибка)
|
||||
rag: Результат от RAG сервиса (None если отключен/ошибка)
|
||||
errors: Словарь с ошибками по источникам
|
||||
"""
|
||||
deepseek: Optional[ScoringResult] = None
|
||||
rag: Optional[ScoringResult] = None
|
||||
errors: Dict[str, str] = field(default_factory=dict)
|
||||
|
||||
@property
|
||||
def deepseek_score(self) -> Optional[float]:
|
||||
"""Возвращает только числовой скор от DeepSeek."""
|
||||
return self.deepseek.score if self.deepseek else None
|
||||
|
||||
@property
|
||||
def rag_score(self) -> Optional[float]:
|
||||
"""Возвращает только числовой скор от RAG."""
|
||||
return self.rag.score if self.rag else None
|
||||
|
||||
def to_json_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Преобразует в словарь для сохранения в ml_scores колонку.
|
||||
|
||||
Формат:
|
||||
{
|
||||
"deepseek": {"score": 0.75, "model": "...", "ts": ...},
|
||||
"rag": {"score": 0.90, "model": "...", "ts": ...}
|
||||
}
|
||||
"""
|
||||
result = {}
|
||||
if self.deepseek:
|
||||
result["deepseek"] = self.deepseek.to_dict()
|
||||
if self.rag:
|
||||
result["rag"] = self.rag.to_dict()
|
||||
return result
|
||||
|
||||
def has_any_score(self) -> bool:
|
||||
"""Проверяет, есть ли хотя бы один успешный скор."""
|
||||
return self.deepseek is not None or self.rag is not None
|
||||
|
||||
|
||||
class ScoringServiceProtocol(Protocol):
|
||||
"""
|
||||
Протокол для сервисов скоринга.
|
||||
|
||||
Любой сервис скоринга должен реализовывать эти методы.
|
||||
"""
|
||||
|
||||
@property
|
||||
def source_name(self) -> str:
|
||||
"""Возвращает имя источника ("deepseek", "rag", etc.)."""
|
||||
...
|
||||
|
||||
@property
|
||||
def is_enabled(self) -> bool:
|
||||
"""Проверяет, включен ли сервис."""
|
||||
...
|
||||
|
||||
async def calculate_score(self, text: str) -> ScoringResult:
|
||||
"""
|
||||
Рассчитывает скор для текста поста.
|
||||
|
||||
Args:
|
||||
text: Текст поста для оценки
|
||||
|
||||
Returns:
|
||||
ScoringResult с оценкой
|
||||
|
||||
Raises:
|
||||
ScoringError: При ошибке расчета
|
||||
"""
|
||||
...
|
||||
|
||||
async def add_positive_example(self, text: str) -> None:
|
||||
"""
|
||||
Добавляет текст как положительный пример (опубликованный пост).
|
||||
|
||||
Args:
|
||||
text: Текст опубликованного поста
|
||||
"""
|
||||
...
|
||||
|
||||
async def add_negative_example(self, text: str) -> None:
|
||||
"""
|
||||
Добавляет текст как отрицательный пример (отклоненный пост).
|
||||
|
||||
Args:
|
||||
text: Текст отклоненного поста
|
||||
"""
|
||||
...
|
||||
358
helper_bot/services/scoring/deepseek_service.py
Normal file
358
helper_bot/services/scoring/deepseek_service.py
Normal file
@@ -0,0 +1,358 @@
|
||||
"""
|
||||
DeepSeek API сервис для скоринга постов.
|
||||
|
||||
Использует DeepSeek API для семантической оценки релевантности поста.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Optional, List
|
||||
|
||||
import httpx
|
||||
|
||||
from logs.custom_logger import logger
|
||||
from helper_bot.utils.metrics import track_time, track_errors
|
||||
|
||||
from .base import ScoringResult
|
||||
from .exceptions import DeepSeekAPIError, ScoringError, TextTooShortError
|
||||
|
||||
|
||||
class DeepSeekService:
|
||||
"""
|
||||
Сервис для оценки постов через DeepSeek API.
|
||||
|
||||
Отправляет текст поста в DeepSeek с промптом для оценки
|
||||
и получает числовой скор релевантности.
|
||||
|
||||
Attributes:
|
||||
api_key: API ключ DeepSeek
|
||||
api_url: URL API эндпоинта
|
||||
model: Название модели
|
||||
timeout: Таймаут запроса в секундах
|
||||
"""
|
||||
|
||||
# Промпт для оценки поста
|
||||
SCORING_PROMPT = """Роль: Ты — строгий и внимательный модератор сообщества в социальной сети, ориентированного на знакомства между людьми. Твоя задача — оценить, можно ли опубликовать пост, основываясь на четких правилах.
|
||||
|
||||
Контекст группы: Это группа для поиска и знакомства с людьми. Пользователи могут искать кого угодно: случайно увиденных на улице, в транспорте, в кафе, старых знакомых, новых друзей или пару. Это главная и единственная цель группы.
|
||||
|
||||
---
|
||||
|
||||
ПРАВИЛА ЗАПРЕТА (пост НЕ ДОЛЖЕН быть опубликован, если содержит это):
|
||||
|
||||
1. Запрещенные законом тематики: Любые призывы, обсуждение или поиск чего-либо незаконного (наркотики, оружие, мошенничество, насилие и т.д.).
|
||||
2. Поиск и утеря животных, найденные предметы: Запрещены посты про потерявшихся/найденных кошек, собак, хомяков, а также про потерянные/найденные телефоны, ключи, сумки и т.п.
|
||||
3. Конкуренция (Дайвинчик): Любое упоминание группы/проекта/чата "Дайвинчик" или любых других групп-конкурентов. Запрещены призывы переходить в другие сообщества.
|
||||
4. Сбор больших компаний и групп: Запрещены посты с целью собрать большую тусовку, компанию, группу для похода, вечеринки, игры и т.д. (например, "собираем команду для футбола", "кто хочет на квартиру?").
|
||||
5. Организация чатов и других сообществ: Запрещено создание или реклама сторонних чатов, каналов, групп в телеграме, дискорде и т.п.
|
||||
|
||||
---
|
||||
|
||||
ПРАВИЛА РАЗРЕШЕНИЯ (пост МОЖЕТ быть опубликован, если):
|
||||
|
||||
· Цель — найти конкретного человека или познакомиться с кем-то новым.
|
||||
· Формат: Описание человека, обстоятельств встречи, примет, места и времени. Или прямой призыв к знакомству.
|
||||
· Примеры ДОПУСТИМЫХ постов (ориентируйся на них):
|
||||
· "мальчики нефоры/патлатые, гоу знакомиться😻 анон"
|
||||
· "ищу девочку, ехала на 21 автобусе примерно в 15:20. села на детской поликлинике и вышла в заречье вся в черной одежде и с черным баулом"
|
||||
· "ищу мальчика ехали на 35 автобусе часов в 7 вечера я была с девочками,у нас с тобой еще куртки одинаковые ,я рядом с тобой сидела,напиши в комментарии если у тебя нету девочки. анон админу любви."
|
||||
|
||||
---
|
||||
|
||||
ИНСТРУКЦИЯ ПО ОЦЕНКЕ:
|
||||
|
||||
Проанализируй полученный пост и присвой ему итоговый Вес (Score) от 0.0 до 1.0, где:
|
||||
|
||||
· 1.0 — Пост полностью соответствует правилам. Цель — найти/познакомиться с человеком. Ничего из списка запретов не нарушено. Можно публиковать.
|
||||
· 0.0 — Пост категорически нарушает правила. Содержит явные признаки одного или нескольких пунктов из списка запрета. Публиковать НЕЛЬЗЯ.
|
||||
· 0.2 - 0.8 — Пост находится в "серой зоне". Присваивай промежуточный вес, оценивая степень риска и соответствия цели группы.
|
||||
· Ближе к 0.2: Сильно сомнительный пост, есть явные признаки запрещенной темы (например, упоминание "собраться компанией", косвенная реклама другого места).
|
||||
· 0.5: Нейтральный или неочевидный пост. Нужно проверить, нет ли скрытого смысла, нарушающего правила.
|
||||
· Ближе к 0.8: В целом допустимый пост, но с небольшими странностями или двусмысленностями, не нарушающими правила напрямую.
|
||||
---
|
||||
{text}
|
||||
---
|
||||
|
||||
Ответь ТОЛЬКО числом от 0.0 до 1.0, без дополнительных объяснений.
|
||||
Пример ответа: 0.75"""
|
||||
|
||||
DEFAULT_API_URL = "https://api.deepseek.com/v1/chat/completions"
|
||||
DEFAULT_MODEL = "deepseek-chat"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
api_url: Optional[str] = None,
|
||||
model: Optional[str] = None,
|
||||
timeout: int = 30,
|
||||
enabled: bool = True,
|
||||
min_text_length: int = 3,
|
||||
max_retries: int = 3,
|
||||
):
|
||||
"""
|
||||
Инициализация DeepSeek сервиса.
|
||||
|
||||
Args:
|
||||
api_key: API ключ DeepSeek
|
||||
api_url: URL API эндпоинта
|
||||
model: Название модели
|
||||
timeout: Таймаут запроса в секундах
|
||||
enabled: Включен ли сервис
|
||||
min_text_length: Минимальная длина текста для обработки
|
||||
max_retries: Максимальное количество повторных попыток
|
||||
"""
|
||||
self.api_key = api_key
|
||||
self.api_url = api_url or self.DEFAULT_API_URL
|
||||
self.model = model or self.DEFAULT_MODEL
|
||||
self.timeout = timeout
|
||||
self._enabled = enabled and bool(api_key)
|
||||
self.min_text_length = min_text_length
|
||||
self.max_retries = max_retries
|
||||
|
||||
# HTTP клиент (создается лениво)
|
||||
self._client: Optional[httpx.AsyncClient] = None
|
||||
|
||||
if not api_key and enabled:
|
||||
logger.warning("DeepSeekService: API ключ не указан, сервис отключен")
|
||||
self._enabled = False
|
||||
|
||||
logger.info(
|
||||
f"DeepSeekService инициализирован "
|
||||
f"(model={self.model}, enabled={self._enabled})"
|
||||
)
|
||||
|
||||
@property
|
||||
def source_name(self) -> str:
|
||||
"""Имя источника для результатов."""
|
||||
return "deepseek"
|
||||
|
||||
@property
|
||||
def is_enabled(self) -> bool:
|
||||
"""Проверяет, включен ли сервис."""
|
||||
return self._enabled
|
||||
|
||||
async def _get_client(self) -> httpx.AsyncClient:
|
||||
"""Получает или создает HTTP клиент."""
|
||||
if self._client is None:
|
||||
self._client = httpx.AsyncClient(
|
||||
timeout=httpx.Timeout(self.timeout),
|
||||
headers={
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
return self._client
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Закрывает HTTP клиент."""
|
||||
if self._client:
|
||||
await self._client.aclose()
|
||||
self._client = None
|
||||
|
||||
def _clean_text(self, text: str) -> str:
|
||||
"""Очищает текст от лишних символов."""
|
||||
if not text:
|
||||
return ""
|
||||
|
||||
# Удаляем лишние пробелы и переносы строк
|
||||
clean = " ".join(text.split())
|
||||
|
||||
# Удаляем служебные символы
|
||||
if clean == "^":
|
||||
return ""
|
||||
|
||||
return clean.strip()
|
||||
|
||||
def _parse_score_response(self, response_text: str) -> float:
|
||||
"""
|
||||
Парсит ответ от DeepSeek и извлекает скор.
|
||||
|
||||
Args:
|
||||
response_text: Текст ответа от API
|
||||
|
||||
Returns:
|
||||
Числовой скор от 0.0 до 1.0
|
||||
|
||||
Raises:
|
||||
DeepSeekAPIError: Если не удалось распарсить ответ
|
||||
"""
|
||||
try:
|
||||
# Пытаемся найти число в ответе
|
||||
text = response_text.strip()
|
||||
|
||||
# Убираем возможные обрамления
|
||||
text = text.strip('"\'`')
|
||||
|
||||
# Пробуем распарсить как число
|
||||
score = float(text)
|
||||
|
||||
# Ограничиваем диапазон
|
||||
score = max(0.0, min(1.0, score))
|
||||
|
||||
return score
|
||||
|
||||
except ValueError:
|
||||
# Пробуем найти число в тексте
|
||||
import re
|
||||
matches = re.findall(r'0\.\d+|1\.0|0|1', text)
|
||||
if matches:
|
||||
score = float(matches[0])
|
||||
return max(0.0, min(1.0, score))
|
||||
|
||||
logger.error(f"DeepSeekService: Не удалось распарсить ответ: {response_text}")
|
||||
raise DeepSeekAPIError(f"Не удалось распарсить скор из ответа: {response_text}")
|
||||
|
||||
@track_time("calculate_score", "deepseek_service")
|
||||
@track_errors("deepseek_service", "calculate_score")
|
||||
async def calculate_score(self, text: str) -> ScoringResult:
|
||||
"""
|
||||
Рассчитывает скор для текста поста через DeepSeek API.
|
||||
|
||||
Args:
|
||||
text: Текст поста для оценки
|
||||
|
||||
Returns:
|
||||
ScoringResult с оценкой
|
||||
|
||||
Raises:
|
||||
ScoringError: При ошибке расчета
|
||||
"""
|
||||
if not self._enabled:
|
||||
raise ScoringError("DeepSeek сервис отключен")
|
||||
|
||||
# Очищаем текст
|
||||
clean_text = self._clean_text(text)
|
||||
|
||||
if len(clean_text) < self.min_text_length:
|
||||
raise TextTooShortError(
|
||||
f"Текст слишком короткий (минимум {self.min_text_length} символов)"
|
||||
)
|
||||
|
||||
# Формируем промпт
|
||||
prompt = self.SCORING_PROMPT.format(text=clean_text)
|
||||
|
||||
# Выполняем запрос с повторными попытками
|
||||
last_error = None
|
||||
for attempt in range(self.max_retries):
|
||||
try:
|
||||
score = await self._make_api_request(prompt)
|
||||
|
||||
return ScoringResult(
|
||||
score=score,
|
||||
source=self.source_name,
|
||||
model=self.model,
|
||||
metadata={
|
||||
"text_length": len(clean_text),
|
||||
"attempt": attempt + 1,
|
||||
},
|
||||
)
|
||||
|
||||
except DeepSeekAPIError as e:
|
||||
last_error = e
|
||||
logger.warning(
|
||||
f"DeepSeekService: Попытка {attempt + 1}/{self.max_retries} "
|
||||
f"не удалась: {e}"
|
||||
)
|
||||
if attempt < self.max_retries - 1:
|
||||
# Экспоненциальная задержка
|
||||
await asyncio.sleep(2 ** attempt)
|
||||
|
||||
raise ScoringError(f"Все попытки запроса к DeepSeek API не удались: {last_error}")
|
||||
|
||||
async def _make_api_request(self, prompt: str) -> float:
|
||||
"""
|
||||
Выполняет запрос к DeepSeek API.
|
||||
|
||||
Args:
|
||||
prompt: Промпт для отправки
|
||||
|
||||
Returns:
|
||||
Числовой скор от 0.0 до 1.0
|
||||
|
||||
Raises:
|
||||
DeepSeekAPIError: При ошибке API
|
||||
"""
|
||||
client = await self._get_client()
|
||||
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": prompt,
|
||||
}
|
||||
],
|
||||
"temperature": 0.1, # Низкая температура для детерминированности
|
||||
"max_tokens": 10, # Ожидаем только число
|
||||
}
|
||||
|
||||
try:
|
||||
response = await client.post(self.api_url, json=payload)
|
||||
response.raise_for_status()
|
||||
|
||||
data = response.json()
|
||||
|
||||
# Извлекаем ответ
|
||||
if "choices" not in data or not data["choices"]:
|
||||
raise DeepSeekAPIError("Пустой ответ от API")
|
||||
|
||||
response_text = data["choices"][0]["message"]["content"]
|
||||
|
||||
# Парсим скор
|
||||
score = self._parse_score_response(response_text)
|
||||
|
||||
logger.debug(f"DeepSeekService: Получен скор {score} для текста")
|
||||
return score
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
error_msg = f"HTTP ошибка {e.response.status_code}"
|
||||
try:
|
||||
error_data = e.response.json()
|
||||
if "error" in error_data:
|
||||
error_msg = error_data["error"].get("message", error_msg)
|
||||
except Exception:
|
||||
pass
|
||||
raise DeepSeekAPIError(error_msg)
|
||||
|
||||
except httpx.TimeoutException:
|
||||
raise DeepSeekAPIError(f"Таймаут запроса ({self.timeout}s)")
|
||||
|
||||
except Exception as e:
|
||||
raise DeepSeekAPIError(f"Ошибка запроса: {e}")
|
||||
|
||||
async def add_positive_example(self, text: str) -> None:
|
||||
"""
|
||||
Добавляет текст как положительный пример.
|
||||
|
||||
Для DeepSeek не требуется хранить примеры - оценка выполняется
|
||||
на основе промпта. Метод существует для совместимости с протоколом.
|
||||
|
||||
Args:
|
||||
text: Текст опубликованного поста
|
||||
"""
|
||||
# DeepSeek не использует примеры для обучения
|
||||
# Промпт уже содержит критерии оценки
|
||||
pass
|
||||
|
||||
async def add_negative_example(self, text: str) -> None:
|
||||
"""
|
||||
Добавляет текст как отрицательный пример.
|
||||
|
||||
Для DeepSeek не требуется хранить примеры - оценка выполняется
|
||||
на основе промпта. Метод существует для совместимости с протоколом.
|
||||
|
||||
Args:
|
||||
text: Текст отклоненного поста
|
||||
"""
|
||||
# DeepSeek не использует примеры для обучения
|
||||
pass
|
||||
|
||||
def get_stats(self) -> dict:
|
||||
"""Возвращает статистику сервиса."""
|
||||
return {
|
||||
"enabled": self._enabled,
|
||||
"model": self.model,
|
||||
"api_url": self.api_url,
|
||||
"timeout": self.timeout,
|
||||
"max_retries": self.max_retries,
|
||||
}
|
||||
33
helper_bot/services/scoring/exceptions.py
Normal file
33
helper_bot/services/scoring/exceptions.py
Normal file
@@ -0,0 +1,33 @@
|
||||
"""
|
||||
Исключения для сервисов скоринга.
|
||||
"""
|
||||
|
||||
|
||||
class ScoringError(Exception):
|
||||
"""Базовое исключение для ошибок скоринга."""
|
||||
pass
|
||||
|
||||
|
||||
class ModelNotLoadedError(ScoringError):
|
||||
"""Модель не загружена или недоступна."""
|
||||
pass
|
||||
|
||||
|
||||
class VectorStoreError(ScoringError):
|
||||
"""Ошибка при работе с хранилищем векторов."""
|
||||
pass
|
||||
|
||||
|
||||
class DeepSeekAPIError(ScoringError):
|
||||
"""Ошибка при обращении к DeepSeek API."""
|
||||
pass
|
||||
|
||||
|
||||
class InsufficientExamplesError(ScoringError):
|
||||
"""Недостаточно примеров для расчета скора."""
|
||||
pass
|
||||
|
||||
|
||||
class TextTooShortError(ScoringError):
|
||||
"""Текст слишком короткий для векторизации."""
|
||||
pass
|
||||
507
helper_bot/services/scoring/rag_service.py
Normal file
507
helper_bot/services/scoring/rag_service.py
Normal file
@@ -0,0 +1,507 @@
|
||||
"""
|
||||
RAG сервис для скоринга постов с использованием ruBERT.
|
||||
|
||||
Использует модель DeepPavlov/rubert-base-cased для создания эмбеддингов
|
||||
и сравнивает их с эталонными примерами через VectorStore.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Optional, List
|
||||
|
||||
import numpy as np
|
||||
|
||||
from logs.custom_logger import logger
|
||||
from helper_bot.utils.metrics import track_time, track_errors
|
||||
|
||||
from .base import ScoringResult
|
||||
from .vector_store import VectorStore
|
||||
from .exceptions import (
|
||||
ModelNotLoadedError,
|
||||
ScoringError,
|
||||
InsufficientExamplesError,
|
||||
TextTooShortError,
|
||||
)
|
||||
|
||||
|
||||
class RAGService:
|
||||
"""
|
||||
RAG сервис для оценки постов на основе векторного сходства.
|
||||
|
||||
Использует ruBERT для создания эмбеддингов текста и сравнивает
|
||||
их с эталонными примерами (опубликованные vs отклоненные посты).
|
||||
|
||||
Attributes:
|
||||
model_name: Название модели HuggingFace
|
||||
vector_store: Хранилище векторов
|
||||
min_text_length: Минимальная длина текста для обработки
|
||||
"""
|
||||
|
||||
# Название модели по умолчанию
|
||||
DEFAULT_MODEL = "DeepPavlov/rubert-base-cased"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: Optional[str] = None,
|
||||
vector_store: Optional[VectorStore] = None,
|
||||
cache_dir: Optional[str] = None,
|
||||
enabled: bool = True,
|
||||
min_text_length: int = 3,
|
||||
):
|
||||
"""
|
||||
Инициализация RAG сервиса.
|
||||
|
||||
Args:
|
||||
model_name: Название модели HuggingFace (по умолчанию ruBERT)
|
||||
vector_store: Хранилище векторов (создается автоматически если не передано)
|
||||
cache_dir: Директория для кеширования модели
|
||||
enabled: Включен ли сервис
|
||||
min_text_length: Минимальная длина текста для обработки
|
||||
"""
|
||||
self.model_name = model_name or self.DEFAULT_MODEL
|
||||
self.cache_dir = cache_dir
|
||||
self._enabled = enabled
|
||||
self.min_text_length = min_text_length
|
||||
|
||||
# Модель и токенизатор загружаются лениво
|
||||
self._model = None
|
||||
self._tokenizer = None
|
||||
self._model_loaded = False
|
||||
|
||||
# Хранилище векторов
|
||||
self.vector_store = vector_store or VectorStore()
|
||||
|
||||
logger.info(f"RAGService инициализирован (model={self.model_name}, enabled={enabled})")
|
||||
|
||||
@property
|
||||
def source_name(self) -> str:
|
||||
"""Имя источника для результатов."""
|
||||
return "rag"
|
||||
|
||||
@property
|
||||
def is_enabled(self) -> bool:
|
||||
"""Проверяет, включен ли сервис."""
|
||||
return self._enabled
|
||||
|
||||
@property
|
||||
def is_model_loaded(self) -> bool:
|
||||
"""Проверяет, загружена ли модель."""
|
||||
return self._model_loaded
|
||||
|
||||
async def load_model(self) -> None:
|
||||
"""
|
||||
Загружает модель и токенизатор.
|
||||
|
||||
Выполняется асинхронно в отдельном потоке чтобы не блокировать event loop.
|
||||
"""
|
||||
if self._model_loaded:
|
||||
return
|
||||
|
||||
if not self._enabled:
|
||||
logger.warning("RAGService: Сервис отключен, модель не загружается")
|
||||
return
|
||||
|
||||
logger.info(f"RAGService: Загрузка модели {self.model_name}...")
|
||||
|
||||
try:
|
||||
# Загрузка в отдельном потоке
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(None, self._load_model_sync)
|
||||
|
||||
self._model_loaded = True
|
||||
logger.info(f"RAGService: Модель {self.model_name} успешно загружена")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"RAGService: Ошибка загрузки модели: {e}")
|
||||
raise ModelNotLoadedError(f"Не удалось загрузить модель {self.model_name}: {e}")
|
||||
|
||||
def _load_model_sync(self) -> None:
|
||||
"""Синхронная загрузка модели (вызывается в executor)."""
|
||||
logger.info("RAGService: Начало _load_model_sync, импорт transformers...")
|
||||
from transformers import AutoTokenizer, AutoModel
|
||||
import torch
|
||||
|
||||
# Определяем устройство
|
||||
self._device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
logger.info(f"RAGService: Устройство определено: {self._device}")
|
||||
|
||||
# Загружаем токенизатор
|
||||
logger.info(f"RAGService: Загрузка токенизатора из {self.model_name}...")
|
||||
self._tokenizer = AutoTokenizer.from_pretrained(
|
||||
self.model_name,
|
||||
cache_dir=self.cache_dir,
|
||||
)
|
||||
logger.info("RAGService: Токенизатор загружен")
|
||||
|
||||
# Загружаем модель
|
||||
logger.info(f"RAGService: Загрузка модели из {self.model_name} (это может занять несколько минут)...")
|
||||
self._model = AutoModel.from_pretrained(
|
||||
self.model_name,
|
||||
cache_dir=self.cache_dir,
|
||||
)
|
||||
logger.info("RAGService: Модель загружена, перенос на устройство...")
|
||||
self._model.to(self._device)
|
||||
self._model.eval() # Режим инференса
|
||||
|
||||
logger.info(f"RAGService: Модель готова на устройстве: {self._device}")
|
||||
|
||||
def _get_embedding_sync(self, text: str) -> np.ndarray:
|
||||
"""
|
||||
Получает эмбеддинг текста (синхронно).
|
||||
|
||||
Использует [CLS] токен как представление всего текста.
|
||||
|
||||
Args:
|
||||
text: Текст для векторизации
|
||||
|
||||
Returns:
|
||||
Numpy массив с эмбеддингом (768 измерений для ruBERT)
|
||||
"""
|
||||
import torch
|
||||
|
||||
# Токенизация с ограничением длины
|
||||
inputs = self._tokenizer(
|
||||
text,
|
||||
return_tensors="pt",
|
||||
truncation=True,
|
||||
max_length=512,
|
||||
padding=True,
|
||||
)
|
||||
inputs = {k: v.to(self._device) for k, v in inputs.items()}
|
||||
|
||||
# Получаем эмбеддинг
|
||||
with torch.no_grad():
|
||||
outputs = self._model(**inputs)
|
||||
# Используем [CLS] токен (первый токен)
|
||||
embedding = outputs.last_hidden_state[:, 0, :].cpu().numpy()
|
||||
|
||||
return embedding.flatten()
|
||||
|
||||
def _get_embeddings_batch_sync(self, texts: List[str], batch_size: int = 16) -> List[np.ndarray]:
|
||||
"""
|
||||
Получает эмбеддинги для батча текстов (синхронно).
|
||||
|
||||
Обрабатывает тексты пачками для эффективного использования GPU/CPU.
|
||||
|
||||
Args:
|
||||
texts: Список текстов для векторизации
|
||||
batch_size: Размер батча (по умолчанию 16)
|
||||
|
||||
Returns:
|
||||
Список numpy массивов с эмбеддингами
|
||||
"""
|
||||
import torch
|
||||
|
||||
all_embeddings = []
|
||||
|
||||
for i in range(0, len(texts), batch_size):
|
||||
batch_texts = texts[i:i + batch_size]
|
||||
|
||||
# Токенизация батча
|
||||
inputs = self._tokenizer(
|
||||
batch_texts,
|
||||
return_tensors="pt",
|
||||
truncation=True,
|
||||
max_length=512,
|
||||
padding=True,
|
||||
)
|
||||
inputs = {k: v.to(self._device) for k, v in inputs.items()}
|
||||
|
||||
# Получаем эмбеддинги
|
||||
with torch.no_grad():
|
||||
outputs = self._model(**inputs)
|
||||
# [CLS] токен для каждого текста в батче
|
||||
batch_embeddings = outputs.last_hidden_state[:, 0, :].cpu().numpy()
|
||||
|
||||
# Разбиваем на отдельные эмбеддинги
|
||||
for j in range(len(batch_texts)):
|
||||
all_embeddings.append(batch_embeddings[j])
|
||||
|
||||
if i > 0 and i % (batch_size * 10) == 0:
|
||||
logger.info(f"RAGService: Обработано {i}/{len(texts)} текстов")
|
||||
|
||||
return all_embeddings
|
||||
|
||||
async def get_embeddings_batch(self, texts: List[str], batch_size: int = 16) -> List[np.ndarray]:
|
||||
"""
|
||||
Получает эмбеддинги для батча текстов (асинхронно).
|
||||
|
||||
Args:
|
||||
texts: Список текстов для векторизации
|
||||
batch_size: Размер батча
|
||||
|
||||
Returns:
|
||||
Список numpy массивов с эмбеддингами
|
||||
"""
|
||||
if not self._model_loaded:
|
||||
await self.load_model()
|
||||
|
||||
if not self._model_loaded:
|
||||
raise ModelNotLoadedError("Модель не загружена")
|
||||
|
||||
# Очищаем тексты
|
||||
clean_texts = [self._clean_text(text) for text in texts]
|
||||
|
||||
# Выполняем батч-обработку в thread pool
|
||||
loop = asyncio.get_event_loop()
|
||||
embeddings = await loop.run_in_executor(
|
||||
None,
|
||||
self._get_embeddings_batch_sync,
|
||||
clean_texts,
|
||||
batch_size,
|
||||
)
|
||||
|
||||
return embeddings
|
||||
|
||||
async def get_embedding(self, text: str) -> np.ndarray:
|
||||
"""
|
||||
Получает эмбеддинг текста (асинхронно).
|
||||
|
||||
Args:
|
||||
text: Текст для векторизации
|
||||
|
||||
Returns:
|
||||
Numpy массив с эмбеддингом
|
||||
|
||||
Raises:
|
||||
ModelNotLoadedError: Если модель не загружена
|
||||
TextTooShortError: Если текст слишком короткий
|
||||
"""
|
||||
if not self._model_loaded:
|
||||
await self.load_model()
|
||||
|
||||
if not self._model_loaded:
|
||||
raise ModelNotLoadedError("Модель не загружена")
|
||||
|
||||
# Очищаем текст
|
||||
clean_text = self._clean_text(text)
|
||||
|
||||
if len(clean_text) < self.min_text_length:
|
||||
raise TextTooShortError(
|
||||
f"Текст слишком короткий (минимум {self.min_text_length} символов)"
|
||||
)
|
||||
|
||||
# Выполняем в отдельном потоке
|
||||
loop = asyncio.get_event_loop()
|
||||
embedding = await loop.run_in_executor(
|
||||
None,
|
||||
self._get_embedding_sync,
|
||||
clean_text
|
||||
)
|
||||
|
||||
return embedding
|
||||
|
||||
def _clean_text(self, text: str) -> str:
|
||||
"""Очищает текст от лишних символов."""
|
||||
if not text:
|
||||
return ""
|
||||
|
||||
# Удаляем лишние пробелы и переносы строк
|
||||
clean = " ".join(text.split())
|
||||
|
||||
# Удаляем служебные символы (например "^" для helper сообщений)
|
||||
if clean == "^":
|
||||
return ""
|
||||
|
||||
return clean.strip()
|
||||
|
||||
@track_time("calculate_score", "rag_service")
|
||||
@track_errors("rag_service", "calculate_score")
|
||||
async def calculate_score(self, text: str) -> ScoringResult:
|
||||
"""
|
||||
Рассчитывает скор для текста поста.
|
||||
|
||||
Args:
|
||||
text: Текст поста для оценки
|
||||
|
||||
Returns:
|
||||
ScoringResult с оценкой
|
||||
|
||||
Raises:
|
||||
ScoringError: При ошибке расчета
|
||||
"""
|
||||
if not self._enabled:
|
||||
raise ScoringError("RAG сервис отключен")
|
||||
|
||||
try:
|
||||
# Получаем эмбеддинг текста
|
||||
embedding = await self.get_embedding(text)
|
||||
|
||||
# Логируем первые элементы вектора для отладки
|
||||
logger.info(
|
||||
f"RAGService: embedding[:3]={embedding[:3].tolist()}, "
|
||||
f"text_preview='{text[:30]}'"
|
||||
)
|
||||
|
||||
# Рассчитываем скор через VectorStore
|
||||
score, confidence, score_pos_only = self.vector_store.calculate_similarity_score(embedding)
|
||||
|
||||
return ScoringResult(
|
||||
score=score,
|
||||
source=self.source_name,
|
||||
model=self.model_name,
|
||||
confidence=confidence,
|
||||
metadata={
|
||||
"positive_examples": self.vector_store.positive_count,
|
||||
"negative_examples": self.vector_store.negative_count,
|
||||
"score_pos_only": score_pos_only, # Для сравнения
|
||||
},
|
||||
)
|
||||
|
||||
except InsufficientExamplesError:
|
||||
# Не достаточно примеров - возвращаем нейтральный скор
|
||||
logger.warning("RAGService: Недостаточно примеров для расчета скора")
|
||||
raise
|
||||
|
||||
except TextTooShortError:
|
||||
logger.warning(f"RAGService: Текст слишком короткий для оценки")
|
||||
raise
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"RAGService: Ошибка расчета скора: {e}")
|
||||
raise ScoringError(f"Ошибка расчета скора: {e}")
|
||||
|
||||
@track_time("add_positive_example", "rag_service")
|
||||
async def add_positive_example(self, text: str) -> None:
|
||||
"""
|
||||
Добавляет текст как положительный пример (опубликованный пост).
|
||||
|
||||
Args:
|
||||
text: Текст опубликованного поста
|
||||
"""
|
||||
if not self._enabled:
|
||||
return
|
||||
|
||||
try:
|
||||
clean_text = self._clean_text(text)
|
||||
if len(clean_text) < self.min_text_length:
|
||||
logger.debug("RAGService: Текст слишком короткий для примера, пропускаем")
|
||||
return
|
||||
|
||||
# Получаем эмбеддинг
|
||||
embedding = await self.get_embedding(clean_text)
|
||||
|
||||
# Вычисляем хеш для дедупликации
|
||||
text_hash = VectorStore.compute_text_hash(clean_text)
|
||||
|
||||
# Добавляем в хранилище
|
||||
added = self.vector_store.add_positive(embedding, text_hash)
|
||||
|
||||
if added:
|
||||
logger.info(f"RAGService: Добавлен положительный пример")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"RAGService: Ошибка добавления положительного примера: {e}")
|
||||
|
||||
@track_time("add_negative_example", "rag_service")
|
||||
async def add_negative_example(self, text: str) -> None:
|
||||
"""
|
||||
Добавляет текст как отрицательный пример (отклоненный пост).
|
||||
|
||||
Args:
|
||||
text: Текст отклоненного поста
|
||||
"""
|
||||
if not self._enabled:
|
||||
return
|
||||
|
||||
try:
|
||||
clean_text = self._clean_text(text)
|
||||
if len(clean_text) < self.min_text_length:
|
||||
logger.debug("RAGService: Текст слишком короткий для примера, пропускаем")
|
||||
return
|
||||
|
||||
# Получаем эмбеддинг
|
||||
embedding = await self.get_embedding(clean_text)
|
||||
|
||||
# Вычисляем хеш для дедупликации
|
||||
text_hash = VectorStore.compute_text_hash(clean_text)
|
||||
|
||||
# Добавляем в хранилище
|
||||
added = self.vector_store.add_negative(embedding, text_hash)
|
||||
|
||||
if added:
|
||||
logger.info(f"RAGService: Добавлен отрицательный пример")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"RAGService: Ошибка добавления отрицательного примера: {e}")
|
||||
|
||||
async def load_examples_from_db(
|
||||
self,
|
||||
positive_texts: list[str],
|
||||
negative_texts: list[str],
|
||||
batch_size: int = 16,
|
||||
) -> None:
|
||||
"""
|
||||
Загружает примеры из базы данных с батч-обработкой.
|
||||
|
||||
Используется при запуске бота для восстановления VectorStore.
|
||||
Батч-обработка ускоряет загрузку в 10-20 раз.
|
||||
|
||||
Args:
|
||||
positive_texts: Список текстов опубликованных постов
|
||||
negative_texts: Список текстов отклоненных постов
|
||||
batch_size: Размер батча для обработки (по умолчанию 16)
|
||||
"""
|
||||
if not self._enabled:
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"RAGService: Загрузка примеров из БД с батч-обработкой "
|
||||
f"(positive: {len(positive_texts)}, negative: {len(negative_texts)}, batch_size: {batch_size})"
|
||||
)
|
||||
|
||||
# Убеждаемся что модель загружена
|
||||
await self.load_model()
|
||||
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
# Фильтруем и очищаем положительные тексты
|
||||
if positive_texts:
|
||||
clean_positive = []
|
||||
positive_hashes = []
|
||||
for text in positive_texts:
|
||||
clean_text = self._clean_text(text)
|
||||
if len(clean_text) >= self.min_text_length:
|
||||
clean_positive.append(clean_text)
|
||||
positive_hashes.append(VectorStore.compute_text_hash(clean_text))
|
||||
|
||||
if clean_positive:
|
||||
logger.info(f"RAGService: Обработка {len(clean_positive)} положительных примеров батчами...")
|
||||
positive_embeddings = await self.get_embeddings_batch(clean_positive, batch_size)
|
||||
self.vector_store.add_positive_batch(positive_embeddings, positive_hashes)
|
||||
|
||||
# Фильтруем и очищаем отрицательные тексты
|
||||
if negative_texts:
|
||||
clean_negative = []
|
||||
negative_hashes = []
|
||||
for text in negative_texts:
|
||||
clean_text = self._clean_text(text)
|
||||
if len(clean_text) >= self.min_text_length:
|
||||
clean_negative.append(clean_text)
|
||||
negative_hashes.append(VectorStore.compute_text_hash(clean_text))
|
||||
|
||||
if clean_negative:
|
||||
logger.info(f"RAGService: Обработка {len(clean_negative)} отрицательных примеров батчами...")
|
||||
negative_embeddings = await self.get_embeddings_batch(clean_negative, batch_size)
|
||||
self.vector_store.add_negative_batch(negative_embeddings, negative_hashes)
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
logger.info(
|
||||
f"RAGService: Загрузка завершена за {elapsed:.1f} сек "
|
||||
f"(positive: {self.vector_store.positive_count}, "
|
||||
f"negative: {self.vector_store.negative_count})"
|
||||
)
|
||||
|
||||
def save_vectors(self) -> None:
|
||||
"""Сохраняет векторы на диск."""
|
||||
if self.vector_store.storage_path:
|
||||
self.vector_store.save_to_disk()
|
||||
|
||||
def get_stats(self) -> dict:
|
||||
"""Возвращает статистику сервиса."""
|
||||
return {
|
||||
"enabled": self._enabled,
|
||||
"model_name": self.model_name,
|
||||
"model_loaded": self._model_loaded,
|
||||
"vector_store": self.vector_store.get_stats(),
|
||||
}
|
||||
242
helper_bot/services/scoring/scoring_manager.py
Normal file
242
helper_bot/services/scoring/scoring_manager.py
Normal file
@@ -0,0 +1,242 @@
|
||||
"""
|
||||
Менеджер для объединения всех сервисов скоринга.
|
||||
|
||||
Координирует работу RAGService и DeepSeekService,
|
||||
выполняет параллельные запросы и агрегирует результаты.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Optional, List
|
||||
|
||||
from logs.custom_logger import logger
|
||||
from helper_bot.utils.metrics import track_time, track_errors
|
||||
|
||||
from .base import CombinedScore, ScoringResult
|
||||
from .rag_service import RAGService
|
||||
from .deepseek_service import DeepSeekService
|
||||
from .vector_store import VectorStore
|
||||
from .exceptions import ScoringError, InsufficientExamplesError, TextTooShortError
|
||||
|
||||
|
||||
class ScoringManager:
|
||||
"""
|
||||
Менеджер для управления всеми сервисами скоринга.
|
||||
|
||||
Объединяет RAGService и DeepSeekService, выполняет параллельные
|
||||
запросы и агрегирует результаты в единый CombinedScore.
|
||||
|
||||
Attributes:
|
||||
rag_service: Сервис RAG с ruBERT
|
||||
deepseek_service: Сервис DeepSeek API
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
rag_service: Optional[RAGService] = None,
|
||||
deepseek_service: Optional[DeepSeekService] = None,
|
||||
):
|
||||
"""
|
||||
Инициализация менеджера.
|
||||
|
||||
Args:
|
||||
rag_service: Сервис RAG (создается автоматически если не передан)
|
||||
deepseek_service: Сервис DeepSeek (создается автоматически если не передан)
|
||||
"""
|
||||
self.rag_service = rag_service
|
||||
self.deepseek_service = deepseek_service
|
||||
|
||||
logger.info(
|
||||
f"ScoringManager инициализирован "
|
||||
f"(rag={rag_service is not None and rag_service.is_enabled}, "
|
||||
f"deepseek={deepseek_service is not None and deepseek_service.is_enabled})"
|
||||
)
|
||||
|
||||
@property
|
||||
def is_any_enabled(self) -> bool:
|
||||
"""Проверяет, включен ли хотя бы один сервис."""
|
||||
rag_enabled = self.rag_service is not None and self.rag_service.is_enabled
|
||||
deepseek_enabled = self.deepseek_service is not None and self.deepseek_service.is_enabled
|
||||
return rag_enabled or deepseek_enabled
|
||||
|
||||
@track_time("score_post", "scoring_manager")
|
||||
@track_errors("scoring_manager", "score_post")
|
||||
async def score_post(self, text: str) -> CombinedScore:
|
||||
"""
|
||||
Рассчитывает скоры для текста поста от всех сервисов.
|
||||
|
||||
Выполняет запросы параллельно для минимизации задержки.
|
||||
|
||||
Args:
|
||||
text: Текст поста для оценки
|
||||
|
||||
Returns:
|
||||
CombinedScore с результатами от всех сервисов
|
||||
"""
|
||||
result = CombinedScore()
|
||||
|
||||
if not text or not text.strip():
|
||||
logger.debug("ScoringManager: Пустой текст, пропускаем скоринг")
|
||||
return result
|
||||
|
||||
# Собираем задачи для параллельного выполнения
|
||||
tasks = []
|
||||
task_names = []
|
||||
|
||||
# RAG сервис
|
||||
if self.rag_service and self.rag_service.is_enabled:
|
||||
tasks.append(self._get_rag_score(text))
|
||||
task_names.append("rag")
|
||||
|
||||
# DeepSeek сервис
|
||||
if self.deepseek_service and self.deepseek_service.is_enabled:
|
||||
tasks.append(self._get_deepseek_score(text))
|
||||
task_names.append("deepseek")
|
||||
|
||||
if not tasks:
|
||||
logger.debug("ScoringManager: Нет активных сервисов для скоринга")
|
||||
return result
|
||||
|
||||
# Выполняем параллельно
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# Обрабатываем результаты
|
||||
for name, res in zip(task_names, results):
|
||||
if isinstance(res, Exception):
|
||||
error_msg = str(res)
|
||||
result.errors[name] = error_msg
|
||||
logger.warning(f"ScoringManager: Ошибка от {name}: {error_msg}")
|
||||
elif res is not None:
|
||||
if name == "rag":
|
||||
result.rag = res
|
||||
elif name == "deepseek":
|
||||
result.deepseek = res
|
||||
|
||||
logger.info(
|
||||
f"ScoringManager: Скоринг завершен "
|
||||
f"(rag={result.rag_score}, deepseek={result.deepseek_score})"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
async def _get_rag_score(self, text: str) -> Optional[ScoringResult]:
|
||||
"""Получает скор от RAG сервиса."""
|
||||
try:
|
||||
return await self.rag_service.calculate_score(text)
|
||||
except InsufficientExamplesError:
|
||||
# Недостаточно примеров - это не ошибка, просто нет данных
|
||||
logger.info("ScoringManager: RAG - недостаточно примеров")
|
||||
return None
|
||||
except TextTooShortError:
|
||||
# Текст слишком короткий - пропускаем
|
||||
logger.debug("ScoringManager: RAG - текст слишком короткий")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"ScoringManager: RAG ошибка: {e}")
|
||||
raise
|
||||
|
||||
async def _get_deepseek_score(self, text: str) -> Optional[ScoringResult]:
|
||||
"""Получает скор от DeepSeek сервиса."""
|
||||
try:
|
||||
return await self.deepseek_service.calculate_score(text)
|
||||
except TextTooShortError:
|
||||
# Текст слишком короткий - пропускаем
|
||||
logger.debug("ScoringManager: DeepSeek - текст слишком короткий")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"ScoringManager: DeepSeek ошибка: {e}")
|
||||
raise
|
||||
|
||||
@track_time("on_post_published", "scoring_manager")
|
||||
async def on_post_published(self, text: str) -> None:
|
||||
"""
|
||||
Вызывается при публикации поста.
|
||||
|
||||
Добавляет текст как положительный пример для обучения RAG.
|
||||
|
||||
Args:
|
||||
text: Текст опубликованного поста
|
||||
"""
|
||||
if not text or not text.strip():
|
||||
return
|
||||
|
||||
tasks = []
|
||||
|
||||
if self.rag_service and self.rag_service.is_enabled:
|
||||
tasks.append(self.rag_service.add_positive_example(text))
|
||||
|
||||
if self.deepseek_service and self.deepseek_service.is_enabled:
|
||||
tasks.append(self.deepseek_service.add_positive_example(text))
|
||||
|
||||
if tasks:
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
logger.info("ScoringManager: Добавлен положительный пример")
|
||||
|
||||
@track_time("on_post_declined", "scoring_manager")
|
||||
async def on_post_declined(self, text: str) -> None:
|
||||
"""
|
||||
Вызывается при отклонении поста.
|
||||
|
||||
Добавляет текст как отрицательный пример для обучения RAG.
|
||||
|
||||
Args:
|
||||
text: Текст отклоненного поста
|
||||
"""
|
||||
if not text or not text.strip():
|
||||
return
|
||||
|
||||
tasks = []
|
||||
|
||||
if self.rag_service and self.rag_service.is_enabled:
|
||||
tasks.append(self.rag_service.add_negative_example(text))
|
||||
|
||||
if self.deepseek_service and self.deepseek_service.is_enabled:
|
||||
tasks.append(self.deepseek_service.add_negative_example(text))
|
||||
|
||||
if tasks:
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
logger.info("ScoringManager: Добавлен отрицательный пример")
|
||||
|
||||
async def load_examples_from_db(
|
||||
self,
|
||||
positive_texts: List[str],
|
||||
negative_texts: List[str],
|
||||
) -> None:
|
||||
"""
|
||||
Загружает примеры из базы данных при запуске бота.
|
||||
|
||||
Args:
|
||||
positive_texts: Список текстов опубликованных постов
|
||||
negative_texts: Список текстов отклоненных постов
|
||||
"""
|
||||
if self.rag_service and self.rag_service.is_enabled:
|
||||
await self.rag_service.load_examples_from_db(
|
||||
positive_texts,
|
||||
negative_texts
|
||||
)
|
||||
|
||||
def save_vectors(self) -> None:
|
||||
"""Сохраняет векторы RAG на диск."""
|
||||
if self.rag_service:
|
||||
self.rag_service.save_vectors()
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Закрывает ресурсы всех сервисов."""
|
||||
if self.deepseek_service:
|
||||
await self.deepseek_service.close()
|
||||
|
||||
# Сохраняем векторы перед закрытием
|
||||
self.save_vectors()
|
||||
|
||||
def get_stats(self) -> dict:
|
||||
"""Возвращает статистику всех сервисов."""
|
||||
stats = {
|
||||
"any_enabled": self.is_any_enabled,
|
||||
}
|
||||
|
||||
if self.rag_service:
|
||||
stats["rag"] = self.rag_service.get_stats()
|
||||
|
||||
if self.deepseek_service:
|
||||
stats["deepseek"] = self.deepseek_service.get_stats()
|
||||
|
||||
return stats
|
||||
399
helper_bot/services/scoring/vector_store.py
Normal file
399
helper_bot/services/scoring/vector_store.py
Normal file
@@ -0,0 +1,399 @@
|
||||
"""
|
||||
In-memory хранилище векторов на numpy.
|
||||
|
||||
Хранит векторные представления постов для быстрого сравнения.
|
||||
Поддерживает персистентность через сохранение/загрузку с диска.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple, List
|
||||
import threading
|
||||
|
||||
import numpy as np
|
||||
|
||||
from logs.custom_logger import logger
|
||||
from .exceptions import VectorStoreError, InsufficientExamplesError
|
||||
|
||||
|
||||
class VectorStore:
|
||||
"""
|
||||
In-memory хранилище векторов для RAG.
|
||||
|
||||
Хранит отдельно положительные (опубликованные) и отрицательные (отклоненные)
|
||||
примеры. Использует косинусное сходство для расчета скора.
|
||||
|
||||
Attributes:
|
||||
vector_dim: Размерность векторов (768 для ruBERT)
|
||||
max_examples: Максимальное количество примеров каждого типа
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vector_dim: int = 768,
|
||||
max_examples: int = 10000,
|
||||
storage_path: Optional[str] = None,
|
||||
score_multiplier: float = 5.0,
|
||||
):
|
||||
"""
|
||||
Инициализация хранилища.
|
||||
|
||||
Args:
|
||||
vector_dim: Размерность векторов
|
||||
max_examples: Максимальное количество примеров каждого типа
|
||||
storage_path: Путь для сохранения/загрузки векторов (опционально)
|
||||
score_multiplier: Множитель для усиления разницы в скорах
|
||||
"""
|
||||
self.vector_dim = vector_dim
|
||||
self.max_examples = max_examples
|
||||
self.storage_path = storage_path
|
||||
self.score_multiplier = score_multiplier
|
||||
|
||||
# Инициализируем пустые массивы
|
||||
# Используем список для динамического добавления, потом конвертируем в numpy
|
||||
self._positive_vectors: list = []
|
||||
self._negative_vectors: list = []
|
||||
self._positive_hashes: list = [] # Хеши текстов для дедупликации
|
||||
self._negative_hashes: list = []
|
||||
|
||||
# Lock для потокобезопасности
|
||||
self._lock = threading.Lock()
|
||||
|
||||
# Пытаемся загрузить сохраненные векторы
|
||||
if storage_path and os.path.exists(storage_path):
|
||||
self._load_from_disk()
|
||||
|
||||
@property
|
||||
def positive_count(self) -> int:
|
||||
"""Количество положительных примеров."""
|
||||
return len(self._positive_vectors)
|
||||
|
||||
@property
|
||||
def negative_count(self) -> int:
|
||||
"""Количество отрицательных примеров."""
|
||||
return len(self._negative_vectors)
|
||||
|
||||
@property
|
||||
def total_count(self) -> int:
|
||||
"""Общее количество примеров."""
|
||||
return self.positive_count + self.negative_count
|
||||
|
||||
@staticmethod
|
||||
def compute_text_hash(text: str) -> str:
|
||||
"""Вычисляет хеш текста для дедупликации."""
|
||||
return hashlib.md5(text.encode('utf-8')).hexdigest()
|
||||
|
||||
def _normalize_vector(self, vector: np.ndarray) -> np.ndarray:
|
||||
"""Нормализует вектор для косинусного сходства."""
|
||||
norm = np.linalg.norm(vector)
|
||||
if norm == 0:
|
||||
return vector
|
||||
return vector / norm
|
||||
|
||||
def add_positive(self, vector: np.ndarray, text_hash: Optional[str] = None) -> bool:
|
||||
"""
|
||||
Добавляет положительный пример (опубликованный пост).
|
||||
|
||||
Args:
|
||||
vector: Векторное представление текста
|
||||
text_hash: Хеш текста для дедупликации (опционально)
|
||||
|
||||
Returns:
|
||||
True если добавлен, False если дубликат или превышен лимит
|
||||
"""
|
||||
with self._lock:
|
||||
# Проверяем дубликат по хешу
|
||||
if text_hash and text_hash in self._positive_hashes:
|
||||
logger.debug(f"VectorStore: Пропуск дубликата положительного примера")
|
||||
return False
|
||||
|
||||
# Проверяем лимит
|
||||
if len(self._positive_vectors) >= self.max_examples:
|
||||
# Удаляем самый старый пример (FIFO)
|
||||
self._positive_vectors.pop(0)
|
||||
self._positive_hashes.pop(0)
|
||||
logger.debug("VectorStore: Удален старый положительный пример (лимит)")
|
||||
|
||||
# Нормализуем и добавляем
|
||||
normalized = self._normalize_vector(vector)
|
||||
self._positive_vectors.append(normalized)
|
||||
if text_hash:
|
||||
self._positive_hashes.append(text_hash)
|
||||
|
||||
logger.info(f"VectorStore: Добавлен положительный пример (всего: {self.positive_count})")
|
||||
return True
|
||||
|
||||
def add_positive_batch(
|
||||
self,
|
||||
vectors: List[np.ndarray],
|
||||
text_hashes: Optional[List[str]] = None
|
||||
) -> int:
|
||||
"""
|
||||
Добавляет батч положительных примеров.
|
||||
|
||||
Args:
|
||||
vectors: Список векторов
|
||||
text_hashes: Список хешей текстов для дедупликации
|
||||
|
||||
Returns:
|
||||
Количество добавленных примеров
|
||||
"""
|
||||
if text_hashes is None:
|
||||
text_hashes = [None] * len(vectors)
|
||||
|
||||
added = 0
|
||||
with self._lock:
|
||||
for vector, text_hash in zip(vectors, text_hashes):
|
||||
# Проверяем дубликат по хешу
|
||||
if text_hash and text_hash in self._positive_hashes:
|
||||
continue
|
||||
|
||||
# Проверяем лимит
|
||||
if len(self._positive_vectors) >= self.max_examples:
|
||||
self._positive_vectors.pop(0)
|
||||
self._positive_hashes.pop(0)
|
||||
|
||||
# Нормализуем и добавляем
|
||||
normalized = self._normalize_vector(vector)
|
||||
self._positive_vectors.append(normalized)
|
||||
if text_hash:
|
||||
self._positive_hashes.append(text_hash)
|
||||
added += 1
|
||||
|
||||
logger.info(f"VectorStore: Добавлено {added} положительных примеров батчем (всего: {self.positive_count})")
|
||||
return added
|
||||
|
||||
def add_negative(self, vector: np.ndarray, text_hash: Optional[str] = None) -> bool:
|
||||
"""
|
||||
Добавляет отрицательный пример (отклоненный пост).
|
||||
|
||||
Args:
|
||||
vector: Векторное представление текста
|
||||
text_hash: Хеш текста для дедупликации (опционально)
|
||||
|
||||
Returns:
|
||||
True если добавлен, False если дубликат или превышен лимит
|
||||
"""
|
||||
with self._lock:
|
||||
# Проверяем дубликат по хешу
|
||||
if text_hash and text_hash in self._negative_hashes:
|
||||
logger.debug(f"VectorStore: Пропуск дубликата отрицательного примера")
|
||||
return False
|
||||
|
||||
# Проверяем лимит
|
||||
if len(self._negative_vectors) >= self.max_examples:
|
||||
# Удаляем самый старый пример (FIFO)
|
||||
self._negative_vectors.pop(0)
|
||||
self._negative_hashes.pop(0)
|
||||
logger.debug("VectorStore: Удален старый отрицательный пример (лимит)")
|
||||
|
||||
# Нормализуем и добавляем
|
||||
normalized = self._normalize_vector(vector)
|
||||
self._negative_vectors.append(normalized)
|
||||
if text_hash:
|
||||
self._negative_hashes.append(text_hash)
|
||||
|
||||
logger.info(f"VectorStore: Добавлен отрицательный пример (всего: {self.negative_count})")
|
||||
return True
|
||||
|
||||
def add_negative_batch(
|
||||
self,
|
||||
vectors: List[np.ndarray],
|
||||
text_hashes: Optional[List[str]] = None
|
||||
) -> int:
|
||||
"""
|
||||
Добавляет батч отрицательных примеров.
|
||||
|
||||
Args:
|
||||
vectors: Список векторов
|
||||
text_hashes: Список хешей текстов для дедупликации
|
||||
|
||||
Returns:
|
||||
Количество добавленных примеров
|
||||
"""
|
||||
if text_hashes is None:
|
||||
text_hashes = [None] * len(vectors)
|
||||
|
||||
added = 0
|
||||
with self._lock:
|
||||
for vector, text_hash in zip(vectors, text_hashes):
|
||||
# Проверяем дубликат по хешу
|
||||
if text_hash and text_hash in self._negative_hashes:
|
||||
continue
|
||||
|
||||
# Проверяем лимит
|
||||
if len(self._negative_vectors) >= self.max_examples:
|
||||
self._negative_vectors.pop(0)
|
||||
self._negative_hashes.pop(0)
|
||||
|
||||
# Нормализуем и добавляем
|
||||
normalized = self._normalize_vector(vector)
|
||||
self._negative_vectors.append(normalized)
|
||||
if text_hash:
|
||||
self._negative_hashes.append(text_hash)
|
||||
added += 1
|
||||
|
||||
logger.info(f"VectorStore: Добавлено {added} отрицательных примеров батчем (всего: {self.negative_count})")
|
||||
return added
|
||||
|
||||
def calculate_similarity_score(self, vector: np.ndarray) -> Tuple[float, float]:
|
||||
"""
|
||||
Рассчитывает скор на основе сходства с примерами.
|
||||
|
||||
Алгоритм:
|
||||
1. Вычисляем среднее косинусное сходство с положительными примерами
|
||||
2. Вычисляем среднее косинусное сходство с отрицательными примерами
|
||||
3. Финальный скор = pos_sim / (pos_sim + neg_sim + eps)
|
||||
|
||||
Args:
|
||||
vector: Векторное представление нового поста
|
||||
|
||||
Returns:
|
||||
Tuple (score, confidence):
|
||||
- score: Оценка от 0.0 до 1.0
|
||||
- confidence: Уверенность (зависит от количества примеров)
|
||||
|
||||
Raises:
|
||||
InsufficientExamplesError: Если недостаточно примеров
|
||||
"""
|
||||
with self._lock:
|
||||
if self.positive_count == 0:
|
||||
raise InsufficientExamplesError(
|
||||
"Нет положительных примеров для сравнения"
|
||||
)
|
||||
|
||||
# Нормализуем входной вектор
|
||||
normalized = self._normalize_vector(vector)
|
||||
|
||||
# Конвертируем в numpy массивы для быстрых вычислений
|
||||
pos_matrix = np.array(self._positive_vectors)
|
||||
|
||||
# Косинусное сходство с положительными примерами
|
||||
# Для нормализованных векторов это просто скалярное произведение
|
||||
pos_similarities = np.dot(pos_matrix, normalized)
|
||||
pos_sim = float(np.mean(pos_similarities))
|
||||
|
||||
# Косинусное сходство с отрицательными примерами
|
||||
if self.negative_count > 0:
|
||||
neg_matrix = np.array(self._negative_vectors)
|
||||
neg_similarities = np.dot(neg_matrix, normalized)
|
||||
neg_sim = float(np.mean(neg_similarities))
|
||||
else:
|
||||
# Если нет отрицательных примеров, используем нейтральное значение
|
||||
neg_sim = pos_sim # Нейтральный скор = 0.5
|
||||
|
||||
# === Вариант 1: neg/pos (разница между положительными и отрицательными) ===
|
||||
diff = pos_sim - neg_sim
|
||||
score_neg_pos = 0.5 + (diff * self.score_multiplier)
|
||||
score_neg_pos = max(0.0, min(1.0, score_neg_pos))
|
||||
|
||||
# === Вариант 2: pos only (только положительные, топ-k ближайших) ===
|
||||
# Берём топ-5 ближайших положительных примеров
|
||||
top_k = min(5, len(pos_similarities))
|
||||
top_k_sim = float(np.mean(np.sort(pos_similarities)[-top_k:]))
|
||||
# Нормализуем: 0.85 -> 0.0, 0.95 -> 1.0 (типичный диапазон для BERT)
|
||||
score_pos_only = (top_k_sim - 0.85) / 0.10
|
||||
score_pos_only = max(0.0, min(1.0, score_pos_only))
|
||||
|
||||
# Основной скор — neg/pos (можно будет переключить позже)
|
||||
score = score_neg_pos
|
||||
|
||||
# Confidence зависит от количества примеров (100% при 1000 примерах)
|
||||
total_examples = self.positive_count + self.negative_count
|
||||
confidence = min(1.0, total_examples / 1000)
|
||||
|
||||
logger.info(
|
||||
f"VectorStore: pos_sim={pos_sim:.4f}, neg_sim={neg_sim:.4f}, "
|
||||
f"top_k_sim={top_k_sim:.4f}, score_neg_pos={score_neg_pos:.4f}, "
|
||||
f"score_pos_only={score_pos_only:.4f}"
|
||||
)
|
||||
|
||||
return score, confidence, score_pos_only
|
||||
|
||||
def save_to_disk(self, path: Optional[str] = None) -> None:
|
||||
"""
|
||||
Сохраняет векторы на диск.
|
||||
|
||||
Args:
|
||||
path: Путь для сохранения (если не указан, используется storage_path)
|
||||
"""
|
||||
save_path = path or self.storage_path
|
||||
if not save_path:
|
||||
raise VectorStoreError("Путь для сохранения не указан")
|
||||
|
||||
with self._lock:
|
||||
# Создаем директорию если нужно
|
||||
Path(save_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Сохраняем в npz формате
|
||||
np.savez_compressed(
|
||||
save_path,
|
||||
positive_vectors=np.array(self._positive_vectors) if self._positive_vectors else np.array([]),
|
||||
negative_vectors=np.array(self._negative_vectors) if self._negative_vectors else np.array([]),
|
||||
positive_hashes=np.array(self._positive_hashes, dtype=object),
|
||||
negative_hashes=np.array(self._negative_hashes, dtype=object),
|
||||
vector_dim=self.vector_dim,
|
||||
max_examples=self.max_examples,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"VectorStore: Сохранено на диск ({self.positive_count} pos, "
|
||||
f"{self.negative_count} neg): {save_path}"
|
||||
)
|
||||
|
||||
def _load_from_disk(self) -> None:
|
||||
"""Загружает векторы с диска."""
|
||||
if not self.storage_path or not os.path.exists(self.storage_path):
|
||||
return
|
||||
|
||||
try:
|
||||
with self._lock:
|
||||
data = np.load(self.storage_path, allow_pickle=True)
|
||||
|
||||
# Загружаем векторы
|
||||
pos_vectors = data.get('positive_vectors', np.array([]))
|
||||
neg_vectors = data.get('negative_vectors', np.array([]))
|
||||
|
||||
if pos_vectors.size > 0:
|
||||
self._positive_vectors = list(pos_vectors)
|
||||
if neg_vectors.size > 0:
|
||||
self._negative_vectors = list(neg_vectors)
|
||||
|
||||
# Загружаем хеши
|
||||
pos_hashes = data.get('positive_hashes', np.array([]))
|
||||
neg_hashes = data.get('negative_hashes', np.array([]))
|
||||
|
||||
if pos_hashes.size > 0:
|
||||
self._positive_hashes = list(pos_hashes)
|
||||
if neg_hashes.size > 0:
|
||||
self._negative_hashes = list(neg_hashes)
|
||||
|
||||
logger.info(
|
||||
f"VectorStore: Загружено с диска ({self.positive_count} pos, "
|
||||
f"{self.negative_count} neg): {self.storage_path}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"VectorStore: Ошибка загрузки с диска: {e}")
|
||||
# Продолжаем с пустым хранилищем
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Очищает все векторы."""
|
||||
with self._lock:
|
||||
self._positive_vectors.clear()
|
||||
self._negative_vectors.clear()
|
||||
self._positive_hashes.clear()
|
||||
self._negative_hashes.clear()
|
||||
logger.info("VectorStore: Хранилище очищено")
|
||||
|
||||
def get_stats(self) -> dict:
|
||||
"""Возвращает статистику хранилища."""
|
||||
return {
|
||||
"positive_count": self.positive_count,
|
||||
"negative_count": self.negative_count,
|
||||
"total_count": self.total_count,
|
||||
"vector_dim": self.vector_dim,
|
||||
"max_examples": self.max_examples,
|
||||
"storage_path": self.storage_path,
|
||||
}
|
||||
@@ -5,6 +5,7 @@ from typing import Optional
|
||||
from database.async_db import AsyncBotDB
|
||||
from dotenv import load_dotenv
|
||||
from helper_bot.utils.s3_storage import S3StorageService
|
||||
from logs.custom_logger import logger
|
||||
|
||||
|
||||
class BaseDependencyFactory:
|
||||
@@ -15,6 +16,7 @@ class BaseDependencyFactory:
|
||||
load_dotenv(env_path)
|
||||
|
||||
self.settings = {}
|
||||
self._project_dir = project_dir
|
||||
|
||||
database_path = os.getenv('DATABASE_PATH', 'database/tg-bot-database.db')
|
||||
if not os.path.isabs(database_path):
|
||||
@@ -25,6 +27,9 @@ class BaseDependencyFactory:
|
||||
self._load_settings_from_env()
|
||||
self._init_s3_storage()
|
||||
|
||||
# ScoringManager инициализируется лениво
|
||||
self._scoring_manager = None
|
||||
|
||||
def _load_settings_from_env(self):
|
||||
"""Загружает настройки из переменных окружения."""
|
||||
self.settings['Telegram'] = {
|
||||
@@ -60,6 +65,23 @@ class BaseDependencyFactory:
|
||||
'region': os.getenv('S3_REGION', 'us-east-1')
|
||||
}
|
||||
|
||||
# Настройки ML-скоринга
|
||||
self.settings['Scoring'] = {
|
||||
# RAG (ruBERT)
|
||||
'rag_enabled': self._parse_bool(os.getenv('RAG_ENABLED', 'false')),
|
||||
'rag_model': os.getenv('RAG_MODEL', 'DeepPavlov/rubert-base-cased'),
|
||||
'rag_cache_dir': os.getenv('RAG_CACHE_DIR', 'data/models'),
|
||||
'rag_vectors_path': os.getenv('RAG_VECTORS_PATH', 'data/vectors.npz'),
|
||||
'rag_max_examples': self._parse_int(os.getenv('RAG_MAX_EXAMPLES', '10000')),
|
||||
'rag_score_multiplier': self._parse_float(os.getenv('RAG_SCORE_MULTIPLIER', '5.0')),
|
||||
# DeepSeek
|
||||
'deepseek_enabled': self._parse_bool(os.getenv('DEEPSEEK_ENABLED', 'false')),
|
||||
'deepseek_api_key': os.getenv('DEEPSEEK_API_KEY', ''),
|
||||
'deepseek_api_url': os.getenv('DEEPSEEK_API_URL', 'https://api.deepseek.com/v1/chat/completions'),
|
||||
'deepseek_model': os.getenv('DEEPSEEK_MODEL', 'deepseek-chat'),
|
||||
'deepseek_timeout': self._parse_int(os.getenv('DEEPSEEK_TIMEOUT', '30')),
|
||||
}
|
||||
|
||||
def _init_s3_storage(self):
|
||||
"""Инициализирует S3StorageService если S3 включен."""
|
||||
self.s3_storage = None
|
||||
@@ -85,6 +107,13 @@ class BaseDependencyFactory:
|
||||
except (ValueError, TypeError):
|
||||
return 0
|
||||
|
||||
def _parse_float(self, value: str) -> float:
|
||||
"""Парсит строковое значение в float."""
|
||||
try:
|
||||
return float(value)
|
||||
except (ValueError, TypeError):
|
||||
return 0.0
|
||||
|
||||
def get_settings(self):
|
||||
return self.settings
|
||||
|
||||
@@ -96,6 +125,100 @@ class BaseDependencyFactory:
|
||||
"""Возвращает S3StorageService если S3 включен, иначе None."""
|
||||
return self.s3_storage
|
||||
|
||||
def _init_scoring_manager(self):
|
||||
"""
|
||||
Инициализирует ScoringManager с RAG и DeepSeek сервисами.
|
||||
|
||||
Вызывается лениво при первом обращении к get_scoring_manager().
|
||||
"""
|
||||
from helper_bot.services.scoring import (
|
||||
ScoringManager,
|
||||
RAGService,
|
||||
DeepSeekService,
|
||||
VectorStore,
|
||||
)
|
||||
|
||||
scoring_config = self.settings['Scoring']
|
||||
|
||||
# Инициализация RAG сервиса
|
||||
rag_service = None
|
||||
if scoring_config['rag_enabled']:
|
||||
# Путь к векторам
|
||||
vectors_path = scoring_config['rag_vectors_path']
|
||||
if not os.path.isabs(vectors_path):
|
||||
vectors_path = os.path.join(self._project_dir, vectors_path)
|
||||
|
||||
# Путь к кешу моделей
|
||||
cache_dir = scoring_config['rag_cache_dir']
|
||||
if not os.path.isabs(cache_dir):
|
||||
cache_dir = os.path.join(self._project_dir, cache_dir)
|
||||
|
||||
# Создаем директории если нужно
|
||||
os.makedirs(os.path.dirname(vectors_path), exist_ok=True)
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
|
||||
# Создаем VectorStore
|
||||
vector_store = VectorStore(
|
||||
vector_dim=768, # ruBERT dimension
|
||||
max_examples=scoring_config['rag_max_examples'],
|
||||
storage_path=vectors_path,
|
||||
score_multiplier=scoring_config['rag_score_multiplier'],
|
||||
)
|
||||
|
||||
# Создаем RAGService
|
||||
rag_service = RAGService(
|
||||
model_name=scoring_config['rag_model'],
|
||||
vector_store=vector_store,
|
||||
cache_dir=cache_dir,
|
||||
enabled=True,
|
||||
)
|
||||
|
||||
logger.info(f"RAGService инициализирован: {scoring_config['rag_model']}")
|
||||
|
||||
# Инициализация DeepSeek сервиса
|
||||
deepseek_service = None
|
||||
if scoring_config['deepseek_enabled'] and scoring_config['deepseek_api_key']:
|
||||
deepseek_service = DeepSeekService(
|
||||
api_key=scoring_config['deepseek_api_key'],
|
||||
api_url=scoring_config['deepseek_api_url'],
|
||||
model=scoring_config['deepseek_model'],
|
||||
timeout=scoring_config['deepseek_timeout'],
|
||||
enabled=True,
|
||||
)
|
||||
logger.info(f"DeepSeekService инициализирован: {scoring_config['deepseek_model']}")
|
||||
|
||||
# Создаем менеджер
|
||||
self._scoring_manager = ScoringManager(
|
||||
rag_service=rag_service,
|
||||
deepseek_service=deepseek_service,
|
||||
)
|
||||
|
||||
return self._scoring_manager
|
||||
|
||||
def get_scoring_manager(self):
|
||||
"""
|
||||
Возвращает ScoringManager для ML-скоринга постов.
|
||||
|
||||
Инициализируется лениво при первом вызове.
|
||||
|
||||
Returns:
|
||||
ScoringManager или None если скоринг полностью отключен
|
||||
"""
|
||||
if self._scoring_manager is None:
|
||||
scoring_config = self.settings.get('Scoring', {})
|
||||
|
||||
# Проверяем, включен ли хотя бы один сервис
|
||||
rag_enabled = scoring_config.get('rag_enabled', False)
|
||||
deepseek_enabled = scoring_config.get('deepseek_enabled', False)
|
||||
|
||||
if not rag_enabled and not deepseek_enabled:
|
||||
logger.info("Scoring полностью отключен (RAG и DeepSeek disabled)")
|
||||
return None
|
||||
|
||||
self._init_scoring_manager()
|
||||
|
||||
return self._scoring_manager
|
||||
|
||||
|
||||
_global_instance = None
|
||||
|
||||
|
||||
@@ -111,7 +111,16 @@ def determine_anonymity(post_text: str) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def get_text_message(post_text: str, first_name: str, username: str = None, is_anonymous: Optional[bool] = None):
|
||||
def get_text_message(
|
||||
post_text: str,
|
||||
first_name: str,
|
||||
username: str = None,
|
||||
is_anonymous: Optional[bool] = None,
|
||||
deepseek_score: Optional[float] = None,
|
||||
rag_score: Optional[float] = None,
|
||||
rag_confidence: Optional[float] = None,
|
||||
rag_score_pos_only: Optional[float] = None,
|
||||
):
|
||||
"""
|
||||
Форматирует текст сообщения для публикации в зависимости от наличия ключевых слов "анон" и "неанон"
|
||||
или переданного параметра is_anonymous.
|
||||
@@ -121,6 +130,10 @@ def get_text_message(post_text: str, first_name: str, username: str = None, is_a
|
||||
first_name: Имя автора поста
|
||||
username: Юзернейм автора поста (может быть None)
|
||||
is_anonymous: Флаг анонимности (True - анонимно, False - не анонимно, None - legacy, определяется по тексту)
|
||||
deepseek_score: Скор от DeepSeek API (0.0-1.0, опционально)
|
||||
rag_score: Скор от RAG/ruBERT neg/pos (0.0-1.0, опционально)
|
||||
rag_confidence: Уверенность RAG модели (0.0-1.0, зависит от количества примеров)
|
||||
rag_score_pos_only: Скор RAG только по положительным примерам (0.0-1.0, опционально)
|
||||
|
||||
Returns:
|
||||
str: - Сформированный текст сообщения.
|
||||
@@ -137,21 +150,37 @@ def get_text_message(post_text: str, first_name: str, username: str = None, is_a
|
||||
else:
|
||||
author_info = f"{first_name} (Ник не указан)"
|
||||
|
||||
# Формируем базовый текст
|
||||
# Если передан is_anonymous, используем его, иначе определяем по тексту (legacy)
|
||||
# TODO: Уверен можно укоротить
|
||||
if is_anonymous is not None:
|
||||
if is_anonymous:
|
||||
return f'{safe_post_text}\n\nПост опубликован анонимно'
|
||||
final_text = f'{safe_post_text}\n\nПост опубликован анонимно'
|
||||
else:
|
||||
return f'{safe_post_text}\n\nАвтор поста: {author_info}'
|
||||
final_text = f'{safe_post_text}\n\nАвтор поста: {author_info}'
|
||||
else:
|
||||
# Legacy: определяем по тексту
|
||||
if "неанон" in post_text or "не анон" in post_text:
|
||||
return f'{safe_post_text}\n\nАвтор поста: {author_info}'
|
||||
final_text = f'{safe_post_text}\n\nАвтор поста: {author_info}'
|
||||
elif "анон" in post_text:
|
||||
return f'{safe_post_text}\n\nПост опубликован анонимно'
|
||||
final_text = f'{safe_post_text}\n\nПост опубликован анонимно'
|
||||
else:
|
||||
return f'{safe_post_text}\n\nАвтор поста: {author_info}'
|
||||
final_text = f'{safe_post_text}\n\nАвтор поста: {author_info}'
|
||||
|
||||
# Добавляем блок со скорами если есть
|
||||
if deepseek_score is not None or rag_score is not None or rag_score_pos_only is not None:
|
||||
scores_lines = ["\n📊 Уверенность в одобрении:"]
|
||||
if deepseek_score is not None:
|
||||
scores_lines.append(f"DeepSeek: {deepseek_score:.2f}")
|
||||
if rag_score is not None:
|
||||
rag_line = f"RAG neg/pos: {rag_score:.2f}"
|
||||
if rag_confidence is not None:
|
||||
rag_line += f" (уверенность: {rag_confidence:.0%})"
|
||||
scores_lines.append(rag_line)
|
||||
if rag_score_pos_only is not None:
|
||||
scores_lines.append(f"RAG pos only: {rag_score_pos_only:.2f}")
|
||||
final_text += "\n" + "\n".join(scores_lines)
|
||||
|
||||
return final_text
|
||||
|
||||
@track_time("download_file", "helper_func")
|
||||
@track_errors("helper_func", "download_file")
|
||||
|
||||
@@ -31,3 +31,9 @@ emoji~=2.8.0
|
||||
|
||||
# S3 Storage (для хранения медиафайлов опубликованных постов)
|
||||
aioboto3>=12.0.0
|
||||
|
||||
# ML Scoring (для оценки вероятности публикации постов)
|
||||
numpy>=1.24.0
|
||||
transformers>=4.30.0
|
||||
torch>=2.0.0
|
||||
httpx>=0.24.0
|
||||
93
scripts/add_ml_scores_columns.py
Normal file
93
scripts/add_ml_scores_columns.py
Normal file
@@ -0,0 +1,93 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Миграция: Добавление колонок для ML-скоринга постов.
|
||||
|
||||
Добавляет:
|
||||
- ml_scores (TEXT/JSON) - JSON с результатами оценки от разных моделей
|
||||
- vector_hash (TEXT) - хеш текста для кеширования векторов
|
||||
|
||||
Структура ml_scores:
|
||||
{
|
||||
"deepseek": {"score": 0.75, "model": "deepseek-chat", "ts": 1706198400},
|
||||
"rag": {"score": 0.90, "model": "rubert-base-cased", "ts": 1706198400}
|
||||
}
|
||||
"""
|
||||
import argparse
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Добавляем корень проекта в путь
|
||||
project_root = Path(__file__).resolve().parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
import aiosqlite
|
||||
|
||||
# Пытаемся импортировать logger, если не получается - используем стандартный
|
||||
try:
|
||||
from logs.custom_logger import logger
|
||||
except ImportError:
|
||||
import logging
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_DB_PATH = "database/tg-bot-database.db"
|
||||
|
||||
|
||||
async def column_exists(conn: aiosqlite.Connection, table: str, column: str) -> bool:
|
||||
"""Проверяет существование колонки в таблице."""
|
||||
cursor = await conn.execute(f"PRAGMA table_info({table})")
|
||||
columns = await cursor.fetchall()
|
||||
return any(col[1] == column for col in columns)
|
||||
|
||||
|
||||
async def main(db_path: str) -> None:
|
||||
"""
|
||||
Основная функция миграции.
|
||||
|
||||
Добавляет колонки ml_scores и vector_hash в таблицу post_from_telegram_suggest.
|
||||
Миграция идемпотентна - можно запускать повторно без ошибок.
|
||||
"""
|
||||
db_path = os.path.abspath(db_path)
|
||||
|
||||
if not os.path.exists(db_path):
|
||||
logger.error(f"База данных не найдена: {db_path}")
|
||||
return
|
||||
|
||||
async with aiosqlite.connect(db_path) as conn:
|
||||
await conn.execute("PRAGMA foreign_keys = ON")
|
||||
|
||||
# Проверяем и добавляем колонку ml_scores
|
||||
if not await column_exists(conn, "post_from_telegram_suggest", "ml_scores"):
|
||||
await conn.execute(
|
||||
"ALTER TABLE post_from_telegram_suggest ADD COLUMN ml_scores TEXT"
|
||||
)
|
||||
logger.info("Колонка ml_scores добавлена в post_from_telegram_suggest")
|
||||
else:
|
||||
logger.info("Колонка ml_scores уже существует")
|
||||
|
||||
# Проверяем и добавляем колонку vector_hash
|
||||
if not await column_exists(conn, "post_from_telegram_suggest", "vector_hash"):
|
||||
await conn.execute(
|
||||
"ALTER TABLE post_from_telegram_suggest ADD COLUMN vector_hash TEXT"
|
||||
)
|
||||
logger.info("Колонка vector_hash добавлена в post_from_telegram_suggest")
|
||||
else:
|
||||
logger.info("Колонка vector_hash уже существует")
|
||||
|
||||
await conn.commit()
|
||||
logger.info("Миграция add_ml_scores_columns завершена успешно")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Добавление колонок ml_scores и vector_hash для ML-скоринга"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--db",
|
||||
default=os.environ.get("DATABASE_PATH", DEFAULT_DB_PATH),
|
||||
help="Путь к БД",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
asyncio.run(main(args.db))
|
||||
390
tests/test_scoring_services.py
Normal file
390
tests/test_scoring_services.py
Normal file
@@ -0,0 +1,390 @@
|
||||
"""
|
||||
Тесты для сервисов ML-скоринга постов.
|
||||
"""
|
||||
|
||||
import json
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
# Импорты для тестирования базовых классов
|
||||
from helper_bot.services.scoring.base import ScoringResult, CombinedScore
|
||||
from helper_bot.services.scoring.exceptions import (
|
||||
ScoringError,
|
||||
InsufficientExamplesError,
|
||||
TextTooShortError,
|
||||
)
|
||||
|
||||
|
||||
class TestScoringResult:
|
||||
"""Тесты для ScoringResult."""
|
||||
|
||||
def test_create_valid_score(self):
|
||||
"""Тест создания валидного результата."""
|
||||
result = ScoringResult(
|
||||
score=0.75,
|
||||
source="rag",
|
||||
model="test-model",
|
||||
)
|
||||
assert result.score == 0.75
|
||||
assert result.source == "rag"
|
||||
assert result.model == "test-model"
|
||||
|
||||
def test_score_validation_lower_bound(self):
|
||||
"""Тест валидации нижней границы скора."""
|
||||
with pytest.raises(ValueError):
|
||||
ScoringResult(score=-0.1, source="test", model="test")
|
||||
|
||||
def test_score_validation_upper_bound(self):
|
||||
"""Тест валидации верхней границы скора."""
|
||||
with pytest.raises(ValueError):
|
||||
ScoringResult(score=1.1, source="test", model="test")
|
||||
|
||||
def test_to_dict(self):
|
||||
"""Тест преобразования в словарь."""
|
||||
result = ScoringResult(
|
||||
score=0.7534,
|
||||
source="rag",
|
||||
model="test-model",
|
||||
confidence=0.85,
|
||||
timestamp=1234567890,
|
||||
)
|
||||
d = result.to_dict()
|
||||
|
||||
assert d["score"] == 0.7534 # Округлено до 4 знаков
|
||||
assert d["model"] == "test-model"
|
||||
assert d["ts"] == 1234567890
|
||||
assert d["confidence"] == 0.85
|
||||
|
||||
def test_from_dict(self):
|
||||
"""Тест создания из словаря."""
|
||||
data = {
|
||||
"score": 0.75,
|
||||
"model": "test-model",
|
||||
"ts": 1234567890,
|
||||
"confidence": 0.9,
|
||||
}
|
||||
result = ScoringResult.from_dict("rag", data)
|
||||
|
||||
assert result.score == 0.75
|
||||
assert result.source == "rag"
|
||||
assert result.model == "test-model"
|
||||
assert result.timestamp == 1234567890
|
||||
assert result.confidence == 0.9
|
||||
|
||||
|
||||
class TestCombinedScore:
|
||||
"""Тесты для CombinedScore."""
|
||||
|
||||
def test_empty_combined_score(self):
|
||||
"""Тест пустого объединенного скора."""
|
||||
score = CombinedScore()
|
||||
|
||||
assert score.deepseek is None
|
||||
assert score.rag is None
|
||||
assert score.deepseek_score is None
|
||||
assert score.rag_score is None
|
||||
assert not score.has_any_score()
|
||||
|
||||
def test_combined_score_with_rag(self):
|
||||
"""Тест объединенного скора с RAG."""
|
||||
rag_result = ScoringResult(score=0.8, source="rag", model="rubert")
|
||||
score = CombinedScore(rag=rag_result)
|
||||
|
||||
assert score.rag_score == 0.8
|
||||
assert score.deepseek_score is None
|
||||
assert score.has_any_score()
|
||||
|
||||
def test_combined_score_with_both(self):
|
||||
"""Тест объединенного скора с обоими сервисами."""
|
||||
rag_result = ScoringResult(score=0.8, source="rag", model="rubert")
|
||||
deepseek_result = ScoringResult(score=0.7, source="deepseek", model="deepseek-chat")
|
||||
score = CombinedScore(rag=rag_result, deepseek=deepseek_result)
|
||||
|
||||
assert score.rag_score == 0.8
|
||||
assert score.deepseek_score == 0.7
|
||||
assert score.has_any_score()
|
||||
|
||||
def test_to_json_dict(self):
|
||||
"""Тест преобразования в JSON словарь."""
|
||||
rag_result = ScoringResult(score=0.8, source="rag", model="rubert", timestamp=123)
|
||||
deepseek_result = ScoringResult(score=0.7, source="deepseek", model="deepseek-chat", timestamp=456)
|
||||
score = CombinedScore(rag=rag_result, deepseek=deepseek_result)
|
||||
|
||||
d = score.to_json_dict()
|
||||
|
||||
assert "rag" in d
|
||||
assert "deepseek" in d
|
||||
assert d["rag"]["score"] == 0.8
|
||||
assert d["deepseek"]["score"] == 0.7
|
||||
|
||||
# Проверяем что это валидный JSON
|
||||
json_str = json.dumps(d)
|
||||
assert json_str
|
||||
|
||||
|
||||
class TestVectorStore:
|
||||
"""Тесты для VectorStore (требует numpy)."""
|
||||
|
||||
@pytest.fixture
|
||||
def vector_store(self):
|
||||
"""Создает VectorStore для тестов."""
|
||||
try:
|
||||
import numpy as np
|
||||
from helper_bot.services.scoring.vector_store import VectorStore
|
||||
return VectorStore(vector_dim=768, max_examples=100)
|
||||
except ImportError:
|
||||
pytest.skip("numpy не установлен")
|
||||
|
||||
def test_add_positive_example(self, vector_store):
|
||||
"""Тест добавления положительного примера."""
|
||||
import numpy as np
|
||||
|
||||
vector = np.random.randn(768).astype(np.float32)
|
||||
result = vector_store.add_positive(vector, "hash1")
|
||||
|
||||
assert result is True
|
||||
assert vector_store.positive_count == 1
|
||||
|
||||
def test_add_duplicate_example(self, vector_store):
|
||||
"""Тест добавления дубликата."""
|
||||
import numpy as np
|
||||
|
||||
vector = np.random.randn(768).astype(np.float32)
|
||||
vector_store.add_positive(vector, "hash1")
|
||||
result = vector_store.add_positive(vector, "hash1") # Дубликат
|
||||
|
||||
assert result is False
|
||||
assert vector_store.positive_count == 1
|
||||
|
||||
def test_max_examples_limit(self, vector_store):
|
||||
"""Тест ограничения максимального количества примеров."""
|
||||
import numpy as np
|
||||
|
||||
# Добавляем больше чем max_examples
|
||||
for i in range(150):
|
||||
vector = np.random.randn(768).astype(np.float32)
|
||||
vector_store.add_positive(vector, f"hash_{i}")
|
||||
|
||||
assert vector_store.positive_count == 100 # max_examples
|
||||
|
||||
def test_calculate_similarity_no_examples(self, vector_store):
|
||||
"""Тест расчета скора без примеров."""
|
||||
import numpy as np
|
||||
|
||||
vector = np.random.randn(768).astype(np.float32)
|
||||
|
||||
with pytest.raises(InsufficientExamplesError):
|
||||
vector_store.calculate_similarity_score(vector)
|
||||
|
||||
def test_calculate_similarity_with_examples(self, vector_store):
|
||||
"""Тест расчета скора с примерами."""
|
||||
import numpy as np
|
||||
|
||||
# Добавляем положительные примеры
|
||||
for i in range(10):
|
||||
vector = np.random.randn(768).astype(np.float32)
|
||||
vector_store.add_positive(vector, f"pos_{i}")
|
||||
|
||||
# Добавляем отрицательные примеры
|
||||
for i in range(10):
|
||||
vector = np.random.randn(768).astype(np.float32)
|
||||
vector_store.add_negative(vector, f"neg_{i}")
|
||||
|
||||
# Рассчитываем скор для нового вектора
|
||||
test_vector = np.random.randn(768).astype(np.float32)
|
||||
score, confidence = vector_store.calculate_similarity_score(test_vector)
|
||||
|
||||
assert 0.0 <= score <= 1.0
|
||||
assert 0.0 <= confidence <= 1.0
|
||||
|
||||
def test_compute_text_hash(self, vector_store):
|
||||
"""Тест вычисления хеша текста."""
|
||||
from helper_bot.services.scoring.vector_store import VectorStore
|
||||
|
||||
hash1 = VectorStore.compute_text_hash("Привет мир")
|
||||
hash2 = VectorStore.compute_text_hash("Привет мир")
|
||||
hash3 = VectorStore.compute_text_hash("Другой текст")
|
||||
|
||||
assert hash1 == hash2
|
||||
assert hash1 != hash3
|
||||
|
||||
|
||||
class TestDeepSeekService:
|
||||
"""Тесты для DeepSeekService."""
|
||||
|
||||
@pytest.fixture
|
||||
def deepseek_service(self):
|
||||
"""Создает DeepSeekService для тестов."""
|
||||
from helper_bot.services.scoring.deepseek_service import DeepSeekService
|
||||
return DeepSeekService(
|
||||
api_key="test_key",
|
||||
enabled=True,
|
||||
timeout=5,
|
||||
)
|
||||
|
||||
def test_service_disabled_without_key(self):
|
||||
"""Тест отключения сервиса без API ключа."""
|
||||
from helper_bot.services.scoring.deepseek_service import DeepSeekService
|
||||
service = DeepSeekService(api_key=None, enabled=True)
|
||||
|
||||
assert service.is_enabled is False
|
||||
|
||||
def test_parse_score_response_valid(self, deepseek_service):
|
||||
"""Тест парсинга валидного ответа."""
|
||||
assert deepseek_service._parse_score_response("0.75") == 0.75
|
||||
assert deepseek_service._parse_score_response("0.5") == 0.5
|
||||
assert deepseek_service._parse_score_response("1.0") == 1.0
|
||||
assert deepseek_service._parse_score_response("0") == 0.0
|
||||
|
||||
def test_parse_score_response_with_quotes(self, deepseek_service):
|
||||
"""Тест парсинга ответа с кавычками."""
|
||||
assert deepseek_service._parse_score_response('"0.75"') == 0.75
|
||||
assert deepseek_service._parse_score_response("'0.8'") == 0.8
|
||||
|
||||
def test_parse_score_response_with_text(self, deepseek_service):
|
||||
"""Тест парсинга ответа с текстом."""
|
||||
# Сервис должен найти число в тексте
|
||||
assert deepseek_service._parse_score_response("Score: 0.75") == 0.75
|
||||
|
||||
def test_clean_text(self, deepseek_service):
|
||||
"""Тест очистки текста."""
|
||||
assert deepseek_service._clean_text(" hello world ") == "hello world"
|
||||
assert deepseek_service._clean_text("^") == ""
|
||||
assert deepseek_service._clean_text("") == ""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_calculate_score_disabled(self):
|
||||
"""Тест расчета скора при отключенном сервисе."""
|
||||
from helper_bot.services.scoring.deepseek_service import DeepSeekService
|
||||
service = DeepSeekService(api_key=None, enabled=False)
|
||||
|
||||
with pytest.raises(ScoringError):
|
||||
await service.calculate_score("Test text")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_calculate_score_short_text(self, deepseek_service):
|
||||
"""Тест расчета скора для короткого текста."""
|
||||
with pytest.raises(TextTooShortError):
|
||||
await deepseek_service.calculate_score("ab")
|
||||
|
||||
|
||||
class TestScoringManager:
|
||||
"""Тесты для ScoringManager."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_rag_service(self):
|
||||
"""Создает мок RAG сервиса."""
|
||||
mock = AsyncMock()
|
||||
mock.is_enabled = True
|
||||
mock.calculate_score = AsyncMock(return_value=ScoringResult(
|
||||
score=0.8,
|
||||
source="rag",
|
||||
model="rubert",
|
||||
))
|
||||
return mock
|
||||
|
||||
@pytest.fixture
|
||||
def mock_deepseek_service(self):
|
||||
"""Создает мок DeepSeek сервиса."""
|
||||
mock = AsyncMock()
|
||||
mock.is_enabled = True
|
||||
mock.calculate_score = AsyncMock(return_value=ScoringResult(
|
||||
score=0.7,
|
||||
source="deepseek",
|
||||
model="deepseek-chat",
|
||||
))
|
||||
return mock
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_score_post_both_services(self, mock_rag_service, mock_deepseek_service):
|
||||
"""Тест скоринга с обоими сервисами."""
|
||||
from helper_bot.services.scoring.scoring_manager import ScoringManager
|
||||
|
||||
manager = ScoringManager(
|
||||
rag_service=mock_rag_service,
|
||||
deepseek_service=mock_deepseek_service,
|
||||
)
|
||||
|
||||
result = await manager.score_post("Тестовый пост")
|
||||
|
||||
assert result.rag_score == 0.8
|
||||
assert result.deepseek_score == 0.7
|
||||
assert result.has_any_score()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_score_post_rag_only(self, mock_rag_service):
|
||||
"""Тест скоринга только с RAG."""
|
||||
from helper_bot.services.scoring.scoring_manager import ScoringManager
|
||||
|
||||
manager = ScoringManager(
|
||||
rag_service=mock_rag_service,
|
||||
deepseek_service=None,
|
||||
)
|
||||
|
||||
result = await manager.score_post("Тестовый пост")
|
||||
|
||||
assert result.rag_score == 0.8
|
||||
assert result.deepseek_score is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_score_post_empty_text(self, mock_rag_service):
|
||||
"""Тест скоринга пустого текста."""
|
||||
from helper_bot.services.scoring.scoring_manager import ScoringManager
|
||||
|
||||
manager = ScoringManager(rag_service=mock_rag_service)
|
||||
|
||||
result = await manager.score_post("")
|
||||
|
||||
assert not result.has_any_score()
|
||||
mock_rag_service.calculate_score.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_score_post_service_error(self, mock_rag_service, mock_deepseek_service):
|
||||
"""Тест обработки ошибки сервиса."""
|
||||
from helper_bot.services.scoring.scoring_manager import ScoringManager
|
||||
|
||||
# RAG выбрасывает ошибку
|
||||
mock_rag_service.calculate_score = AsyncMock(side_effect=Exception("Test error"))
|
||||
|
||||
manager = ScoringManager(
|
||||
rag_service=mock_rag_service,
|
||||
deepseek_service=mock_deepseek_service,
|
||||
)
|
||||
|
||||
result = await manager.score_post("Тестовый пост")
|
||||
|
||||
# DeepSeek должен вернуть результат
|
||||
assert result.deepseek_score == 0.7
|
||||
# RAG должен быть None с ошибкой
|
||||
assert result.rag_score is None
|
||||
assert "rag" in result.errors
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_post_published(self, mock_rag_service, mock_deepseek_service):
|
||||
"""Тест обучения на опубликованном посте."""
|
||||
from helper_bot.services.scoring.scoring_manager import ScoringManager
|
||||
|
||||
manager = ScoringManager(
|
||||
rag_service=mock_rag_service,
|
||||
deepseek_service=mock_deepseek_service,
|
||||
)
|
||||
|
||||
await manager.on_post_published("Опубликованный пост")
|
||||
|
||||
mock_rag_service.add_positive_example.assert_called_once_with("Опубликованный пост")
|
||||
mock_deepseek_service.add_positive_example.assert_called_once_with("Опубликованный пост")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_post_declined(self, mock_rag_service, mock_deepseek_service):
|
||||
"""Тест обучения на отклоненном посте."""
|
||||
from helper_bot.services.scoring.scoring_manager import ScoringManager
|
||||
|
||||
manager = ScoringManager(
|
||||
rag_service=mock_rag_service,
|
||||
deepseek_service=mock_deepseek_service,
|
||||
)
|
||||
|
||||
await manager.on_post_declined("Отклоненный пост")
|
||||
|
||||
mock_rag_service.add_negative_example.assert_called_once_with("Отклоненный пост")
|
||||
mock_deepseek_service.add_negative_example.assert_called_once_with("Отклоненный пост")
|
||||
Reference in New Issue
Block a user