Замена RuBERT на sentence-transformers/all-MiniLM-L12-v2, упрощение формулы расчета, поддержка загрузки из отдельных .npy файлов
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
"""
|
||||
RAG сервис для скоринга постов с использованием ruBERT.
|
||||
RAG сервис для скоринга постов с использованием sentence-transformers.
|
||||
|
||||
Использует модель DeepPavlov/rubert-base-cased для создания эмбеддингов
|
||||
Использует модель sentence-transformers/all-MiniLM-L12-v2 для создания эмбеддингов
|
||||
и сравнивает их с эталонными примерами через VectorStore.
|
||||
"""
|
||||
|
||||
@@ -66,7 +66,7 @@ class RAGService:
|
||||
"""
|
||||
RAG сервис для оценки постов на основе векторного сходства.
|
||||
|
||||
Использует ruBERT для создания эмбеддингов текста и сравнивает
|
||||
Использует sentence-transformers для создания эмбеддингов текста и сравнивает
|
||||
их с эталонными примерами (опубликованные vs отклоненные посты).
|
||||
|
||||
Attributes:
|
||||
@@ -92,9 +92,8 @@ class RAGService:
|
||||
self.cache_dir = self._settings.cache_dir
|
||||
self.min_text_length = self._settings.min_text_length
|
||||
|
||||
# Модель и токенизатор загружаются лениво
|
||||
# Модель загружается лениво
|
||||
self._model = None
|
||||
self._tokenizer = None
|
||||
self._device = None
|
||||
self._model_loaded = False
|
||||
|
||||
@@ -104,6 +103,7 @@ class RAGService:
|
||||
max_examples=self._settings.max_examples,
|
||||
storage_path=self._settings.vectors_path,
|
||||
score_multiplier=self._settings.score_multiplier,
|
||||
k=3, # Фиксированное значение k для топ-k ближайших примеров
|
||||
)
|
||||
|
||||
logger.info(f"RAGService инициализирован (model={self.model_name})")
|
||||
@@ -138,64 +138,37 @@ class RAGService:
|
||||
|
||||
def _load_model_sync(self) -> None:
|
||||
"""Синхронная загрузка модели (вызывается в executor)."""
|
||||
logger.info("RAGService: Начало _load_model_sync, импорт transformers...")
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
logger.info("RAGService: Начало _load_model_sync, импорт sentence_transformers...")
|
||||
from sentence_transformers import SentenceTransformer
|
||||
import torch
|
||||
|
||||
# Определяем устройство
|
||||
self._device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
logger.info(f"RAGService: Устройство определено: {self._device}")
|
||||
|
||||
# Загружаем токенизатор
|
||||
logger.info(f"RAGService: Загрузка токенизатора из {self.model_name}...")
|
||||
self._tokenizer = AutoTokenizer.from_pretrained(
|
||||
self.model_name,
|
||||
cache_dir=self.cache_dir,
|
||||
)
|
||||
logger.info("RAGService: Токенизатор загружен")
|
||||
|
||||
# Загружаем модель
|
||||
# Загружаем модель SentenceTransformer
|
||||
logger.info(f"RAGService: Загрузка модели из {self.model_name} (это может занять несколько минут)...")
|
||||
self._model = AutoModel.from_pretrained(
|
||||
self._model = SentenceTransformer(
|
||||
self.model_name,
|
||||
cache_dir=self.cache_dir,
|
||||
cache_folder=self.cache_dir,
|
||||
device=self._device,
|
||||
)
|
||||
logger.info("RAGService: Модель загружена, перенос на устройство...")
|
||||
self._model.to(self._device)
|
||||
self._model.eval() # Режим инференса
|
||||
|
||||
logger.info(f"RAGService: Модель готова на устройстве: {self._device}")
|
||||
|
||||
def _get_embedding_sync(self, text: str) -> np.ndarray:
|
||||
"""
|
||||
Получает эмбеддинг текста (синхронно).
|
||||
|
||||
Использует [CLS] токен как представление всего текста.
|
||||
Использует SentenceTransformer для получения нормализованного эмбеддинга.
|
||||
|
||||
Args:
|
||||
text: Текст для векторизации
|
||||
|
||||
Returns:
|
||||
Numpy массив с эмбеддингом (768 измерений для ruBERT)
|
||||
Numpy массив с эмбеддингом (384 измерений для all-MiniLM-L12-v2)
|
||||
"""
|
||||
import torch
|
||||
|
||||
# Токенизация с ограничением длины
|
||||
inputs = self._tokenizer(
|
||||
text,
|
||||
return_tensors="pt",
|
||||
truncation=True,
|
||||
max_length=512,
|
||||
padding=True,
|
||||
)
|
||||
inputs = {k: v.to(self._device) for k, v in inputs.items()}
|
||||
|
||||
# Получаем эмбеддинг
|
||||
with torch.no_grad():
|
||||
outputs = self._model(**inputs)
|
||||
# Используем [CLS] токен (первый токен)
|
||||
embedding = outputs.last_hidden_state[:, 0, :].cpu().numpy()
|
||||
|
||||
# SentenceTransformer автоматически нормализует эмбеддинги
|
||||
embedding = self._model.encode(text, convert_to_numpy=True, normalize_embeddings=True)
|
||||
return embedding.flatten()
|
||||
|
||||
def _get_embeddings_batch_sync(self, texts: List[str], batch_size: int = 16) -> List[np.ndarray]:
|
||||
@@ -211,37 +184,17 @@ class RAGService:
|
||||
Returns:
|
||||
Список numpy массивов с эмбеддингами
|
||||
"""
|
||||
import torch
|
||||
# SentenceTransformer автоматически обрабатывает батчи
|
||||
embeddings = self._model.encode(
|
||||
texts,
|
||||
batch_size=batch_size,
|
||||
convert_to_numpy=True,
|
||||
normalize_embeddings=True,
|
||||
show_progress_bar=False,
|
||||
)
|
||||
|
||||
all_embeddings = []
|
||||
|
||||
for i in range(0, len(texts), batch_size):
|
||||
batch_texts = texts[i:i + batch_size]
|
||||
|
||||
# Токенизация батча
|
||||
inputs = self._tokenizer(
|
||||
batch_texts,
|
||||
return_tensors="pt",
|
||||
truncation=True,
|
||||
max_length=512,
|
||||
padding=True,
|
||||
)
|
||||
inputs = {k: v.to(self._device) for k, v in inputs.items()}
|
||||
|
||||
# Получаем эмбеддинги
|
||||
with torch.no_grad():
|
||||
outputs = self._model(**inputs)
|
||||
# [CLS] токен для каждого текста в батче
|
||||
batch_embeddings = outputs.last_hidden_state[:, 0, :].cpu().numpy()
|
||||
|
||||
# Разбиваем на отдельные эмбеддинги
|
||||
for j in range(len(batch_texts)):
|
||||
all_embeddings.append(batch_embeddings[j])
|
||||
|
||||
if i > 0 and i % (batch_size * 10) == 0:
|
||||
logger.info(f"RAGService: Обработано {i}/{len(texts)} текстов")
|
||||
|
||||
return all_embeddings
|
||||
# Преобразуем в список отдельных массивов
|
||||
return [emb.flatten() for emb in embeddings]
|
||||
|
||||
async def get_embeddings_batch(self, texts: List[str], batch_size: Optional[int] = None) -> List[np.ndarray]:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user