Files
rag-service/app/services/rag_service.py

488 lines
19 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
RAG сервис для скоринга постов с использованием ruBERT.
Использует модель DeepPavlov/rubert-base-cased для создания эмбеддингов
и сравнивает их с эталонными примерами через VectorStore.
"""
import asyncio
import logging
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any, Dict, List, Optional
import numpy as np
from app.config import Settings, get_settings
from app.exceptions import (
InsufficientExamplesError,
ModelNotLoadedError,
ScoringError,
TextTooShortError,
)
from app.storage.vector_store import VectorStore
logger = logging.getLogger(__name__)
@dataclass
class ScoringResult:
"""
Результат оценки поста.
Attributes:
score: Оценка от 0.0 до 1.0 (вероятность публикации)
confidence: Уверенность в оценке
score_pos_only: Оценка только по положительным примерам
positive_examples: Количество положительных примеров
negative_examples: Количество отрицательных примеров
model: Название используемой модели
timestamp: Время получения оценки
"""
score: float
confidence: float
score_pos_only: float
positive_examples: int
negative_examples: int
model: str
timestamp: int = field(default_factory=lambda: int(datetime.now().timestamp()))
def to_dict(self) -> Dict[str, Any]:
"""Преобразует результат в словарь."""
return {
"rag_score": round(self.score, 4),
"rag_confidence": round(self.confidence, 4),
"rag_score_pos_only": round(self.score_pos_only, 4),
"meta": {
"positive_examples": self.positive_examples,
"negative_examples": self.negative_examples,
"model": self.model,
"timestamp": self.timestamp,
}
}
class RAGService:
"""
RAG сервис для оценки постов на основе векторного сходства.
Использует ruBERT для создания эмбеддингов текста и сравнивает
их с эталонными примерами (опубликованные vs отклоненные посты).
Attributes:
model_name: Название модели HuggingFace
vector_store: Хранилище векторов
min_text_length: Минимальная длина текста для обработки
"""
def __init__(
self,
settings: Optional[Settings] = None,
vector_store: Optional[VectorStore] = None,
):
"""
Инициализация RAG сервиса.
Args:
settings: Настройки сервиса (берутся из get_settings() если не переданы)
vector_store: Хранилище векторов (создается автоматически если не передано)
"""
self._settings = settings or get_settings()
self.model_name = self._settings.model_name
self.cache_dir = self._settings.cache_dir
self.min_text_length = self._settings.min_text_length
# Модель и токенизатор загружаются лениво
self._model = None
self._tokenizer = None
self._device = None
self._model_loaded = False
# Хранилище векторов
self.vector_store = vector_store or VectorStore(
vector_dim=self._settings.vector_dim,
max_examples=self._settings.max_examples,
storage_path=self._settings.vectors_path,
score_multiplier=self._settings.score_multiplier,
)
logger.info(f"RAGService инициализирован (model={self.model_name})")
@property
def is_model_loaded(self) -> bool:
"""Проверяет, загружена ли модель."""
return self._model_loaded
async def load_model(self) -> None:
"""
Загружает модель и токенизатор.
Выполняется асинхронно в отдельном потоке чтобы не блокировать event loop.
"""
if self._model_loaded:
return
logger.info(f"RAGService: Загрузка модели {self.model_name}...")
try:
# Загрузка в отдельном потоке
loop = asyncio.get_event_loop()
await loop.run_in_executor(None, self._load_model_sync)
self._model_loaded = True
logger.info(f"RAGService: Модель {self.model_name} успешно загружена")
except Exception as e:
logger.error(f"RAGService: Ошибка загрузки модели: {e}")
raise ModelNotLoadedError(f"Не удалось загрузить модель {self.model_name}: {e}")
def _load_model_sync(self) -> None:
"""Синхронная загрузка модели (вызывается в executor)."""
logger.info("RAGService: Начало _load_model_sync, импорт transformers...")
from transformers import AutoModel, AutoTokenizer
import torch
# Определяем устройство
self._device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"RAGService: Устройство определено: {self._device}")
# Загружаем токенизатор
logger.info(f"RAGService: Загрузка токенизатора из {self.model_name}...")
self._tokenizer = AutoTokenizer.from_pretrained(
self.model_name,
cache_dir=self.cache_dir,
)
logger.info("RAGService: Токенизатор загружен")
# Загружаем модель
logger.info(f"RAGService: Загрузка модели из {self.model_name} (это может занять несколько минут)...")
self._model = AutoModel.from_pretrained(
self.model_name,
cache_dir=self.cache_dir,
)
logger.info("RAGService: Модель загружена, перенос на устройство...")
self._model.to(self._device)
self._model.eval() # Режим инференса
logger.info(f"RAGService: Модель готова на устройстве: {self._device}")
def _get_embedding_sync(self, text: str) -> np.ndarray:
"""
Получает эмбеддинг текста (синхронно).
Использует [CLS] токен как представление всего текста.
Args:
text: Текст для векторизации
Returns:
Numpy массив с эмбеддингом (768 измерений для ruBERT)
"""
import torch
# Токенизация с ограничением длины
inputs = self._tokenizer(
text,
return_tensors="pt",
truncation=True,
max_length=512,
padding=True,
)
inputs = {k: v.to(self._device) for k, v in inputs.items()}
# Получаем эмбеддинг
with torch.no_grad():
outputs = self._model(**inputs)
# Используем [CLS] токен (первый токен)
embedding = outputs.last_hidden_state[:, 0, :].cpu().numpy()
return embedding.flatten()
def _get_embeddings_batch_sync(self, texts: List[str], batch_size: int = 16) -> List[np.ndarray]:
"""
Получает эмбеддинги для батча текстов (синхронно).
Обрабатывает тексты пачками для эффективного использования GPU/CPU.
Args:
texts: Список текстов для векторизации
batch_size: Размер батча
Returns:
Список numpy массивов с эмбеддингами
"""
import torch
all_embeddings = []
for i in range(0, len(texts), batch_size):
batch_texts = texts[i:i + batch_size]
# Токенизация батча
inputs = self._tokenizer(
batch_texts,
return_tensors="pt",
truncation=True,
max_length=512,
padding=True,
)
inputs = {k: v.to(self._device) for k, v in inputs.items()}
# Получаем эмбеддинги
with torch.no_grad():
outputs = self._model(**inputs)
# [CLS] токен для каждого текста в батче
batch_embeddings = outputs.last_hidden_state[:, 0, :].cpu().numpy()
# Разбиваем на отдельные эмбеддинги
for j in range(len(batch_texts)):
all_embeddings.append(batch_embeddings[j])
if i > 0 and i % (batch_size * 10) == 0:
logger.info(f"RAGService: Обработано {i}/{len(texts)} текстов")
return all_embeddings
async def get_embeddings_batch(self, texts: List[str], batch_size: Optional[int] = None) -> List[np.ndarray]:
"""
Получает эмбеддинги для батча текстов (асинхронно).
Args:
texts: Список текстов для векторизации
batch_size: Размер батча (берется из настроек если не указан)
Returns:
Список numpy массивов с эмбеддингами
"""
if not self._model_loaded:
await self.load_model()
if not self._model_loaded:
raise ModelNotLoadedError("Модель не загружена")
batch_size = batch_size or self._settings.batch_size
# Очищаем тексты
clean_texts = [self._clean_text(text) for text in texts]
# Выполняем батч-обработку в thread pool
loop = asyncio.get_event_loop()
embeddings = await loop.run_in_executor(
None,
self._get_embeddings_batch_sync,
clean_texts,
batch_size,
)
return embeddings
async def get_embedding(self, text: str) -> np.ndarray:
"""
Получает эмбеддинг текста (асинхронно).
Args:
text: Текст для векторизации
Returns:
Numpy массив с эмбеддингом
Raises:
ModelNotLoadedError: Если модель не загружена
TextTooShortError: Если текст слишком короткий
"""
if not self._model_loaded:
await self.load_model()
if not self._model_loaded:
raise ModelNotLoadedError("Модель не загружена")
# Очищаем текст
clean_text = self._clean_text(text)
if len(clean_text) < self.min_text_length:
raise TextTooShortError(
f"Текст слишком короткий (минимум {self.min_text_length} символов)"
)
# Выполняем в отдельном потоке
loop = asyncio.get_event_loop()
embedding = await loop.run_in_executor(
None,
self._get_embedding_sync,
clean_text
)
return embedding
def _clean_text(self, text: str) -> str:
"""Очищает текст от лишних символов."""
if not text:
return ""
# Удаляем лишние пробелы и переносы строк
clean = " ".join(text.split())
# Удаляем служебные символы (например "^" для helper сообщений)
if clean == "^":
return ""
return clean.strip()
async def calculate_score(self, text: str) -> ScoringResult:
"""
Рассчитывает скор для текста поста.
Args:
text: Текст поста для оценки
Returns:
ScoringResult с оценкой
Raises:
ScoringError: При ошибке расчета
InsufficientExamplesError: Если недостаточно примеров
TextTooShortError: Если текст слишком короткий
"""
try:
# Получаем эмбеддинг текста
embedding = await self.get_embedding(text)
# Логируем первые элементы вектора для отладки
logger.debug(
f"RAGService: embedding[:3]={embedding[:3].tolist()}, "
f"text_preview='{text[:30]}'"
)
# Рассчитываем скор через VectorStore
score, confidence, score_pos_only = self.vector_store.calculate_similarity_score(embedding)
return ScoringResult(
score=score,
confidence=confidence,
score_pos_only=score_pos_only,
positive_examples=self.vector_store.positive_count,
negative_examples=self.vector_store.negative_count,
model=self.model_name,
)
except (InsufficientExamplesError, TextTooShortError):
# Пробрасываем ожидаемые исключения
raise
except Exception as e:
logger.error(f"RAGService: Ошибка расчета скора: {e}")
raise ScoringError(f"Ошибка расчета скора: {e}")
async def add_positive_example(self, text: str) -> bool:
"""
Добавляет текст как положительный пример (опубликованный пост).
Args:
text: Текст опубликованного поста
Returns:
True если пример добавлен, False если дубликат/короткий текст
"""
try:
clean_text = self._clean_text(text)
if len(clean_text) < self.min_text_length:
logger.debug("RAGService: Текст слишком короткий для примера, пропускаем")
return False
# Получаем эмбеддинг
embedding = await self.get_embedding(clean_text)
# Вычисляем хеш для дедупликации
text_hash = VectorStore.compute_text_hash(clean_text)
# Добавляем в хранилище
added = self.vector_store.add_positive(embedding, text_hash)
if added:
logger.info("RAGService: Добавлен положительный пример")
return added
except Exception as e:
logger.error(f"RAGService: Ошибка добавления положительного примера: {e}")
return False
async def add_negative_example(self, text: str) -> bool:
"""
Добавляет текст как отрицательный пример (отклоненный пост).
Args:
text: Текст отклоненного поста
Returns:
True если пример добавлен, False если дубликат/короткий текст
"""
try:
clean_text = self._clean_text(text)
if len(clean_text) < self.min_text_length:
logger.debug("RAGService: Текст слишком короткий для примера, пропускаем")
return False
# Получаем эмбеддинг
embedding = await self.get_embedding(clean_text)
# Вычисляем хеш для дедупликации
text_hash = VectorStore.compute_text_hash(clean_text)
# Добавляем в хранилище
added = self.vector_store.add_negative(embedding, text_hash)
if added:
logger.info("RAGService: Добавлен отрицательный пример")
return added
except Exception as e:
logger.error(f"RAGService: Ошибка добавления отрицательного примера: {e}")
return False
async def warmup(self) -> bool:
"""
Прогревает модель (загружает если не загружена).
Returns:
True если модель загружена успешно
"""
try:
await self.load_model()
return self._model_loaded
except Exception as e:
logger.error(f"RAGService: Ошибка прогрева модели: {e}")
return False
def save_vectors(self) -> None:
"""Сохраняет векторы на диск."""
if self.vector_store.storage_path:
self.vector_store.save_to_disk()
def get_stats(self) -> Dict[str, Any]:
"""Возвращает статистику сервиса."""
return {
"model_name": self.model_name,
"model_loaded": self._model_loaded,
"device": self._device,
"vector_store": self.vector_store.get_stats(),
}
# Глобальный экземпляр сервиса (singleton)
_rag_service: Optional[RAGService] = None
def get_rag_service() -> RAGService:
"""
Возвращает глобальный экземпляр RAG сервиса.
Returns:
RAGService: Экземпляр сервиса
"""
global _rag_service
if _rag_service is None:
_rag_service = RAGService()
return _rag_service