feat: add submitted collection, /similar and /submitted endpoints (Stage 4)

Made-with: Cursor
This commit is contained in:
2026-02-28 19:00:22 +03:00
parent 955f518429
commit a1d6d2d860
15 changed files with 1308 additions and 400 deletions

2
.gitignore vendored
View File

@@ -133,6 +133,8 @@ Thumbs.db
# Project specific
data/models/
data/vectors/*.npz
*.bak
*.tar.gz
# Keep data directories
!data/models/.gitkeep

View File

@@ -14,9 +14,9 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
# Копируем зависимости
COPY requirements.txt .
# Устанавливаем зависимости
# --no-cache-dir для уменьшения размера образа
RUN pip install --no-cache-dir -r requirements.txt
# Устанавливаем зависимости (CPU-only torch для контейнеров без GPU)
RUN pip install --no-cache-dir torch --index-url https://download.pytorch.org/whl/cpu \
&& pip install --no-cache-dir -r requirements.txt
# Копируем код приложения
COPY app/ ./app/

View File

@@ -24,7 +24,12 @@ from app.schemas import (
ScoreRequest,
ScoreResponse,
ScoringParamsResponse,
SimilarPostItem,
SimilarRequest,
SimilarResponse,
StatsResponse,
SubmittedRequest,
SubmittedResponse,
UpdateScoringParamsRequest,
VectorStoreStats,
WarmupResponse,
@@ -49,6 +54,7 @@ RAGServiceDep = Annotated[RAGService, Depends(get_service)]
# Health Check
# =============================================================================
@router.get(
"/health",
response_model=HealthResponse,
@@ -73,6 +79,7 @@ async def health_check(service: RAGServiceDep, _auth: AuthDep) -> HealthResponse
# Scoring
# =============================================================================
@router.post(
"/score",
response_model=ScoreResponse,
@@ -147,6 +154,7 @@ async def calculate_score(
# Examples
# =============================================================================
@router.post(
"/examples/positive",
response_model=ExampleResponse,
@@ -277,10 +285,117 @@ async def add_negative_example(
)
# =============================================================================
# Similar & Submitted
# =============================================================================
@router.post(
"/similar",
response_model=SimilarResponse,
responses={
400: {"model": ErrorResponse, "description": "Ошибка в запросе"},
401: {"model": ErrorResponse, "description": "Не авторизован"},
403: {"model": ErrorResponse, "description": "Доступ запрещён"},
503: {"model": ErrorResponse, "description": "Сервис недоступен"},
},
summary="Поиск похожих постов",
tags=["similar"],
)
async def find_similar_posts(
request: SimilarRequest,
service: RAGServiceDep,
_auth: AuthDep,
) -> SimilarResponse:
"""
Ищет похожие submitted-посты за последние N часов.
Args:
request: Запрос с текстом, threshold и hours
service: RAG сервис
Returns:
SimilarResponse: Список похожих постов
"""
try:
similar = await service.find_similar_posts(
text=request.text,
threshold=request.threshold,
hours=request.hours,
)
return SimilarResponse(
similar_count=len(similar),
similar_posts=[SimilarPostItem(**item) for item in similar],
)
except TextTooShortError as e:
logger.warning(f"Текст слишком короткий: {e}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"detail": str(e), "error_type": "TextTooShortError"},
)
except ModelNotLoadedError as e:
logger.error(f"Модель не загружена: {e}")
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail={"detail": str(e), "error_type": "ModelNotLoadedError"},
)
@router.post(
"/submitted",
response_model=SubmittedResponse,
responses={
400: {"model": ErrorResponse, "description": "Ошибка в запросе"},
401: {"model": ErrorResponse, "description": "Не авторизован"},
403: {"model": ErrorResponse, "description": "Доступ запрещён"},
503: {"model": ErrorResponse, "description": "Сервис недоступен"},
},
summary="Добавить submitted-пост",
tags=["submitted"],
)
async def add_submitted_post(
request: SubmittedRequest,
service: RAGServiceDep,
_auth: AuthDep,
) -> SubmittedResponse:
"""
Добавляет submitted-пост в коллекцию для индексации ботом.
Args:
request: Запрос с текстом, post_id и rag_score
service: RAG сервис
Returns:
SubmittedResponse: Результат добавления
"""
try:
added = await service.add_submitted_post(
text=request.text,
post_id=request.post_id,
rag_score=request.rag_score,
)
if added:
message = "Submitted-пост добавлен"
else:
message = "Пост не добавлен (дубликат или слишком короткий текст)"
return SubmittedResponse(
success=added,
message=message,
submitted_count=service.vector_store.submitted_count,
)
except ModelNotLoadedError as e:
logger.error(f"Модель не загружена: {e}")
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail={"detail": str(e), "error_type": "ModelNotLoadedError"},
)
# =============================================================================
# Stats & Warmup
# =============================================================================
@router.get(
"/stats",
response_model=StatsResponse,
@@ -354,6 +469,7 @@ async def warmup(service: RAGServiceDep, _auth: AuthDep) -> WarmupResponse:
# Scoring Parameters
# =============================================================================
@router.get(
"/scoring/params",
response_model=ScoringParamsResponse,

View File

@@ -5,7 +5,6 @@
import os
import secrets
from dataclasses import dataclass, field
from typing import Optional
@dataclass
@@ -20,53 +19,41 @@ class Settings:
model_name: str = field(
default_factory=lambda: os.getenv("RAG_MODEL", "sentence-transformers/all-MiniLM-L12-v2")
)
cache_dir: str = field(
default_factory=lambda: os.getenv("RAG_CACHE_DIR", "data/models")
)
cache_dir: str = field(default_factory=lambda: os.getenv("RAG_CACHE_DIR", "data/models"))
# VectorStore
vectors_path: str = field(
default_factory=lambda: os.getenv("RAG_VECTORS_PATH", "data/vectors/vectors.npz")
)
max_examples: int = field(
default_factory=lambda: int(os.getenv("RAG_MAX_EXAMPLES", "10000"))
max_examples: int = field(default_factory=lambda: int(os.getenv("RAG_MAX_EXAMPLES", "10000")))
max_submitted: int = field(default_factory=lambda: int(os.getenv("RAG_MAX_SUBMITTED", "5000")))
submitted_path: str = field(
default_factory=lambda: os.getenv("RAG_SUBMITTED_PATH", "data/vectors/submitted.npz")
)
score_multiplier: float = field(
default_factory=lambda: float(os.getenv("RAG_SCORE_MULTIPLIER", "5.0"))
)
# Батч-обработка
batch_size: int = field(
default_factory=lambda: int(os.getenv("RAG_BATCH_SIZE", "16"))
)
batch_size: int = field(default_factory=lambda: int(os.getenv("RAG_BATCH_SIZE", "16")))
# Минимальная длина текста
min_text_length: int = field(
default_factory=lambda: int(os.getenv("RAG_MIN_TEXT_LENGTH", "3"))
)
min_text_length: int = field(default_factory=lambda: int(os.getenv("RAG_MIN_TEXT_LENGTH", "3")))
# API настройки
api_host: str = field(
default_factory=lambda: os.getenv("RAG_API_HOST", "0.0.0.0")
)
api_port: int = field(
default_factory=lambda: int(os.getenv("RAG_API_PORT", "8000"))
)
api_host: str = field(default_factory=lambda: os.getenv("RAG_API_HOST", "0.0.0.0"))
api_port: int = field(default_factory=lambda: int(os.getenv("RAG_API_PORT", "8000")))
# Безопасность
# API ключ для авторизации (обязателен в продакшене!)
api_key: Optional[str] = field(
default_factory=lambda: os.getenv("RAG_API_KEY")
)
api_key: str | None = field(default_factory=lambda: os.getenv("RAG_API_KEY"))
# Разрешить запросы без ключа (только для разработки)
allow_no_auth: bool = field(
default_factory=lambda: os.getenv("RAG_ALLOW_NO_AUTH", "false").lower() == "true"
)
# Логирование
log_level: str = field(
default_factory=lambda: os.getenv("LOG_LEVEL", "INFO")
)
log_level: str = field(default_factory=lambda: os.getenv("LOG_LEVEL", "INFO"))
# Автосохранение (интервал в секундах, 0 = отключено)
autosave_interval: int = field(
@@ -88,7 +75,7 @@ class Settings:
# Глобальный экземпляр настроек
_settings: Optional[Settings] = None
_settings: Settings | None = None
def get_settings() -> Settings:

View File

@@ -5,29 +5,35 @@
class RAGServiceError(Exception):
"""Базовое исключение для ошибок RAG сервиса."""
pass
class ModelNotLoadedError(RAGServiceError):
"""Модель не загружена или недоступна."""
pass
class VectorStoreError(RAGServiceError):
"""Ошибка при работе с хранилищем векторов."""
pass
class InsufficientExamplesError(RAGServiceError):
"""Недостаточно примеров для расчета скора."""
pass
class TextTooShortError(RAGServiceError):
"""Текст слишком короткий для векторизации."""
pass
class ScoringError(RAGServiceError):
"""Ошибка при расчете скора."""
pass

View File

@@ -7,8 +7,8 @@ FastAPI приложение Embedding сервиса.
import asyncio
import logging
import sys
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from typing import AsyncGenerator, Optional
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
@@ -18,6 +18,7 @@ from app.api.routes import router
from app.config import get_settings
from app.services.rag_service import RAGService, get_rag_service
# Настройка логирования
def setup_logging() -> None:
"""Настраивает логирование для приложения."""
@@ -40,7 +41,7 @@ def setup_logging() -> None:
logger = logging.getLogger(__name__)
# Глобальная задача автосохранения
_autosave_task: Optional[asyncio.Task] = None
_autosave_task: asyncio.Task | None = None
async def autosave_loop(service: RAGService, interval: int) -> None:
@@ -58,12 +59,19 @@ async def autosave_loop(service: RAGService, interval: int) -> None:
await asyncio.sleep(interval)
# Сохраняем только если есть данные
if service.vector_store.total_count > 0:
has_examples = service.vector_store.total_count > 0
has_submitted = service.vector_store.submitted_count > 0
if has_examples or has_submitted:
service.save_vectors()
logger.info(
f"Автосохранение: сохранено {service.vector_store.positive_count} pos, "
parts = []
if has_examples:
parts.append(
f"{service.vector_store.positive_count} pos, "
f"{service.vector_store.negative_count} neg"
)
if has_submitted:
parts.append(f"{service.vector_store.submitted_count} submitted")
logger.info(f"Автосохранение: сохранено {', '.join(parts)}")
else:
logger.debug("Автосохранение: нет данных для сохранения")
@@ -100,9 +108,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
# Запускаем автосохранение если включено
if settings.autosave_interval > 0:
_autosave_task = asyncio.create_task(
autosave_loop(service, settings.autosave_interval)
)
_autosave_task = asyncio.create_task(autosave_loop(service, settings.autosave_interval))
logger.info(f"Автосохранение включено: каждые {settings.autosave_interval} сек")
else:
logger.info("Автосохранение отключено")
@@ -176,12 +182,14 @@ app.add_middleware(
allow_headers=["*"],
)
# Простой healthcheck endpoint без авторизации (для Docker healthcheck)
@app.get("/health")
async def simple_health_check():
"""Простая проверка здоровья без авторизации (для Docker healthcheck)."""
return {"status": "ok"}
# Подключение роутов
app.include_router(router, prefix="/api/v1")

View File

@@ -2,36 +2,66 @@
Pydantic схемы для API Embedding сервиса.
"""
from typing import Any, Dict, Optional
from pydantic import BaseModel, Field
# =============================================================================
# Запросы
# =============================================================================
class ScoreRequest(BaseModel):
"""Запрос на расчет скора."""
text: str = Field(..., min_length=1, description="Текст поста для оценки")
model_config = {
"json_schema_extra": {
"example": {
"text": "Это пример текста поста для оценки скоринга"
}
}
"json_schema_extra": {"example": {"text": "Это пример текста поста для оценки скоринга"}}
}
class ExampleRequest(BaseModel):
"""Запрос на добавление примера."""
text: str = Field(..., min_length=1, description="Текст примера")
model_config = {
"json_schema_extra": {"example": {"text": "Это пример опубликованного/отклоненного поста"}}
}
class SimilarRequest(BaseModel):
"""Запрос на поиск похожих постов."""
text: str = Field(..., min_length=1, description="Текст для поиска похожих")
threshold: float = Field(
default=0.9, ge=0.0, le=1.0, description="Минимальный порог similarity"
)
hours: int = Field(default=24, ge=1, le=168, description="Количество часов для фильтрации")
model_config = {
"json_schema_extra": {
"example": {
"text": "Это пример опубликованного/отклоненного поста"
"text": "Текст поста для поиска похожих",
"threshold": 0.9,
"hours": 24,
}
}
}
class SubmittedRequest(BaseModel):
"""Запрос на добавление submitted-поста."""
text: str = Field(..., min_length=1, description="Текст поста")
post_id: int | None = None
rag_score: float | None = None
model_config = {
"json_schema_extra": {
"example": {
"text": "Текст submitted-поста",
"post_id": 12345,
"rag_score": 0.85,
}
}
}
@@ -41,8 +71,10 @@ class ExampleRequest(BaseModel):
# Ответы
# =============================================================================
class ScoreMetadata(BaseModel):
"""Метаданные результата скоринга."""
positive_examples: int = Field(..., description="Количество положительных примеров")
negative_examples: int = Field(..., description="Количество отрицательных примеров")
model: str = Field(..., description="Название модели")
@@ -51,9 +83,12 @@ class ScoreMetadata(BaseModel):
class ScoreResponse(BaseModel):
"""Ответ с результатом скоринга."""
rag_score: float = Field(..., ge=0.0, le=1.0, description="Основной скор (neg/pos формула)")
rag_confidence: float = Field(..., ge=0.0, le=1.0, description="Уверенность в оценке")
rag_score_pos_only: float = Field(..., ge=0.0, le=1.0, description="Скор только по положительным примерам")
rag_score_pos_only: float = Field(
..., ge=0.0, le=1.0, description="Скор только по положительным примерам"
)
meta: ScoreMetadata = Field(..., description="Метаданные")
model_config = {
@@ -66,8 +101,8 @@ class ScoreResponse(BaseModel):
"positive_examples": 500,
"negative_examples": 350,
"model": "sentence-transformers/all-MiniLM-L12-v2",
"timestamp": 1706270000
}
"timestamp": 1706270000,
},
}
}
}
@@ -75,6 +110,7 @@ class ScoreResponse(BaseModel):
class ExampleResponse(BaseModel):
"""Ответ на добавление примера."""
success: bool = Field(..., description="Успешность добавления")
message: str = Field(..., description="Сообщение о результате")
positive_count: int = Field(..., description="Текущее количество положительных примеров")
@@ -86,7 +122,59 @@ class ExampleResponse(BaseModel):
"success": True,
"message": "Положительный пример добавлен",
"positive_count": 501,
"negative_count": 350
"negative_count": 350,
}
}
}
class SimilarPostItem(BaseModel):
"""Элемент похожего поста."""
similarity: float = Field(..., description="Косинусное сходство")
created_at: int = Field(..., description="Unix timestamp создания")
post_id: int | None = None
text: str = Field(..., description="Текст поста")
rag_score: float | None = None
class SimilarResponse(BaseModel):
"""Ответ с похожими постами."""
similar_count: int = Field(..., description="Количество найденных похожих постов")
similar_posts: list[SimilarPostItem] = Field(..., description="Список похожих постов")
model_config = {
"json_schema_extra": {
"example": {
"similar_count": 2,
"similar_posts": [
{
"similarity": 0.95,
"created_at": 1706270000,
"post_id": 123,
"text": "Похожий пост",
"rag_score": 0.85,
}
],
}
}
}
class SubmittedResponse(BaseModel):
"""Ответ на добавление submitted-поста."""
success: bool = Field(..., description="Успешность добавления")
message: str = Field(..., description="Сообщение о результате")
submitted_count: int = Field(..., description="Текущее количество submitted-постов")
model_config = {
"json_schema_extra": {
"example": {
"success": True,
"message": "Submitted-пост добавлен",
"submitted_count": 42,
}
}
}
@@ -94,18 +182,22 @@ class ExampleResponse(BaseModel):
class VectorStoreStats(BaseModel):
"""Статистика хранилища векторов."""
positive_count: int = Field(..., description="Количество положительных примеров")
negative_count: int = Field(..., description="Количество отрицательных примеров")
total_count: int = Field(..., description="Общее количество примеров")
submitted_count: int = Field(default=0, description="Количество submitted-постов")
vector_dim: int = Field(..., description="Размерность векторов")
max_examples: int = Field(..., description="Максимальное количество примеров")
max_submitted: int = Field(default=5000, description="Максимальное количество submitted-постов")
class StatsResponse(BaseModel):
"""Ответ со статистикой сервиса."""
model_name: str = Field(..., description="Название модели")
model_loaded: bool = Field(..., description="Загружена ли модель")
device: Optional[str] = Field(None, description="Устройство (cpu/cuda)")
device: str | None = Field(None, description="Устройство (cpu/cuda)")
vector_store: VectorStoreStats = Field(..., description="Статистика хранилища векторов")
model_config = {
@@ -119,8 +211,8 @@ class StatsResponse(BaseModel):
"negative_count": 350,
"total_count": 850,
"vector_dim": 384,
"max_examples": 10000
}
"max_examples": 10000,
},
}
}
}
@@ -128,6 +220,7 @@ class StatsResponse(BaseModel):
class WarmupResponse(BaseModel):
"""Ответ на прогрев модели."""
success: bool = Field(..., description="Успешность загрузки")
model_loaded: bool = Field(..., description="Загружена ли модель")
message: str = Field(..., description="Сообщение о результате")
@@ -137,7 +230,7 @@ class WarmupResponse(BaseModel):
"example": {
"success": True,
"model_loaded": True,
"message": "Модель успешно загружена"
"message": "Модель успешно загружена",
}
}
}
@@ -145,6 +238,7 @@ class WarmupResponse(BaseModel):
class ErrorResponse(BaseModel):
"""Ответ с ошибкой."""
detail: str = Field(..., description="Описание ошибки")
error_type: str = Field(..., description="Тип ошибки")
@@ -152,7 +246,7 @@ class ErrorResponse(BaseModel):
"json_schema_extra": {
"example": {
"detail": "Недостаточно примеров для расчета скора",
"error_type": "InsufficientExamplesError"
"error_type": "InsufficientExamplesError",
}
}
}
@@ -160,23 +254,21 @@ class ErrorResponse(BaseModel):
class HealthResponse(BaseModel):
"""Ответ проверки здоровья сервиса."""
status: str = Field(..., description="Статус сервиса")
model_loaded: bool = Field(..., description="Загружена ли модель")
version: str = Field(..., description="Версия сервиса")
model_config = {
"json_schema_extra": {
"example": {
"status": "healthy",
"model_loaded": True,
"version": "0.1.0"
}
"example": {"status": "healthy", "model_loaded": True, "version": "0.1.0"}
}
}
class ScoringParamsResponse(BaseModel):
"""Ответ с текущими параметрами формулы расчета score."""
score_multiplier: float = Field(
...,
description=(
@@ -185,7 +277,7 @@ class ScoringParamsResponse(BaseModel):
"где diff = avg_pos - avg_neg (разница средних сходств топ-k примеров). "
"Чем больше значение, тем сильнее влияние разницы между положительными и отрицательными примерами на итоговый score. "
"Рекомендуемое значение: 5.0"
)
),
)
k: int = Field(
...,
@@ -195,22 +287,16 @@ class ScoringParamsResponse(BaseModel):
"и вычисляет среднее косинусное сходство. "
"Меньшее значение k делает алгоритм более чувствительным к различиям, но может быть менее стабильным. "
"Рекомендуемое значение: 3"
)
),
)
model_config = {
"json_schema_extra": {
"example": {
"score_multiplier": 5.0,
"k": 3
}
}
}
model_config = {"json_schema_extra": {"example": {"score_multiplier": 5.0, "k": 3}}}
class UpdateScoringParamsRequest(BaseModel):
"""Запрос на обновление параметров формулы расчета score."""
score_multiplier: Optional[float] = Field(
score_multiplier: float | None = Field(
None,
gt=0,
description=(
@@ -219,9 +305,9 @@ class UpdateScoringParamsRequest(BaseModel):
"где diff = avg_pos - avg_neg (разница средних сходств топ-k примеров). "
"Чем больше значение, тем сильнее влияние разницы между положительными и отрицательными примерами на итоговый score. "
"Должен быть > 0. Рекомендуемое значение: 5.0"
),
)
)
k: Optional[int] = Field(
k: int | None = Field(
None,
ge=1,
description=(
@@ -230,14 +316,7 @@ class UpdateScoringParamsRequest(BaseModel):
"и вычисляет среднее косинусное сходство. "
"Меньшее значение k делает алгоритм более чувствительным к различиям, но может быть менее стабильным. "
"Должно быть >= 1. Рекомендуемое значение: 3"
)
),
)
model_config = {
"json_schema_extra": {
"example": {
"score_multiplier": 5.0,
"k": 3
}
}
}
model_config = {"json_schema_extra": {"example": {"score_multiplier": 5.0, "k": 3}}}

View File

@@ -9,7 +9,7 @@ import asyncio
import logging
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any, Dict, List, Optional
from typing import Any
import numpy as np
@@ -39,6 +39,7 @@ class ScoringResult:
model: Название используемой модели
timestamp: Время получения оценки
"""
score: float
confidence: float
score_pos_only: float
@@ -47,7 +48,7 @@ class ScoringResult:
model: str
timestamp: int = field(default_factory=lambda: int(datetime.now().timestamp()))
def to_dict(self) -> Dict[str, Any]:
def to_dict(self) -> dict[str, Any]:
"""Преобразует результат в словарь."""
return {
"rag_score": round(self.score, 4),
@@ -58,7 +59,7 @@ class ScoringResult:
"negative_examples": self.negative_examples,
"model": self.model,
"timestamp": self.timestamp,
}
},
}
@@ -77,8 +78,8 @@ class RAGService:
def __init__(
self,
settings: Optional[Settings] = None,
vector_store: Optional[VectorStore] = None,
settings: Settings | None = None,
vector_store: VectorStore | None = None,
):
"""
Инициализация RAG сервиса.
@@ -101,7 +102,9 @@ class RAGService:
self.vector_store = vector_store or VectorStore(
vector_dim=self._settings.vector_dim,
max_examples=self._settings.max_examples,
max_submitted=self._settings.max_submitted,
storage_path=self._settings.vectors_path,
submitted_path=self._settings.submitted_path,
score_multiplier=self._settings.score_multiplier,
k=3, # Фиксированное значение k для топ-k ближайших примеров
)
@@ -139,15 +142,17 @@ class RAGService:
def _load_model_sync(self) -> None:
"""Синхронная загрузка модели (вызывается в executor)."""
logger.info("RAGService: Начало _load_model_sync, импорт sentence_transformers...")
from sentence_transformers import SentenceTransformer
import torch
from sentence_transformers import SentenceTransformer
# Определяем устройство
self._device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"RAGService: Устройство определено: {self._device}")
# Загружаем модель SentenceTransformer
logger.info(f"RAGService: Загрузка модели из {self.model_name} (это может занять несколько минут)...")
logger.info(
f"RAGService: Загрузка модели из {self.model_name} (это может занять несколько минут)..."
)
self._model = SentenceTransformer(
self.model_name,
cache_folder=self.cache_dir,
@@ -171,7 +176,9 @@ class RAGService:
embedding = self._model.encode(text, convert_to_numpy=True, normalize_embeddings=True)
return embedding.flatten()
def _get_embeddings_batch_sync(self, texts: List[str], batch_size: int = 16) -> List[np.ndarray]:
def _get_embeddings_batch_sync(
self, texts: list[str], batch_size: int = 16
) -> list[np.ndarray]:
"""
Получает эмбеддинги для батча текстов (синхронно).
@@ -196,7 +203,9 @@ class RAGService:
# Преобразуем в список отдельных массивов
return [emb.flatten() for emb in embeddings]
async def get_embeddings_batch(self, texts: List[str], batch_size: Optional[int] = None) -> List[np.ndarray]:
async def get_embeddings_batch(
self, texts: list[str], batch_size: int | None = None
) -> list[np.ndarray]:
"""
Получает эмбеддинги для батча текстов (асинхронно).
@@ -259,11 +268,7 @@ class RAGService:
# Выполняем в отдельном потоке
loop = asyncio.get_event_loop()
embedding = await loop.run_in_executor(
None,
self._get_embedding_sync,
clean_text
)
embedding = await loop.run_in_executor(None, self._get_embedding_sync, clean_text)
return embedding
@@ -302,12 +307,13 @@ class RAGService:
# Логируем первые элементы вектора для отладки
logger.debug(
f"RAGService: embedding[:3]={embedding[:3].tolist()}, "
f"text_preview='{text[:30]}'"
f"RAGService: embedding[:3]={embedding[:3].tolist()}, text_preview='{text[:30]}'"
)
# Рассчитываем скор через VectorStore
score, confidence, score_pos_only = self.vector_store.calculate_similarity_score(embedding)
score, confidence, score_pos_only = self.vector_store.calculate_similarity_score(
embedding
)
return ScoringResult(
score=score,
@@ -394,6 +400,79 @@ class RAGService:
logger.error(f"RAGService: Ошибка добавления отрицательного примера: {e}")
return False
async def add_submitted_post(
self,
text: str,
post_id: int | None = None,
rag_score: float | None = None,
) -> bool:
"""
Добавляет submitted-пост в коллекцию для индексации.
Args:
text: Текст поста
post_id: ID поста (опционально)
rag_score: RAG скор поста (опционально)
Returns:
True если добавлен, False если дубликат/короткий текст
"""
try:
clean_text = self._clean_text(text)
if len(clean_text) < self.min_text_length:
logger.debug("RAGService: Текст слишком короткий для submitted, пропускаем")
return False
embedding = await self.get_embedding(clean_text)
text_hash = VectorStore.compute_text_hash(clean_text)
created_at = int(datetime.now().timestamp())
added = self.vector_store.add_submitted(
vector=embedding,
text_hash=text_hash,
created_at=created_at,
post_id=post_id,
text=clean_text,
rag_score=rag_score,
)
if added:
logger.info("RAGService: Добавлен submitted-пост")
return added
except Exception as e:
logger.error(f"RAGService: Ошибка добавления submitted-поста: {e}")
return False
async def find_similar_posts(
self,
text: str,
threshold: float = 0.9,
hours: int = 24,
) -> list[dict[str, Any]]:
"""
Ищет похожие submitted-посты за последние N часов.
Args:
text: Текст для поиска
threshold: Минимальный порог similarity (0.0 - 1.0)
hours: Количество часов для фильтрации
Returns:
Список dict с полями: similarity, created_at, post_id, text, rag_score
"""
try:
embedding = await self.get_embedding(text)
return self.vector_store.find_similar_submitted(
vector=embedding,
threshold=threshold,
hours=hours,
)
except Exception as e:
logger.error(f"RAGService: Ошибка поиска похожих постов: {e}")
return []
async def warmup(self) -> bool:
"""
Прогревает модель (загружает если не загружена).
@@ -409,11 +488,13 @@ class RAGService:
return False
def save_vectors(self) -> None:
"""Сохраняет векторы на диск."""
"""Сохраняет векторы на диск (включая submitted)."""
if self.vector_store.storage_path:
self.vector_store.save_to_disk()
if self.vector_store.submitted_path:
self.vector_store.save_submitted_to_disk()
def get_stats(self) -> Dict[str, Any]:
def get_stats(self) -> dict[str, Any]:
"""Возвращает статистику сервиса."""
return {
"model_name": self.model_name,
@@ -424,7 +505,7 @@ class RAGService:
# Глобальный экземпляр сервиса (singleton)
_rag_service: Optional[RAGService] = None
_rag_service: RAGService | None = None
def get_rag_service() -> RAGService:

View File

@@ -9,8 +9,9 @@ import hashlib
import logging
import os
import threading
import time
from pathlib import Path
from typing import List, Optional, Tuple
from typing import Any
import numpy as np
@@ -35,7 +36,9 @@ class VectorStore:
self,
vector_dim: int = 384,
max_examples: int = 10000,
storage_path: Optional[str] = None,
max_submitted: int = 5000,
storage_path: str | None = None,
submitted_path: str | None = None,
score_multiplier: float = 5.0,
k: int = 3,
):
@@ -45,13 +48,17 @@ class VectorStore:
Args:
vector_dim: Размерность векторов
max_examples: Максимальное количество примеров каждого типа
max_submitted: Максимальное количество submitted-постов
storage_path: Путь для сохранения/загрузки векторов (опционально)
submitted_path: Путь для сохранения/загрузки submitted-постов (опционально)
score_multiplier: Множитель для масштабирования разницы в скорах
k: Количество ближайших примеров для расчета среднего сходства
"""
self.vector_dim = vector_dim
self.max_examples = max_examples
self.max_submitted = max_submitted
self.storage_path = storage_path
self.submitted_path = submitted_path
self.score_multiplier = score_multiplier
self.k = k
@@ -62,13 +69,22 @@ class VectorStore:
self._positive_hashes: list = [] # Хеши текстов для дедупликации
self._negative_hashes: list = []
# Submitted-посты (третья коллекция)
self._submitted_vectors: list = []
self._submitted_hashes: list = []
self._submitted_created_at: list = [] # Unix timestamps
self._submitted_post_ids: list = []
self._submitted_texts: list = []
self._submitted_rag_scores: list = []
# Lock для потокобезопасности
self._lock = threading.Lock()
# Пытаемся загрузить сохраненные векторы
# Всегда вызываем _load_from_disk если есть storage_path - он сам решит что загружать
if storage_path:
self._load_from_disk()
if submitted_path:
self._load_submitted_from_disk()
@property
def positive_count(self) -> int:
@@ -85,10 +101,15 @@ class VectorStore:
"""Общее количество примеров."""
return self.positive_count + self.negative_count
@property
def submitted_count(self) -> int:
"""Количество submitted-постов."""
return len(self._submitted_vectors)
@staticmethod
def compute_text_hash(text: str) -> str:
"""Вычисляет хеш текста для дедупликации."""
return hashlib.md5(text.encode('utf-8')).hexdigest()
return hashlib.md5(text.encode("utf-8")).hexdigest()
def _normalize_vector(self, vector: np.ndarray) -> np.ndarray:
"""Нормализует вектор для косинусного сходства."""
@@ -97,7 +118,7 @@ class VectorStore:
return vector
return vector / norm
def add_positive(self, vector: np.ndarray, text_hash: Optional[str] = None) -> bool:
def add_positive(self, vector: np.ndarray, text_hash: str | None = None) -> bool:
"""
Добавляет положительный пример (опубликованный пост).
@@ -127,13 +148,13 @@ class VectorStore:
if text_hash:
self._positive_hashes.append(text_hash)
logger.info(f"VectorStore: Добавлен положительный пример (всего: {self.positive_count})")
logger.info(
f"VectorStore: Добавлен положительный пример (всего: {self.positive_count})"
)
return True
def add_positive_batch(
self,
vectors: List[np.ndarray],
text_hashes: Optional[List[str]] = None
self, vectors: list[np.ndarray], text_hashes: list[str] | None = None
) -> int:
"""
Добавляет батч положительных примеров.
@@ -167,10 +188,12 @@ class VectorStore:
self._positive_hashes.append(text_hash)
added += 1
logger.info(f"VectorStore: Добавлено {added} положительных примеров батчем (всего: {self.positive_count})")
logger.info(
f"VectorStore: Добавлено {added} положительных примеров батчем (всего: {self.positive_count})"
)
return added
def add_negative(self, vector: np.ndarray, text_hash: Optional[str] = None) -> bool:
def add_negative(self, vector: np.ndarray, text_hash: str | None = None) -> bool:
"""
Добавляет отрицательный пример (отклоненный пост).
@@ -200,13 +223,13 @@ class VectorStore:
if text_hash:
self._negative_hashes.append(text_hash)
logger.info(f"VectorStore: Добавлен отрицательный пример (всего: {self.negative_count})")
logger.info(
f"VectorStore: Добавлен отрицательный пример (всего: {self.negative_count})"
)
return True
def add_negative_batch(
self,
vectors: List[np.ndarray],
text_hashes: Optional[List[str]] = None
self, vectors: list[np.ndarray], text_hashes: list[str] | None = None
) -> int:
"""
Добавляет батч отрицательных примеров.
@@ -240,10 +263,107 @@ class VectorStore:
self._negative_hashes.append(text_hash)
added += 1
logger.info(f"VectorStore: Добавлено {added} отрицательных примеров батчем (всего: {self.negative_count})")
logger.info(
f"VectorStore: Добавлено {added} отрицательных примеров батчем (всего: {self.negative_count})"
)
return added
def calculate_similarity_score(self, vector: np.ndarray) -> Tuple[float, float, float]:
def add_submitted(
self,
vector: np.ndarray,
text_hash: str,
created_at: int,
post_id: int | None = None,
text: str = "",
rag_score: float | None = None,
) -> bool:
"""
Добавляет submitted-пост в коллекцию.
Args:
vector: Векторное представление текста
text_hash: Хеш текста для дедупликации
created_at: Unix timestamp создания
post_id: ID поста (опционально)
text: Текст поста
rag_score: RAG скор поста (опционально)
Returns:
True если добавлен, False если дубликат
"""
with self._lock:
if text_hash in self._submitted_hashes:
logger.debug("VectorStore: Пропуск дубликата submitted-поста")
return False
if len(self._submitted_vectors) >= self.max_submitted:
self._submitted_vectors.pop(0)
self._submitted_hashes.pop(0)
self._submitted_created_at.pop(0)
self._submitted_post_ids.pop(0)
self._submitted_texts.pop(0)
self._submitted_rag_scores.pop(0)
logger.debug("VectorStore: Удален старый submitted-пост (лимит)")
normalized = self._normalize_vector(vector)
self._submitted_vectors.append(normalized)
self._submitted_hashes.append(text_hash)
self._submitted_created_at.append(created_at)
self._submitted_post_ids.append(post_id)
self._submitted_texts.append(text)
self._submitted_rag_scores.append(rag_score)
logger.info(f"VectorStore: Добавлен submitted-пост (всего: {self.submitted_count})")
return True
def find_similar_submitted(
self,
vector: np.ndarray,
threshold: float,
hours: int,
) -> list[dict[str, Any]]:
"""
Ищет похожие submitted-посты за последние N часов.
Args:
vector: Векторное представление запроса
threshold: Минимальный порог similarity (0.0 - 1.0)
hours: Количество часов для фильтрации (created_at >= now - hours*3600)
Returns:
Список dict с полями: similarity, created_at, post_id, text, rag_score
"""
with self._lock:
if self.submitted_count == 0:
return []
now = int(time.time())
cutoff = now - hours * 3600
normalized = self._normalize_vector(vector)
submitted_matrix = np.array(self._submitted_vectors)
similarities = np.dot(submitted_matrix, normalized)
results: list[dict[str, Any]] = []
for i, sim in enumerate(similarities):
if float(sim) < threshold:
continue
created_at = self._submitted_created_at[i]
if created_at < cutoff:
continue
results.append(
{
"similarity": float(sim),
"created_at": created_at,
"post_id": self._submitted_post_ids[i],
"text": self._submitted_texts[i],
"rag_score": self._submitted_rag_scores[i],
}
)
return sorted(results, key=lambda x: x["similarity"], reverse=True)
def calculate_similarity_score(self, vector: np.ndarray) -> tuple[float, float, float]:
"""
Рассчитывает скор на основе сходства с примерами.
@@ -266,15 +386,14 @@ class VectorStore:
"""
with self._lock:
if self.positive_count == 0:
raise InsufficientExamplesError(
"Нет положительных примеров для сравнения"
)
raise InsufficientExamplesError("Нет положительных примеров для сравнения")
# Нормализуем входной вектор
normalized = self._normalize_vector(vector)
normalized = self._normalize_vector(np.asarray(vector).flatten())
# Конвертируем в numpy массивы для быстрых вычислений
pos_matrix = np.array(self._positive_vectors)
# Используем vstack для гарантии одинаковой формы (совместимость со старым npz)
pos_matrix = np.vstack([np.asarray(v).flatten() for v in self._positive_vectors])
# Косинусное сходство с положительными примерами
# Для нормализованных векторов это просто скалярное произведение
@@ -282,7 +401,7 @@ class VectorStore:
# Косинусное сходство с отрицательными примерами
if self.negative_count > 0:
neg_matrix = np.array(self._negative_vectors)
neg_matrix = np.vstack([np.asarray(v).flatten() for v in self._negative_vectors])
neg_similarities = np.dot(neg_matrix, normalized)
else:
neg_similarities = np.array([])
@@ -345,7 +464,7 @@ class VectorStore:
return score, confidence, score_pos_only
def save_to_disk(self, path: Optional[str] = None) -> None:
def save_to_disk(self, path: str | None = None) -> None:
"""
Сохраняет векторы на диск.
@@ -363,8 +482,12 @@ class VectorStore:
# Сохраняем в npz формате
np.savez_compressed(
save_path,
positive_vectors=np.array(self._positive_vectors) if self._positive_vectors else np.array([]),
negative_vectors=np.array(self._negative_vectors) if self._negative_vectors else np.array([]),
positive_vectors=np.array(self._positive_vectors)
if self._positive_vectors
else np.array([]),
negative_vectors=np.array(self._negative_vectors)
if self._negative_vectors
else np.array([]),
positive_hashes=np.array(self._positive_hashes, dtype=object),
negative_hashes=np.array(self._negative_hashes, dtype=object),
vector_dim=self.vector_dim,
@@ -376,6 +499,87 @@ class VectorStore:
f"{self.negative_count} neg): {save_path}"
)
def save_submitted_to_disk(self, path: str | None = None) -> None:
"""
Сохраняет submitted-коллекцию на диск.
Args:
path: Путь для сохранения (если не указан, используется submitted_path)
"""
save_path = path or self.submitted_path
if not save_path:
raise VectorStoreError("Путь для сохранения submitted не указан")
with self._lock:
Path(save_path).parent.mkdir(parents=True, exist_ok=True)
np.savez_compressed(
save_path,
vectors=np.array(self._submitted_vectors)
if self._submitted_vectors
else np.array([]),
hashes=np.array(self._submitted_hashes, dtype=object),
created_at=np.array(self._submitted_created_at)
if self._submitted_created_at
else np.array([]),
post_ids=np.array(self._submitted_post_ids, dtype=object),
texts=np.array(self._submitted_texts, dtype=object),
rag_scores=np.array(self._submitted_rag_scores, dtype=object),
)
logger.info(f"VectorStore: Сохранено submitted ({self.submitted_count}): {save_path}")
def _load_submitted_from_disk(self) -> None:
"""Загружает submitted-коллекцию с диска."""
if not self.submitted_path or not os.path.exists(self.submitted_path):
return
try:
with self._lock:
data = np.load(self.submitted_path, allow_pickle=True)
vectors = data.get("vectors", np.array([]))
if vectors.size > 0:
if len(vectors.shape) == 2:
self._submitted_vectors = [
self._normalize_vector(np.array(v)) for v in vectors
]
elif len(vectors.shape) == 1:
self._submitted_vectors = [self._normalize_vector(np.array(vectors))]
else:
self._submitted_vectors = []
else:
self._submitted_vectors = []
hashes = data.get("hashes", np.array([]))
self._submitted_hashes = list(hashes) if hashes.size > 0 else []
created_at = data.get("created_at", np.array([]))
self._submitted_created_at = list(created_at) if created_at.size > 0 else []
post_ids = data.get("post_ids", np.array([]))
self._submitted_post_ids = list(post_ids) if post_ids.size > 0 else []
texts = data.get("texts", np.array([]))
self._submitted_texts = list(texts) if texts.size > 0 else []
rag_scores = data.get("rag_scores", np.array([]))
self._submitted_rag_scores = list(rag_scores) if rag_scores.size > 0 else []
# Выравниваем длины (на случай поврежденных данных)
n = len(self._submitted_vectors)
self._submitted_hashes = self._submitted_hashes[:n]
self._submitted_created_at = self._submitted_created_at[:n]
self._submitted_post_ids = self._submitted_post_ids[:n]
self._submitted_texts = self._submitted_texts[:n]
self._submitted_rag_scores = self._submitted_rag_scores[:n]
logger.info(
f"VectorStore: Загружено submitted ({self.submitted_count}): {self.submitted_path}"
)
except Exception as e:
logger.error(f"VectorStore: Ошибка загрузки submitted с диска: {e}")
def _load_from_disk(self) -> None:
"""Загружает векторы с диска."""
if not self.storage_path:
@@ -388,7 +592,9 @@ class VectorStore:
negative_npy = storage_dir / "negative_embeddings.npy"
# Отладочное логирование
logger.info(f"VectorStore: Проверка путей - storage_dir={storage_dir}, positive_npy={positive_npy}, exists={positive_npy.exists()}, negative_npy={negative_npy}, exists={negative_npy.exists()}")
logger.info(
f"VectorStore: Проверка путей - storage_dir={storage_dir}, positive_npy={positive_npy}, exists={positive_npy.exists()}, negative_npy={negative_npy}, exists={negative_npy.exists()}"
)
# Проверяем наличие отдельных .npy файлов
if positive_npy.exists() or negative_npy.exists():
@@ -406,9 +612,13 @@ class VectorStore:
# Один вектор [dim]
self._positive_vectors = [pos_vectors]
else:
logger.warning(f"VectorStore: Неожиданная размерность positive_embeddings.npy: {pos_vectors.shape}")
logger.warning(
f"VectorStore: Неожиданная размерность positive_embeddings.npy: {pos_vectors.shape}"
)
self._positive_vectors = []
logger.info(f"VectorStore: Загружено {len(self._positive_vectors)} положительных векторов из {positive_npy}")
logger.info(
f"VectorStore: Загружено {len(self._positive_vectors)} положительных векторов из {positive_npy}"
)
# Загружаем отрицательные векторы
if negative_npy.exists():
@@ -422,13 +632,21 @@ class VectorStore:
# Один вектор [dim]
self._negative_vectors = [neg_vectors]
else:
logger.warning(f"VectorStore: Неожиданная размерность negative_embeddings.npy: {neg_vectors.shape}")
logger.warning(
f"VectorStore: Неожиданная размерность negative_embeddings.npy: {neg_vectors.shape}"
)
self._negative_vectors = []
logger.info(f"VectorStore: Загружено {len(self._negative_vectors)} отрицательных векторов из {negative_npy}")
logger.info(
f"VectorStore: Загружено {len(self._negative_vectors)} отрицательных векторов из {negative_npy}"
)
# Нормализуем загруженные векторы
self._positive_vectors = [self._normalize_vector(np.array(v)) for v in self._positive_vectors]
self._negative_vectors = [self._normalize_vector(np.array(v)) for v in self._negative_vectors]
self._positive_vectors = [
self._normalize_vector(np.array(v)) for v in self._positive_vectors
]
self._negative_vectors = [
self._normalize_vector(np.array(v)) for v in self._negative_vectors
]
logger.info(
f"VectorStore: Загружено с диска из .npy файлов ({self.positive_count} pos, "
@@ -438,12 +656,14 @@ class VectorStore:
# Если отдельных .npy файлов нет, пытаемся загрузить из старого формата .npz
if os.path.exists(self.storage_path):
logger.info(f"VectorStore: Загружаем из старого формата .npz: {self.storage_path}")
logger.info(
f"VectorStore: Загружаем из старого формата .npz: {self.storage_path}"
)
data = np.load(self.storage_path, allow_pickle=True)
# Загружаем векторы
pos_vectors = data.get('positive_vectors', np.array([]))
neg_vectors = data.get('negative_vectors', np.array([]))
pos_vectors = data.get("positive_vectors", np.array([]))
neg_vectors = data.get("negative_vectors", np.array([]))
if pos_vectors.size > 0:
self._positive_vectors = list(pos_vectors)
@@ -451,8 +671,8 @@ class VectorStore:
self._negative_vectors = list(neg_vectors)
# Загружаем хеши
pos_hashes = data.get('positive_hashes', np.array([]))
neg_hashes = data.get('negative_hashes', np.array([]))
pos_hashes = data.get("positive_hashes", np.array([]))
neg_hashes = data.get("negative_hashes", np.array([]))
if pos_hashes.size > 0:
self._positive_hashes = list(pos_hashes)
@@ -475,6 +695,12 @@ class VectorStore:
self._negative_vectors.clear()
self._positive_hashes.clear()
self._negative_hashes.clear()
self._submitted_vectors.clear()
self._submitted_hashes.clear()
self._submitted_created_at.clear()
self._submitted_post_ids.clear()
self._submitted_texts.clear()
self._submitted_rag_scores.clear()
logger.info("VectorStore: Хранилище очищено")
def get_stats(self) -> dict:
@@ -483,8 +709,10 @@ class VectorStore:
"positive_count": self.positive_count,
"negative_count": self.negative_count,
"total_count": self.total_count,
"submitted_count": self.submitted_count,
"vector_dim": self.vector_dim,
"max_examples": self.max_examples,
"max_submitted": self.max_submitted,
}
def get_scoring_params(self) -> dict:
@@ -496,8 +724,8 @@ class VectorStore:
def update_scoring_params(
self,
score_multiplier: Optional[float] = None,
k: Optional[int] = None,
score_multiplier: float | None = None,
k: int | None = None,
) -> dict:
"""
Обновляет параметры формулы расчета score.

View File

@@ -20,6 +20,8 @@ services:
- RAG_CACHE_DIR=/app/data/models
- RAG_VECTORS_PATH=/app/data/vectors/vectors.npz
- RAG_MAX_EXAMPLES=${RAG_MAX_EXAMPLES:-10000}
- RAG_MAX_SUBMITTED=${RAG_MAX_SUBMITTED:-5000}
- RAG_SUBMITTED_PATH=/app/data/vectors/submitted.npz
- RAG_SCORE_MULTIPLIER=${RAG_SCORE_MULTIPLIER:-5.0}
- RAG_BATCH_SIZE=${RAG_BATCH_SIZE:-16}
- RAG_MIN_TEXT_LENGTH=${RAG_MIN_TEXT_LENGTH:-3}

View File

@@ -7,6 +7,8 @@ RAG_CACHE_DIR=data/models
# VectorStore
RAG_VECTORS_PATH=data/vectors/vectors.npz
RAG_MAX_EXAMPLES=10000
RAG_MAX_SUBMITTED=5000
RAG_SUBMITTED_PATH=data/vectors/submitted.npz
RAG_SCORE_MULTIPLIER=5.0
# Батч-обработка

16
tests/conftest.py Normal file
View File

@@ -0,0 +1,16 @@
"""
Pytest fixtures для RAG сервиса.
"""
import os
import pytest
@pytest.fixture(autouse=True)
def allow_no_auth():
"""Разрешает запросы без API ключа в тестах."""
os.environ["RAG_ALLOW_NO_AUTH"] = "true"
yield
if "RAG_ALLOW_NO_AUTH" in os.environ:
del os.environ["RAG_ALLOW_NO_AUTH"]

View File

@@ -0,0 +1,169 @@
"""
Тесты API endpoints /similar и /submitted.
"""
from unittest.mock import AsyncMock, MagicMock
import pytest
from fastapi.testclient import TestClient
from app.main import app
@pytest.fixture
def mock_rag_service():
"""Mock RAGService для тестов API без загрузки модели."""
service = MagicMock()
service.is_model_loaded = True
service.vector_store.submitted_count = 0
return service
@pytest.fixture
def client(mock_rag_service):
"""TestClient с переопределённым RAG сервисом."""
from app.api.routes import get_service
def override_get_service():
return mock_rag_service
app.dependency_overrides[get_service] = override_get_service
# get_rag_service используется при создании сервиса - get_service вызывает get_rag_service
# Смотрю routes: get_service возвращает get_rag_service(). Значит override get_service достаточно.
with TestClient(app) as c:
yield c
app.dependency_overrides.clear()
def test_similar_endpoint(client, mock_rag_service):
"""POST /api/v1/similar возвращает похожие посты."""
mock_rag_service.find_similar_posts = AsyncMock(
return_value=[
{
"similarity": 0.95,
"created_at": 1706270000,
"post_id": 123,
"text": "Похожий пост",
"rag_score": 0.85,
}
]
)
response = client.post(
"/api/v1/similar",
json={"text": "Текст для поиска", "threshold": 0.9, "hours": 24},
)
assert response.status_code == 200
data = response.json()
assert data["similar_count"] == 1
assert len(data["similar_posts"]) == 1
assert data["similar_posts"][0]["similarity"] == 0.95
assert data["similar_posts"][0]["post_id"] == 123
assert data["similar_posts"][0]["text"] == "Похожий пост"
assert data["similar_posts"][0]["rag_score"] == 0.85
def test_similar_endpoint_empty(client, mock_rag_service):
"""POST /api/v1/similar с пустым результатом."""
mock_rag_service.find_similar_posts = AsyncMock(return_value=[])
response = client.post(
"/api/v1/similar",
json={"text": "Уникальный текст", "threshold": 0.99, "hours": 1},
)
assert response.status_code == 200
assert response.json()["similar_count"] == 0
assert response.json()["similar_posts"] == []
def test_similar_endpoint_default_params(client, mock_rag_service):
"""POST /api/v1/similar с дефолтными параметрами."""
mock_rag_service.find_similar_posts = AsyncMock(return_value=[])
response = client.post(
"/api/v1/similar",
json={"text": "Текст"},
)
assert response.status_code == 200
mock_rag_service.find_similar_posts.assert_called_once_with(
text="Текст",
threshold=0.9,
hours=24,
)
def test_submitted_endpoint_success(client, mock_rag_service):
"""POST /api/v1/submitted успешно добавляет пост."""
mock_rag_service.add_submitted_post = AsyncMock(return_value=True)
mock_rag_service.vector_store.submitted_count = 5
response = client.post(
"/api/v1/submitted",
json={"text": "Новый пост", "post_id": 42, "rag_score": 0.8},
)
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert "добавлен" in data["message"].lower()
assert data["submitted_count"] == 5
mock_rag_service.add_submitted_post.assert_called_once_with(
text="Новый пост",
post_id=42,
rag_score=0.8,
)
def test_submitted_endpoint_duplicate(client, mock_rag_service):
"""POST /api/v1/submitted при дубликате возвращает success=False."""
mock_rag_service.add_submitted_post = AsyncMock(return_value=False)
mock_rag_service.vector_store.submitted_count = 10
response = client.post(
"/api/v1/submitted",
json={"text": "Дубликат поста"},
)
assert response.status_code == 200
data = response.json()
assert data["success"] is False
assert data["submitted_count"] == 10
def test_submitted_endpoint_optional_fields(client, mock_rag_service):
"""POST /api/v1/submitted с только текстом (post_id, rag_score опциональны)."""
mock_rag_service.add_submitted_post = AsyncMock(return_value=True)
mock_rag_service.vector_store.submitted_count = 1
response = client.post(
"/api/v1/submitted",
json={"text": "Только текст"},
)
assert response.status_code == 200
mock_rag_service.add_submitted_post.assert_called_once_with(
text="Только текст",
post_id=None,
rag_score=None,
)
def test_similar_validation_empty_text(client):
"""POST /api/v1/similar с пустым текстом возвращает 422."""
response = client.post(
"/api/v1/similar",
json={"text": "", "threshold": 0.9, "hours": 24},
)
assert response.status_code == 422
def test_submitted_validation_empty_text(client):
"""POST /api/v1/submitted с пустым текстом возвращает 422."""
response = client.post(
"/api/v1/submitted",
json={"text": ""},
)
assert response.status_code == 422

View File

@@ -0,0 +1,212 @@
"""
Тесты для submitted-коллекции VectorStore.
"""
import numpy as np
import pytest
from app.storage.vector_store import VectorStore
@pytest.fixture
def vector_store(tmp_path):
"""VectorStore с временным путём для submitted."""
return VectorStore(
vector_dim=4,
max_examples=10,
max_submitted=5,
storage_path=None,
submitted_path=str(tmp_path / "submitted.npz"),
)
@pytest.fixture
def sample_vector():
"""Нормализованный вектор для тестов."""
v = np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32)
return v / np.linalg.norm(v)
def test_add_submitted(vector_store, sample_vector):
"""Добавление submitted-поста."""
added = vector_store.add_submitted(
vector=sample_vector,
text_hash="abc123",
created_at=1000,
post_id=42,
text="Test post",
rag_score=0.85,
)
assert added is True
assert vector_store.submitted_count == 1
def test_add_submitted_duplicate(vector_store, sample_vector):
"""Дубликат по хешу не добавляется."""
vector_store.add_submitted(
vector=sample_vector,
text_hash="same_hash",
created_at=1000,
text="First",
)
added = vector_store.add_submitted(
vector=sample_vector,
text_hash="same_hash",
created_at=2000,
text="Second",
)
assert added is False
assert vector_store.submitted_count == 1
def test_add_submitted_fifo(vector_store, sample_vector):
"""При превышении max_submitted удаляется самый старый (FIFO)."""
for i in range(7):
v = np.array(
[float(i + 1), 0.0, 0.0, 0.0], dtype=np.float32
) # i+1 чтобы избежать нулевого вектора
v = v / np.linalg.norm(v)
vector_store.add_submitted(
vector=v,
text_hash=f"hash_{i}",
created_at=1000 + i,
post_id=i,
text=f"Post {i}",
)
assert vector_store.submitted_count == 5 # max_submitted
# Должны остаться посты 2, 3, 4, 5, 6 (удалены 0, 1)
post_ids = vector_store._submitted_post_ids
assert 0 not in post_ids
assert 1 not in post_ids
assert 2 in post_ids
def test_find_similar_submitted_empty(vector_store, sample_vector):
"""Поиск в пустой коллекции возвращает пустой список."""
result = vector_store.find_similar_submitted(
vector=sample_vector,
threshold=0.5,
hours=24,
)
assert result == []
def test_find_similar_submitted(vector_store, sample_vector):
"""Поиск похожих постов с фильтром по времени и threshold."""
import time
now = int(time.time())
# Похожий вектор
similar_v = np.array([0.99, 0.01, 0.0, 0.0], dtype=np.float32)
similar_v = similar_v / np.linalg.norm(similar_v)
# Непохожий вектор
different_v = np.array([0.0, 1.0, 0.0, 0.0], dtype=np.float32)
different_v = different_v / np.linalg.norm(different_v)
vector_store.add_submitted(
vector=similar_v,
text_hash="similar",
created_at=now - 3600, # 1 час назад
post_id=1,
text="Similar post",
rag_score=0.9,
)
vector_store.add_submitted(
vector=different_v,
text_hash="different",
created_at=now - 3600,
post_id=2,
text="Different post",
rag_score=0.5,
)
result = vector_store.find_similar_submitted(
vector=sample_vector,
threshold=0.9,
hours=24,
)
assert len(result) == 1
assert result[0]["post_id"] == 1
assert result[0]["text"] == "Similar post"
assert result[0]["similarity"] >= 0.9
def test_find_similar_submitted_time_filter(vector_store, sample_vector):
"""Фильтр по hours исключает старые посты."""
import time
now = int(time.time())
vector_store.add_submitted(
vector=sample_vector,
text_hash="old",
created_at=now - 48 * 3600, # 48 часов назад
post_id=1,
text="Old post",
)
vector_store.add_submitted(
vector=sample_vector,
text_hash="recent",
created_at=now - 3600, # 1 час назад
post_id=2,
text="Recent post",
)
result = vector_store.find_similar_submitted(
vector=sample_vector,
threshold=0.5,
hours=24,
)
assert len(result) == 1
assert result[0]["post_id"] == 2
def test_submitted_persistence(vector_store, sample_vector, tmp_path):
"""Сохранение и загрузка submitted-коллекции."""
vector_store.add_submitted(
vector=sample_vector,
text_hash="persist",
created_at=12345,
post_id=999,
text="Persisted post",
rag_score=0.77,
)
vector_store.save_submitted_to_disk()
# Новый store загружает данные
store2 = VectorStore(
vector_dim=4,
max_submitted=5,
storage_path=None,
submitted_path=str(tmp_path / "submitted.npz"),
)
assert store2.submitted_count == 1
assert store2._submitted_post_ids[0] == 999
assert store2._submitted_texts[0] == "Persisted post"
assert store2._submitted_rag_scores[0] == 0.77
def test_get_stats_includes_submitted(vector_store, sample_vector):
"""get_stats включает submitted_count и max_submitted."""
vector_store.add_submitted(
vector=sample_vector,
text_hash="stat",
created_at=1000,
text="For stats",
)
stats = vector_store.get_stats()
assert "submitted_count" in stats
assert stats["submitted_count"] == 1
assert "max_submitted" in stats
assert stats["max_submitted"] == 5
def test_clear_submitted(vector_store, sample_vector):
"""clear() очищает submitted-коллекцию."""
vector_store.add_submitted(
vector=sample_vector,
text_hash="clear",
created_at=1000,
text="To clear",
)
vector_store.clear()
assert vector_store.submitted_count == 0