Initial commit: RAG Service
This commit is contained in:
3
app/services/__init__.py
Normal file
3
app/services/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""
|
||||
Сервисы RAG: ядро логики скоринга.
|
||||
"""
|
||||
488
app/services/rag_service.py
Normal file
488
app/services/rag_service.py
Normal file
@@ -0,0 +1,488 @@
|
||||
"""
|
||||
RAG сервис для скоринга постов с использованием ruBERT.
|
||||
|
||||
Использует модель DeepPavlov/rubert-base-cased для создания эмбеддингов
|
||||
и сравнивает их с эталонными примерами через VectorStore.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from app.config import Settings, get_settings
|
||||
from app.exceptions import (
|
||||
InsufficientExamplesError,
|
||||
ModelNotLoadedError,
|
||||
ScoringError,
|
||||
TextTooShortError,
|
||||
)
|
||||
from app.storage.vector_store import VectorStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScoringResult:
|
||||
"""
|
||||
Результат оценки поста.
|
||||
|
||||
Attributes:
|
||||
score: Оценка от 0.0 до 1.0 (вероятность публикации)
|
||||
confidence: Уверенность в оценке
|
||||
score_pos_only: Оценка только по положительным примерам
|
||||
positive_examples: Количество положительных примеров
|
||||
negative_examples: Количество отрицательных примеров
|
||||
model: Название используемой модели
|
||||
timestamp: Время получения оценки
|
||||
"""
|
||||
score: float
|
||||
confidence: float
|
||||
score_pos_only: float
|
||||
positive_examples: int
|
||||
negative_examples: int
|
||||
model: str
|
||||
timestamp: int = field(default_factory=lambda: int(datetime.now().timestamp()))
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Преобразует результат в словарь."""
|
||||
return {
|
||||
"rag_score": round(self.score, 4),
|
||||
"rag_confidence": round(self.confidence, 4),
|
||||
"rag_score_pos_only": round(self.score_pos_only, 4),
|
||||
"meta": {
|
||||
"positive_examples": self.positive_examples,
|
||||
"negative_examples": self.negative_examples,
|
||||
"model": self.model,
|
||||
"timestamp": self.timestamp,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class RAGService:
|
||||
"""
|
||||
RAG сервис для оценки постов на основе векторного сходства.
|
||||
|
||||
Использует ruBERT для создания эмбеддингов текста и сравнивает
|
||||
их с эталонными примерами (опубликованные vs отклоненные посты).
|
||||
|
||||
Attributes:
|
||||
model_name: Название модели HuggingFace
|
||||
vector_store: Хранилище векторов
|
||||
min_text_length: Минимальная длина текста для обработки
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
settings: Optional[Settings] = None,
|
||||
vector_store: Optional[VectorStore] = None,
|
||||
):
|
||||
"""
|
||||
Инициализация RAG сервиса.
|
||||
|
||||
Args:
|
||||
settings: Настройки сервиса (берутся из get_settings() если не переданы)
|
||||
vector_store: Хранилище векторов (создается автоматически если не передано)
|
||||
"""
|
||||
self._settings = settings or get_settings()
|
||||
self.model_name = self._settings.model_name
|
||||
self.cache_dir = self._settings.cache_dir
|
||||
self.min_text_length = self._settings.min_text_length
|
||||
|
||||
# Модель и токенизатор загружаются лениво
|
||||
self._model = None
|
||||
self._tokenizer = None
|
||||
self._device = None
|
||||
self._model_loaded = False
|
||||
|
||||
# Хранилище векторов
|
||||
self.vector_store = vector_store or VectorStore(
|
||||
vector_dim=self._settings.vector_dim,
|
||||
max_examples=self._settings.max_examples,
|
||||
storage_path=self._settings.vectors_path,
|
||||
score_multiplier=self._settings.score_multiplier,
|
||||
)
|
||||
|
||||
logger.info(f"RAGService инициализирован (model={self.model_name})")
|
||||
|
||||
@property
|
||||
def is_model_loaded(self) -> bool:
|
||||
"""Проверяет, загружена ли модель."""
|
||||
return self._model_loaded
|
||||
|
||||
async def load_model(self) -> None:
|
||||
"""
|
||||
Загружает модель и токенизатор.
|
||||
|
||||
Выполняется асинхронно в отдельном потоке чтобы не блокировать event loop.
|
||||
"""
|
||||
if self._model_loaded:
|
||||
return
|
||||
|
||||
logger.info(f"RAGService: Загрузка модели {self.model_name}...")
|
||||
|
||||
try:
|
||||
# Загрузка в отдельном потоке
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(None, self._load_model_sync)
|
||||
|
||||
self._model_loaded = True
|
||||
logger.info(f"RAGService: Модель {self.model_name} успешно загружена")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"RAGService: Ошибка загрузки модели: {e}")
|
||||
raise ModelNotLoadedError(f"Не удалось загрузить модель {self.model_name}: {e}")
|
||||
|
||||
def _load_model_sync(self) -> None:
|
||||
"""Синхронная загрузка модели (вызывается в executor)."""
|
||||
logger.info("RAGService: Начало _load_model_sync, импорт transformers...")
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
import torch
|
||||
|
||||
# Определяем устройство
|
||||
self._device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
logger.info(f"RAGService: Устройство определено: {self._device}")
|
||||
|
||||
# Загружаем токенизатор
|
||||
logger.info(f"RAGService: Загрузка токенизатора из {self.model_name}...")
|
||||
self._tokenizer = AutoTokenizer.from_pretrained(
|
||||
self.model_name,
|
||||
cache_dir=self.cache_dir,
|
||||
)
|
||||
logger.info("RAGService: Токенизатор загружен")
|
||||
|
||||
# Загружаем модель
|
||||
logger.info(f"RAGService: Загрузка модели из {self.model_name} (это может занять несколько минут)...")
|
||||
self._model = AutoModel.from_pretrained(
|
||||
self.model_name,
|
||||
cache_dir=self.cache_dir,
|
||||
)
|
||||
logger.info("RAGService: Модель загружена, перенос на устройство...")
|
||||
self._model.to(self._device)
|
||||
self._model.eval() # Режим инференса
|
||||
|
||||
logger.info(f"RAGService: Модель готова на устройстве: {self._device}")
|
||||
|
||||
def _get_embedding_sync(self, text: str) -> np.ndarray:
|
||||
"""
|
||||
Получает эмбеддинг текста (синхронно).
|
||||
|
||||
Использует [CLS] токен как представление всего текста.
|
||||
|
||||
Args:
|
||||
text: Текст для векторизации
|
||||
|
||||
Returns:
|
||||
Numpy массив с эмбеддингом (768 измерений для ruBERT)
|
||||
"""
|
||||
import torch
|
||||
|
||||
# Токенизация с ограничением длины
|
||||
inputs = self._tokenizer(
|
||||
text,
|
||||
return_tensors="pt",
|
||||
truncation=True,
|
||||
max_length=512,
|
||||
padding=True,
|
||||
)
|
||||
inputs = {k: v.to(self._device) for k, v in inputs.items()}
|
||||
|
||||
# Получаем эмбеддинг
|
||||
with torch.no_grad():
|
||||
outputs = self._model(**inputs)
|
||||
# Используем [CLS] токен (первый токен)
|
||||
embedding = outputs.last_hidden_state[:, 0, :].cpu().numpy()
|
||||
|
||||
return embedding.flatten()
|
||||
|
||||
def _get_embeddings_batch_sync(self, texts: List[str], batch_size: int = 16) -> List[np.ndarray]:
|
||||
"""
|
||||
Получает эмбеддинги для батча текстов (синхронно).
|
||||
|
||||
Обрабатывает тексты пачками для эффективного использования GPU/CPU.
|
||||
|
||||
Args:
|
||||
texts: Список текстов для векторизации
|
||||
batch_size: Размер батча
|
||||
|
||||
Returns:
|
||||
Список numpy массивов с эмбеддингами
|
||||
"""
|
||||
import torch
|
||||
|
||||
all_embeddings = []
|
||||
|
||||
for i in range(0, len(texts), batch_size):
|
||||
batch_texts = texts[i:i + batch_size]
|
||||
|
||||
# Токенизация батча
|
||||
inputs = self._tokenizer(
|
||||
batch_texts,
|
||||
return_tensors="pt",
|
||||
truncation=True,
|
||||
max_length=512,
|
||||
padding=True,
|
||||
)
|
||||
inputs = {k: v.to(self._device) for k, v in inputs.items()}
|
||||
|
||||
# Получаем эмбеддинги
|
||||
with torch.no_grad():
|
||||
outputs = self._model(**inputs)
|
||||
# [CLS] токен для каждого текста в батче
|
||||
batch_embeddings = outputs.last_hidden_state[:, 0, :].cpu().numpy()
|
||||
|
||||
# Разбиваем на отдельные эмбеддинги
|
||||
for j in range(len(batch_texts)):
|
||||
all_embeddings.append(batch_embeddings[j])
|
||||
|
||||
if i > 0 and i % (batch_size * 10) == 0:
|
||||
logger.info(f"RAGService: Обработано {i}/{len(texts)} текстов")
|
||||
|
||||
return all_embeddings
|
||||
|
||||
async def get_embeddings_batch(self, texts: List[str], batch_size: Optional[int] = None) -> List[np.ndarray]:
|
||||
"""
|
||||
Получает эмбеддинги для батча текстов (асинхронно).
|
||||
|
||||
Args:
|
||||
texts: Список текстов для векторизации
|
||||
batch_size: Размер батча (берется из настроек если не указан)
|
||||
|
||||
Returns:
|
||||
Список numpy массивов с эмбеддингами
|
||||
"""
|
||||
if not self._model_loaded:
|
||||
await self.load_model()
|
||||
|
||||
if not self._model_loaded:
|
||||
raise ModelNotLoadedError("Модель не загружена")
|
||||
|
||||
batch_size = batch_size or self._settings.batch_size
|
||||
|
||||
# Очищаем тексты
|
||||
clean_texts = [self._clean_text(text) for text in texts]
|
||||
|
||||
# Выполняем батч-обработку в thread pool
|
||||
loop = asyncio.get_event_loop()
|
||||
embeddings = await loop.run_in_executor(
|
||||
None,
|
||||
self._get_embeddings_batch_sync,
|
||||
clean_texts,
|
||||
batch_size,
|
||||
)
|
||||
|
||||
return embeddings
|
||||
|
||||
async def get_embedding(self, text: str) -> np.ndarray:
|
||||
"""
|
||||
Получает эмбеддинг текста (асинхронно).
|
||||
|
||||
Args:
|
||||
text: Текст для векторизации
|
||||
|
||||
Returns:
|
||||
Numpy массив с эмбеддингом
|
||||
|
||||
Raises:
|
||||
ModelNotLoadedError: Если модель не загружена
|
||||
TextTooShortError: Если текст слишком короткий
|
||||
"""
|
||||
if not self._model_loaded:
|
||||
await self.load_model()
|
||||
|
||||
if not self._model_loaded:
|
||||
raise ModelNotLoadedError("Модель не загружена")
|
||||
|
||||
# Очищаем текст
|
||||
clean_text = self._clean_text(text)
|
||||
|
||||
if len(clean_text) < self.min_text_length:
|
||||
raise TextTooShortError(
|
||||
f"Текст слишком короткий (минимум {self.min_text_length} символов)"
|
||||
)
|
||||
|
||||
# Выполняем в отдельном потоке
|
||||
loop = asyncio.get_event_loop()
|
||||
embedding = await loop.run_in_executor(
|
||||
None,
|
||||
self._get_embedding_sync,
|
||||
clean_text
|
||||
)
|
||||
|
||||
return embedding
|
||||
|
||||
def _clean_text(self, text: str) -> str:
|
||||
"""Очищает текст от лишних символов."""
|
||||
if not text:
|
||||
return ""
|
||||
|
||||
# Удаляем лишние пробелы и переносы строк
|
||||
clean = " ".join(text.split())
|
||||
|
||||
# Удаляем служебные символы (например "^" для helper сообщений)
|
||||
if clean == "^":
|
||||
return ""
|
||||
|
||||
return clean.strip()
|
||||
|
||||
async def calculate_score(self, text: str) -> ScoringResult:
|
||||
"""
|
||||
Рассчитывает скор для текста поста.
|
||||
|
||||
Args:
|
||||
text: Текст поста для оценки
|
||||
|
||||
Returns:
|
||||
ScoringResult с оценкой
|
||||
|
||||
Raises:
|
||||
ScoringError: При ошибке расчета
|
||||
InsufficientExamplesError: Если недостаточно примеров
|
||||
TextTooShortError: Если текст слишком короткий
|
||||
"""
|
||||
try:
|
||||
# Получаем эмбеддинг текста
|
||||
embedding = await self.get_embedding(text)
|
||||
|
||||
# Логируем первые элементы вектора для отладки
|
||||
logger.debug(
|
||||
f"RAGService: embedding[:3]={embedding[:3].tolist()}, "
|
||||
f"text_preview='{text[:30]}'"
|
||||
)
|
||||
|
||||
# Рассчитываем скор через VectorStore
|
||||
score, confidence, score_pos_only = self.vector_store.calculate_similarity_score(embedding)
|
||||
|
||||
return ScoringResult(
|
||||
score=score,
|
||||
confidence=confidence,
|
||||
score_pos_only=score_pos_only,
|
||||
positive_examples=self.vector_store.positive_count,
|
||||
negative_examples=self.vector_store.negative_count,
|
||||
model=self.model_name,
|
||||
)
|
||||
|
||||
except (InsufficientExamplesError, TextTooShortError):
|
||||
# Пробрасываем ожидаемые исключения
|
||||
raise
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"RAGService: Ошибка расчета скора: {e}")
|
||||
raise ScoringError(f"Ошибка расчета скора: {e}")
|
||||
|
||||
async def add_positive_example(self, text: str) -> bool:
|
||||
"""
|
||||
Добавляет текст как положительный пример (опубликованный пост).
|
||||
|
||||
Args:
|
||||
text: Текст опубликованного поста
|
||||
|
||||
Returns:
|
||||
True если пример добавлен, False если дубликат/короткий текст
|
||||
"""
|
||||
try:
|
||||
clean_text = self._clean_text(text)
|
||||
if len(clean_text) < self.min_text_length:
|
||||
logger.debug("RAGService: Текст слишком короткий для примера, пропускаем")
|
||||
return False
|
||||
|
||||
# Получаем эмбеддинг
|
||||
embedding = await self.get_embedding(clean_text)
|
||||
|
||||
# Вычисляем хеш для дедупликации
|
||||
text_hash = VectorStore.compute_text_hash(clean_text)
|
||||
|
||||
# Добавляем в хранилище
|
||||
added = self.vector_store.add_positive(embedding, text_hash)
|
||||
|
||||
if added:
|
||||
logger.info("RAGService: Добавлен положительный пример")
|
||||
|
||||
return added
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"RAGService: Ошибка добавления положительного примера: {e}")
|
||||
return False
|
||||
|
||||
async def add_negative_example(self, text: str) -> bool:
|
||||
"""
|
||||
Добавляет текст как отрицательный пример (отклоненный пост).
|
||||
|
||||
Args:
|
||||
text: Текст отклоненного поста
|
||||
|
||||
Returns:
|
||||
True если пример добавлен, False если дубликат/короткий текст
|
||||
"""
|
||||
try:
|
||||
clean_text = self._clean_text(text)
|
||||
if len(clean_text) < self.min_text_length:
|
||||
logger.debug("RAGService: Текст слишком короткий для примера, пропускаем")
|
||||
return False
|
||||
|
||||
# Получаем эмбеддинг
|
||||
embedding = await self.get_embedding(clean_text)
|
||||
|
||||
# Вычисляем хеш для дедупликации
|
||||
text_hash = VectorStore.compute_text_hash(clean_text)
|
||||
|
||||
# Добавляем в хранилище
|
||||
added = self.vector_store.add_negative(embedding, text_hash)
|
||||
|
||||
if added:
|
||||
logger.info("RAGService: Добавлен отрицательный пример")
|
||||
|
||||
return added
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"RAGService: Ошибка добавления отрицательного примера: {e}")
|
||||
return False
|
||||
|
||||
async def warmup(self) -> bool:
|
||||
"""
|
||||
Прогревает модель (загружает если не загружена).
|
||||
|
||||
Returns:
|
||||
True если модель загружена успешно
|
||||
"""
|
||||
try:
|
||||
await self.load_model()
|
||||
return self._model_loaded
|
||||
except Exception as e:
|
||||
logger.error(f"RAGService: Ошибка прогрева модели: {e}")
|
||||
return False
|
||||
|
||||
def save_vectors(self) -> None:
|
||||
"""Сохраняет векторы на диск."""
|
||||
if self.vector_store.storage_path:
|
||||
self.vector_store.save_to_disk()
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Возвращает статистику сервиса."""
|
||||
return {
|
||||
"model_name": self.model_name,
|
||||
"model_loaded": self._model_loaded,
|
||||
"device": self._device,
|
||||
"cache_dir": self.cache_dir,
|
||||
"vector_store": self.vector_store.get_stats(),
|
||||
}
|
||||
|
||||
|
||||
# Глобальный экземпляр сервиса (singleton)
|
||||
_rag_service: Optional[RAGService] = None
|
||||
|
||||
|
||||
def get_rag_service() -> RAGService:
|
||||
"""
|
||||
Возвращает глобальный экземпляр RAG сервиса.
|
||||
|
||||
Returns:
|
||||
RAGService: Экземпляр сервиса
|
||||
"""
|
||||
global _rag_service
|
||||
if _rag_service is None:
|
||||
_rag_service = RAGService()
|
||||
return _rag_service
|
||||
Reference in New Issue
Block a user