Compare commits

...

1 Commits

Author SHA1 Message Date
a1d6d2d860 feat: add submitted collection, /similar and /submitted endpoints (Stage 4)
Made-with: Cursor
2026-02-28 19:00:22 +03:00
15 changed files with 1308 additions and 400 deletions

2
.gitignore vendored
View File

@@ -133,6 +133,8 @@ Thumbs.db
# Project specific
data/models/
data/vectors/*.npz
*.bak
*.tar.gz
# Keep data directories
!data/models/.gitkeep

View File

@@ -14,9 +14,9 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
# Копируем зависимости
COPY requirements.txt .
# Устанавливаем зависимости
# --no-cache-dir для уменьшения размера образа
RUN pip install --no-cache-dir -r requirements.txt
# Устанавливаем зависимости (CPU-only torch для контейнеров без GPU)
RUN pip install --no-cache-dir torch --index-url https://download.pytorch.org/whl/cpu \
&& pip install --no-cache-dir -r requirements.txt
# Копируем код приложения
COPY app/ ./app/

View File

@@ -24,14 +24,14 @@ async def verify_api_key(
) -> bool:
"""
Проверяет API ключ из заголовка запроса.
Args:
api_key: Ключ из заголовка X-API-Key
settings: Настройки приложения
Returns:
True если авторизация успешна
Raises:
HTTPException: Если ключ неверный или отсутствует
"""
@@ -47,7 +47,7 @@ async def verify_api_key(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="API ключ не настроен на сервере",
)
# Проверяем ключ
if api_key is None:
logger.warning("Запрос без API ключа")
@@ -56,14 +56,14 @@ async def verify_api_key(
detail="API ключ не предоставлен. Используйте заголовок X-API-Key",
headers={"WWW-Authenticate": "ApiKey"},
)
if api_key != settings.api_key:
logger.warning("Неверный API ключ")
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Неверный API ключ",
)
return True

View File

@@ -24,7 +24,12 @@ from app.schemas import (
ScoreRequest,
ScoreResponse,
ScoringParamsResponse,
SimilarPostItem,
SimilarRequest,
SimilarResponse,
StatsResponse,
SubmittedRequest,
SubmittedResponse,
UpdateScoringParamsRequest,
VectorStoreStats,
WarmupResponse,
@@ -49,6 +54,7 @@ RAGServiceDep = Annotated[RAGService, Depends(get_service)]
# Health Check
# =============================================================================
@router.get(
"/health",
response_model=HealthResponse,
@@ -58,7 +64,7 @@ RAGServiceDep = Annotated[RAGService, Depends(get_service)]
async def health_check(service: RAGServiceDep, _auth: AuthDep) -> HealthResponse:
"""
Проверяет состояние сервиса.
Returns:
HealthResponse: Статус сервиса
"""
@@ -73,6 +79,7 @@ async def health_check(service: RAGServiceDep, _auth: AuthDep) -> HealthResponse
# Scoring
# =============================================================================
@router.post(
"/score",
response_model=ScoreResponse,
@@ -86,55 +93,55 @@ async def health_check(service: RAGServiceDep, _auth: AuthDep) -> HealthResponse
tags=["scoring"],
)
async def calculate_score(
request: ScoreRequest,
request: ScoreRequest,
service: RAGServiceDep,
_auth: AuthDep,
) -> ScoreResponse:
"""
Рассчитывает скор для текста поста.
Args:
request: Запрос с текстом
service: RAG сервис
Returns:
ScoreResponse: Результат скоринга
Raises:
HTTPException: При ошибке расчета
"""
try:
result = await service.calculate_score(request.text)
response_dict = result.to_dict()
return ScoreResponse(
rag_score=response_dict["rag_score"],
rag_confidence=response_dict["rag_confidence"],
rag_score_pos_only=response_dict["rag_score_pos_only"],
meta=ScoreMetadata(**response_dict["meta"]),
)
except TextTooShortError as e:
logger.warning(f"Текст слишком короткий: {e}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"detail": str(e), "error_type": "TextTooShortError"},
)
except InsufficientExamplesError as e:
logger.warning(f"Недостаточно примеров: {e}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"detail": str(e), "error_type": "InsufficientExamplesError"},
)
except ModelNotLoadedError as e:
logger.error(f"Модель не загружена: {e}")
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail={"detail": str(e), "error_type": "ModelNotLoadedError"},
)
except ScoringError as e:
logger.error(f"Ошибка скоринга: {e}")
raise HTTPException(
@@ -147,6 +154,7 @@ async def calculate_score(
# Examples
# =============================================================================
@router.post(
"/examples/positive",
response_model=ExampleResponse,
@@ -159,27 +167,27 @@ async def calculate_score(
tags=["examples"],
)
async def add_positive_example(
request: ExampleRequest,
request: ExampleRequest,
service: RAGServiceDep,
_auth: AuthDep,
x_test_mode: str | None = Header(default=None, alias="X-Test-Mode"),
) -> ExampleResponse:
"""
Добавляет текст как положительный пример (опубликованный пост).
При наличии заголовка X-Test-Mode: true пример НЕ сохраняется (тестовый режим).
Args:
request: Запрос с текстом
service: RAG сервис
x_test_mode: Заголовок тестового режима
Returns:
ExampleResponse: Результат добавления
"""
# Тестовый режим — не сохраняем примеры
is_test = x_test_mode and x_test_mode.lower() == "true"
if is_test:
logger.info("Тестовый режим: положительный пример НЕ сохранён")
return ExampleResponse(
@@ -188,22 +196,22 @@ async def add_positive_example(
positive_count=service.vector_store.positive_count,
negative_count=service.vector_store.negative_count,
)
try:
added = await service.add_positive_example(request.text)
if added:
message = "Положительный пример добавлен"
else:
message = "Пример не добавлен (дубликат или слишком короткий текст)"
return ExampleResponse(
success=added,
message=message,
positive_count=service.vector_store.positive_count,
negative_count=service.vector_store.negative_count,
)
except ModelNotLoadedError as e:
logger.error(f"Модель не загружена: {e}")
raise HTTPException(
@@ -224,27 +232,27 @@ async def add_positive_example(
tags=["examples"],
)
async def add_negative_example(
request: ExampleRequest,
request: ExampleRequest,
service: RAGServiceDep,
_auth: AuthDep,
x_test_mode: str | None = Header(default=None, alias="X-Test-Mode"),
) -> ExampleResponse:
"""
Добавляет текст как отрицательный пример (отклоненный пост).
При наличии заголовка X-Test-Mode: true пример НЕ сохраняется (тестовый режим).
Args:
request: Запрос с текстом
service: RAG сервис
x_test_mode: Заголовок тестового режима
Returns:
ExampleResponse: Результат добавления
"""
# Тестовый режим — не сохраняем примеры
is_test = x_test_mode and x_test_mode.lower() == "true"
if is_test:
logger.info("Тестовый режим: отрицательный пример НЕ сохранён")
return ExampleResponse(
@@ -253,22 +261,128 @@ async def add_negative_example(
positive_count=service.vector_store.positive_count,
negative_count=service.vector_store.negative_count,
)
try:
added = await service.add_negative_example(request.text)
if added:
message = "Отрицательный пример добавлен"
else:
message = "Пример не добавлен (дубликат или слишком короткий текст)"
return ExampleResponse(
success=added,
message=message,
positive_count=service.vector_store.positive_count,
negative_count=service.vector_store.negative_count,
)
except ModelNotLoadedError as e:
logger.error(f"Модель не загружена: {e}")
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail={"detail": str(e), "error_type": "ModelNotLoadedError"},
)
# =============================================================================
# Similar & Submitted
# =============================================================================
@router.post(
"/similar",
response_model=SimilarResponse,
responses={
400: {"model": ErrorResponse, "description": "Ошибка в запросе"},
401: {"model": ErrorResponse, "description": "Не авторизован"},
403: {"model": ErrorResponse, "description": "Доступ запрещён"},
503: {"model": ErrorResponse, "description": "Сервис недоступен"},
},
summary="Поиск похожих постов",
tags=["similar"],
)
async def find_similar_posts(
request: SimilarRequest,
service: RAGServiceDep,
_auth: AuthDep,
) -> SimilarResponse:
"""
Ищет похожие submitted-посты за последние N часов.
Args:
request: Запрос с текстом, threshold и hours
service: RAG сервис
Returns:
SimilarResponse: Список похожих постов
"""
try:
similar = await service.find_similar_posts(
text=request.text,
threshold=request.threshold,
hours=request.hours,
)
return SimilarResponse(
similar_count=len(similar),
similar_posts=[SimilarPostItem(**item) for item in similar],
)
except TextTooShortError as e:
logger.warning(f"Текст слишком короткий: {e}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"detail": str(e), "error_type": "TextTooShortError"},
)
except ModelNotLoadedError as e:
logger.error(f"Модель не загружена: {e}")
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail={"detail": str(e), "error_type": "ModelNotLoadedError"},
)
@router.post(
"/submitted",
response_model=SubmittedResponse,
responses={
400: {"model": ErrorResponse, "description": "Ошибка в запросе"},
401: {"model": ErrorResponse, "description": "Не авторизован"},
403: {"model": ErrorResponse, "description": "Доступ запрещён"},
503: {"model": ErrorResponse, "description": "Сервис недоступен"},
},
summary="Добавить submitted-пост",
tags=["submitted"],
)
async def add_submitted_post(
request: SubmittedRequest,
service: RAGServiceDep,
_auth: AuthDep,
) -> SubmittedResponse:
"""
Добавляет submitted-пост в коллекцию для индексации ботом.
Args:
request: Запрос с текстом, post_id и rag_score
service: RAG сервис
Returns:
SubmittedResponse: Результат добавления
"""
try:
added = await service.add_submitted_post(
text=request.text,
post_id=request.post_id,
rag_score=request.rag_score,
)
if added:
message = "Submitted-пост добавлен"
else:
message = "Пост не добавлен (дубликат или слишком короткий текст)"
return SubmittedResponse(
success=added,
message=message,
submitted_count=service.vector_store.submitted_count,
)
except ModelNotLoadedError as e:
logger.error(f"Модель не загружена: {e}")
raise HTTPException(
@@ -281,6 +395,7 @@ async def add_negative_example(
# Stats & Warmup
# =============================================================================
@router.get(
"/stats",
response_model=StatsResponse,
@@ -294,15 +409,15 @@ async def add_negative_example(
async def get_stats(service: RAGServiceDep, _auth: AuthDep) -> StatsResponse:
"""
Возвращает статистику сервиса.
Args:
service: RAG сервис
Returns:
StatsResponse: Статистика
"""
stats = service.get_stats()
return StatsResponse(
model_name=stats["model_name"],
model_loaded=stats["model_loaded"],
@@ -325,15 +440,15 @@ async def get_stats(service: RAGServiceDep, _auth: AuthDep) -> StatsResponse:
async def warmup(service: RAGServiceDep, _auth: AuthDep) -> WarmupResponse:
"""
Прогревает модель (загружает если не загружена).
Args:
service: RAG сервис
Returns:
WarmupResponse: Результат прогрева
"""
success = await service.warmup()
if success:
message = "Модель успешно загружена"
else:
@@ -342,7 +457,7 @@ async def warmup(service: RAGServiceDep, _auth: AuthDep) -> WarmupResponse:
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail={"detail": message, "error_type": "ModelNotLoadedError"},
)
return WarmupResponse(
success=success,
model_loaded=service.is_model_loaded,
@@ -354,6 +469,7 @@ async def warmup(service: RAGServiceDep, _auth: AuthDep) -> WarmupResponse:
# Scoring Parameters
# =============================================================================
@router.get(
"/scoring/params",
response_model=ScoringParamsResponse,
@@ -370,10 +486,10 @@ async def get_scoring_params(
) -> ScoringParamsResponse:
"""
Возвращает текущие параметры формулы расчета score.
Args:
service: RAG сервис
Returns:
ScoringParamsResponse: Текущие параметры формулы
"""
@@ -399,17 +515,17 @@ async def update_scoring_params(
) -> ScoringParamsResponse:
"""
Обновляет параметры формулы расчета score.
Можно обновить один или несколько параметров одновременно.
Параметры, которые не указаны, остаются без изменений.
Args:
request: Запрос с новыми параметрами
service: RAG сервис
Returns:
ScoringParamsResponse: Обновленные параметры формулы
Raises:
HTTPException: При невалидных значениях параметров
"""

View File

@@ -5,82 +5,69 @@
import os
import secrets
from dataclasses import dataclass, field
from typing import Optional
@dataclass
class Settings:
"""
Настройки RAG сервиса.
Все параметры загружаются из переменных окружения.
"""
# Модель
model_name: str = field(
default_factory=lambda: os.getenv("RAG_MODEL", "sentence-transformers/all-MiniLM-L12-v2")
)
cache_dir: str = field(
default_factory=lambda: os.getenv("RAG_CACHE_DIR", "data/models")
)
cache_dir: str = field(default_factory=lambda: os.getenv("RAG_CACHE_DIR", "data/models"))
# VectorStore
vectors_path: str = field(
default_factory=lambda: os.getenv("RAG_VECTORS_PATH", "data/vectors/vectors.npz")
)
max_examples: int = field(
default_factory=lambda: int(os.getenv("RAG_MAX_EXAMPLES", "10000"))
max_examples: int = field(default_factory=lambda: int(os.getenv("RAG_MAX_EXAMPLES", "10000")))
max_submitted: int = field(default_factory=lambda: int(os.getenv("RAG_MAX_SUBMITTED", "5000")))
submitted_path: str = field(
default_factory=lambda: os.getenv("RAG_SUBMITTED_PATH", "data/vectors/submitted.npz")
)
score_multiplier: float = field(
default_factory=lambda: float(os.getenv("RAG_SCORE_MULTIPLIER", "5.0"))
)
# Батч-обработка
batch_size: int = field(
default_factory=lambda: int(os.getenv("RAG_BATCH_SIZE", "16"))
)
batch_size: int = field(default_factory=lambda: int(os.getenv("RAG_BATCH_SIZE", "16")))
# Минимальная длина текста
min_text_length: int = field(
default_factory=lambda: int(os.getenv("RAG_MIN_TEXT_LENGTH", "3"))
)
min_text_length: int = field(default_factory=lambda: int(os.getenv("RAG_MIN_TEXT_LENGTH", "3")))
# API настройки
api_host: str = field(
default_factory=lambda: os.getenv("RAG_API_HOST", "0.0.0.0")
)
api_port: int = field(
default_factory=lambda: int(os.getenv("RAG_API_PORT", "8000"))
)
api_host: str = field(default_factory=lambda: os.getenv("RAG_API_HOST", "0.0.0.0"))
api_port: int = field(default_factory=lambda: int(os.getenv("RAG_API_PORT", "8000")))
# Безопасность
# API ключ для авторизации (обязателен в продакшене!)
api_key: Optional[str] = field(
default_factory=lambda: os.getenv("RAG_API_KEY")
)
api_key: str | None = field(default_factory=lambda: os.getenv("RAG_API_KEY"))
# Разрешить запросы без ключа (только для разработки)
allow_no_auth: bool = field(
default_factory=lambda: os.getenv("RAG_ALLOW_NO_AUTH", "false").lower() == "true"
)
# Логирование
log_level: str = field(
default_factory=lambda: os.getenv("LOG_LEVEL", "INFO")
)
log_level: str = field(default_factory=lambda: os.getenv("LOG_LEVEL", "INFO"))
# Автосохранение (интервал в секундах, 0 = отключено)
autosave_interval: int = field(
default_factory=lambda: int(os.getenv("RAG_AUTOSAVE_INTERVAL", "600")) # 10 минут
)
# Размерность векторов (384 для all-MiniLM-L12-v2)
vector_dim: int = 384
@property
def is_auth_required(self) -> bool:
"""Проверяет, требуется ли авторизация."""
return self.api_key is not None and not self.allow_no_auth
@staticmethod
def generate_api_key() -> str:
"""Генерирует случайный API ключ."""
@@ -88,13 +75,13 @@ class Settings:
# Глобальный экземпляр настроек
_settings: Optional[Settings] = None
_settings: Settings | None = None
def get_settings() -> Settings:
"""
Возвращает глобальный экземпляр настроек.
Returns:
Settings: Настройки приложения
"""

View File

@@ -5,29 +5,35 @@
class RAGServiceError(Exception):
"""Базовое исключение для ошибок RAG сервиса."""
pass
class ModelNotLoadedError(RAGServiceError):
"""Модель не загружена или недоступна."""
pass
class VectorStoreError(RAGServiceError):
"""Ошибка при работе с хранилищем векторов."""
pass
class InsufficientExamplesError(RAGServiceError):
"""Недостаточно примеров для расчета скора."""
pass
class TextTooShortError(RAGServiceError):
"""Текст слишком короткий для векторизации."""
pass
class ScoringError(RAGServiceError):
"""Ошибка при расчете скора."""
pass

View File

@@ -7,8 +7,8 @@ FastAPI приложение Embedding сервиса.
import asyncio
import logging
import sys
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from typing import AsyncGenerator, Optional
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
@@ -18,11 +18,12 @@ from app.api.routes import router
from app.config import get_settings
from app.services.rag_service import RAGService, get_rag_service
# Настройка логирования
def setup_logging() -> None:
"""Настраивает логирование для приложения."""
settings = get_settings()
logging.basicConfig(
level=getattr(logging, settings.log_level.upper()),
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
@@ -30,7 +31,7 @@ def setup_logging() -> None:
logging.StreamHandler(sys.stdout),
],
)
# Уменьшаем логи от библиотек
logging.getLogger("transformers").setLevel(logging.WARNING)
logging.getLogger("torch").setLevel(logging.WARNING)
@@ -40,33 +41,40 @@ def setup_logging() -> None:
logger = logging.getLogger(__name__)
# Глобальная задача автосохранения
_autosave_task: Optional[asyncio.Task] = None
_autosave_task: asyncio.Task | None = None
async def autosave_loop(service: RAGService, interval: int) -> None:
"""
Фоновая задача для периодического сохранения векторов.
Args:
service: RAG сервис
interval: Интервал сохранения в секундах
"""
logger.info(f"Автосохранение запущено (интервал: {interval} сек)")
while True:
try:
await asyncio.sleep(interval)
# Сохраняем только если есть данные
if service.vector_store.total_count > 0:
has_examples = service.vector_store.total_count > 0
has_submitted = service.vector_store.submitted_count > 0
if has_examples or has_submitted:
service.save_vectors()
logger.info(
f"Автосохранение: сохранено {service.vector_store.positive_count} pos, "
f"{service.vector_store.negative_count} neg"
)
parts = []
if has_examples:
parts.append(
f"{service.vector_store.positive_count} pos, "
f"{service.vector_store.negative_count} neg"
)
if has_submitted:
parts.append(f"{service.vector_store.submitted_count} submitted")
logger.info(f"Автосохранение: сохранено {', '.join(parts)}")
else:
logger.debug("Автосохранение: нет данных для сохранения")
except asyncio.CancelledError:
logger.info("Автосохранение остановлено")
break
@@ -79,43 +87,41 @@ async def autosave_loop(service: RAGService, interval: int) -> None:
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
"""
Lifespan контекст для FastAPI.
При запуске:
- Настраивает логирование
- Прогревает модель (опционально)
При остановке:
- Сохраняет векторы на диск
"""
global _autosave_task
setup_logging()
logger.info(f"Embedding Service v{__version__} запускается...")
settings = get_settings()
logger.info(f"Настройки: model={settings.model_name}, vectors_path={settings.vectors_path}")
# Получаем сервис (создается singleton)
service = get_rag_service()
# Запускаем автосохранение если включено
if settings.autosave_interval > 0:
_autosave_task = asyncio.create_task(
autosave_loop(service, settings.autosave_interval)
)
_autosave_task = asyncio.create_task(autosave_loop(service, settings.autosave_interval))
logger.info(f"Автосохранение включено: каждые {settings.autosave_interval} сек")
else:
logger.info("Автосохранение отключено")
# Прогреваем модель при запуске (опционально)
# Можно раскомментировать если нужен автопрогрев
# logger.info("Прогрев модели при запуске...")
# await service.warmup()
logger.info("Embedding Service готов к работе")
yield
# Останавливаем автосохранение
if _autosave_task and not _autosave_task.done():
_autosave_task.cancel()
@@ -123,7 +129,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
await _autosave_task
except asyncio.CancelledError:
pass
# При остановке сохраняем векторы
logger.info("Embedding Service останавливается, финальное сохранение векторов...")
try:
@@ -131,7 +137,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
logger.info("Векторы сохранены")
except Exception as e:
logger.error(f"Ошибка сохранения векторов: {e}")
logger.info("Embedding Service остановлен")
@@ -176,19 +182,21 @@ app.add_middleware(
allow_headers=["*"],
)
# Простой healthcheck endpoint без авторизации (для Docker healthcheck)
@app.get("/health")
async def simple_health_check():
"""Простая проверка здоровья без авторизации (для Docker healthcheck)."""
return {"status": "ok"}
# Подключение роутов
app.include_router(router, prefix="/api/v1")
if __name__ == "__main__":
import uvicorn
settings = get_settings()
uvicorn.run(
"app.main:app",

View File

@@ -2,36 +2,66 @@
Pydantic схемы для API Embedding сервиса.
"""
from typing import Any, Dict, Optional
from pydantic import BaseModel, Field
# =============================================================================
# Запросы
# =============================================================================
class ScoreRequest(BaseModel):
"""Запрос на расчет скора."""
text: str = Field(..., min_length=1, description="Текст поста для оценки")
model_config = {
"json_schema_extra": {
"example": {
"text": "Это пример текста поста для оценки скоринга"
}
}
"json_schema_extra": {"example": {"text": "Это пример текста поста для оценки скоринга"}}
}
class ExampleRequest(BaseModel):
"""Запрос на добавление примера."""
text: str = Field(..., min_length=1, description="Текст примера")
model_config = {
"json_schema_extra": {"example": {"text": "Это пример опубликованного/отклоненного поста"}}
}
class SimilarRequest(BaseModel):
"""Запрос на поиск похожих постов."""
text: str = Field(..., min_length=1, description="Текст для поиска похожих")
threshold: float = Field(
default=0.9, ge=0.0, le=1.0, description="Минимальный порог similarity"
)
hours: int = Field(default=24, ge=1, le=168, description="Количество часов для фильтрации")
model_config = {
"json_schema_extra": {
"example": {
"text": "Это пример опубликованного/отклоненного поста"
"text": "Текст поста для поиска похожих",
"threshold": 0.9,
"hours": 24,
}
}
}
class SubmittedRequest(BaseModel):
"""Запрос на добавление submitted-поста."""
text: str = Field(..., min_length=1, description="Текст поста")
post_id: int | None = None
rag_score: float | None = None
model_config = {
"json_schema_extra": {
"example": {
"text": "Текст submitted-поста",
"post_id": 12345,
"rag_score": 0.85,
}
}
}
@@ -41,8 +71,10 @@ class ExampleRequest(BaseModel):
# Ответы
# =============================================================================
class ScoreMetadata(BaseModel):
"""Метаданные результата скоринга."""
positive_examples: int = Field(..., description="Количество положительных примеров")
negative_examples: int = Field(..., description="Количество отрицательных примеров")
model: str = Field(..., description="Название модели")
@@ -51,11 +83,14 @@ class ScoreMetadata(BaseModel):
class ScoreResponse(BaseModel):
"""Ответ с результатом скоринга."""
rag_score: float = Field(..., ge=0.0, le=1.0, description="Основной скор (neg/pos формула)")
rag_confidence: float = Field(..., ge=0.0, le=1.0, description="Уверенность в оценке")
rag_score_pos_only: float = Field(..., ge=0.0, le=1.0, description="Скор только по положительным примерам")
rag_score_pos_only: float = Field(
..., ge=0.0, le=1.0, description="Скор только по положительным примерам"
)
meta: ScoreMetadata = Field(..., description="Метаданные")
model_config = {
"json_schema_extra": {
"example": {
@@ -66,8 +101,8 @@ class ScoreResponse(BaseModel):
"positive_examples": 500,
"negative_examples": 350,
"model": "sentence-transformers/all-MiniLM-L12-v2",
"timestamp": 1706270000
}
"timestamp": 1706270000,
},
}
}
}
@@ -75,18 +110,71 @@ class ScoreResponse(BaseModel):
class ExampleResponse(BaseModel):
"""Ответ на добавление примера."""
success: bool = Field(..., description="Успешность добавления")
message: str = Field(..., description="Сообщение о результате")
positive_count: int = Field(..., description="Текущее количество положительных примеров")
negative_count: int = Field(..., description="Текущее количество отрицательных примеров")
model_config = {
"json_schema_extra": {
"example": {
"success": True,
"message": "Положительный пример добавлен",
"positive_count": 501,
"negative_count": 350
"negative_count": 350,
}
}
}
class SimilarPostItem(BaseModel):
"""Элемент похожего поста."""
similarity: float = Field(..., description="Косинусное сходство")
created_at: int = Field(..., description="Unix timestamp создания")
post_id: int | None = None
text: str = Field(..., description="Текст поста")
rag_score: float | None = None
class SimilarResponse(BaseModel):
"""Ответ с похожими постами."""
similar_count: int = Field(..., description="Количество найденных похожих постов")
similar_posts: list[SimilarPostItem] = Field(..., description="Список похожих постов")
model_config = {
"json_schema_extra": {
"example": {
"similar_count": 2,
"similar_posts": [
{
"similarity": 0.95,
"created_at": 1706270000,
"post_id": 123,
"text": "Похожий пост",
"rag_score": 0.85,
}
],
}
}
}
class SubmittedResponse(BaseModel):
"""Ответ на добавление submitted-поста."""
success: bool = Field(..., description="Успешность добавления")
message: str = Field(..., description="Сообщение о результате")
submitted_count: int = Field(..., description="Текущее количество submitted-постов")
model_config = {
"json_schema_extra": {
"example": {
"success": True,
"message": "Submitted-пост добавлен",
"submitted_count": 42,
}
}
}
@@ -94,20 +182,24 @@ class ExampleResponse(BaseModel):
class VectorStoreStats(BaseModel):
"""Статистика хранилища векторов."""
positive_count: int = Field(..., description="Количество положительных примеров")
negative_count: int = Field(..., description="Количество отрицательных примеров")
total_count: int = Field(..., description="Общее количество примеров")
submitted_count: int = Field(default=0, description="Количество submitted-постов")
vector_dim: int = Field(..., description="Размерность векторов")
max_examples: int = Field(..., description="Максимальное количество примеров")
max_submitted: int = Field(default=5000, description="Максимальное количество submitted-постов")
class StatsResponse(BaseModel):
"""Ответ со статистикой сервиса."""
model_name: str = Field(..., description="Название модели")
model_loaded: bool = Field(..., description="Загружена ли модель")
device: Optional[str] = Field(None, description="Устройство (cpu/cuda)")
device: str | None = Field(None, description="Устройство (cpu/cuda)")
vector_store: VectorStoreStats = Field(..., description="Статистика хранилища векторов")
model_config = {
"json_schema_extra": {
"example": {
@@ -119,8 +211,8 @@ class StatsResponse(BaseModel):
"negative_count": 350,
"total_count": 850,
"vector_dim": 384,
"max_examples": 10000
}
"max_examples": 10000,
},
}
}
}
@@ -128,16 +220,17 @@ class StatsResponse(BaseModel):
class WarmupResponse(BaseModel):
"""Ответ на прогрев модели."""
success: bool = Field(..., description="Успешность загрузки")
model_loaded: bool = Field(..., description="Загружена ли модель")
message: str = Field(..., description="Сообщение о результате")
model_config = {
"json_schema_extra": {
"example": {
"success": True,
"model_loaded": True,
"message": "Модель успешно загружена"
"message": "Модель успешно загружена",
}
}
}
@@ -145,14 +238,15 @@ class WarmupResponse(BaseModel):
class ErrorResponse(BaseModel):
"""Ответ с ошибкой."""
detail: str = Field(..., description="Описание ошибки")
error_type: str = Field(..., description="Тип ошибки")
model_config = {
"json_schema_extra": {
"example": {
"detail": "Недостаточно примеров для расчета скора",
"error_type": "InsufficientExamplesError"
"error_type": "InsufficientExamplesError",
}
}
}
@@ -160,23 +254,21 @@ class ErrorResponse(BaseModel):
class HealthResponse(BaseModel):
"""Ответ проверки здоровья сервиса."""
status: str = Field(..., description="Статус сервиса")
model_loaded: bool = Field(..., description="Загружена ли модель")
version: str = Field(..., description="Версия сервиса")
model_config = {
"json_schema_extra": {
"example": {
"status": "healthy",
"model_loaded": True,
"version": "0.1.0"
}
"example": {"status": "healthy", "model_loaded": True, "version": "0.1.0"}
}
}
class ScoringParamsResponse(BaseModel):
"""Ответ с текущими параметрами формулы расчета score."""
score_multiplier: float = Field(
...,
description=(
@@ -185,7 +277,7 @@ class ScoringParamsResponse(BaseModel):
"где diff = avg_pos - avg_neg (разница средних сходств топ-k примеров). "
"Чем больше значение, тем сильнее влияние разницы между положительными и отрицательными примерами на итоговый score. "
"Рекомендуемое значение: 5.0"
)
),
)
k: int = Field(
...,
@@ -195,22 +287,16 @@ class ScoringParamsResponse(BaseModel):
"и вычисляет среднее косинусное сходство. "
"Меньшее значение k делает алгоритм более чувствительным к различиям, но может быть менее стабильным. "
"Рекомендуемое значение: 3"
)
),
)
model_config = {
"json_schema_extra": {
"example": {
"score_multiplier": 5.0,
"k": 3
}
}
}
model_config = {"json_schema_extra": {"example": {"score_multiplier": 5.0, "k": 3}}}
class UpdateScoringParamsRequest(BaseModel):
"""Запрос на обновление параметров формулы расчета score."""
score_multiplier: Optional[float] = Field(
score_multiplier: float | None = Field(
None,
gt=0,
description=(
@@ -219,9 +305,9 @@ class UpdateScoringParamsRequest(BaseModel):
"где diff = avg_pos - avg_neg (разница средних сходств топ-k примеров). "
"Чем больше значение, тем сильнее влияние разницы между положительными и отрицательными примерами на итоговый score. "
"Должен быть > 0. Рекомендуемое значение: 5.0"
)
),
)
k: Optional[int] = Field(
k: int | None = Field(
None,
ge=1,
description=(
@@ -230,14 +316,7 @@ class UpdateScoringParamsRequest(BaseModel):
"и вычисляет среднее косинусное сходство. "
"Меньшее значение k делает алгоритм более чувствительным к различиям, но может быть менее стабильным. "
"Должно быть >= 1. Рекомендуемое значение: 3"
)
),
)
model_config = {
"json_schema_extra": {
"example": {
"score_multiplier": 5.0,
"k": 3
}
}
}
model_config = {"json_schema_extra": {"example": {"score_multiplier": 5.0, "k": 3}}}

View File

@@ -9,7 +9,7 @@ import asyncio
import logging
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any, Dict, List, Optional
from typing import Any
import numpy as np
@@ -29,7 +29,7 @@ logger = logging.getLogger(__name__)
class ScoringResult:
"""
Результат оценки поста.
Attributes:
score: Оценка от 0.0 до 1.0 (вероятность публикации)
confidence: Уверенность в оценке
@@ -39,6 +39,7 @@ class ScoringResult:
model: Название используемой модели
timestamp: Время получения оценки
"""
score: float
confidence: float
score_pos_only: float
@@ -46,8 +47,8 @@ class ScoringResult:
negative_examples: int
model: str
timestamp: int = field(default_factory=lambda: int(datetime.now().timestamp()))
def to_dict(self) -> Dict[str, Any]:
def to_dict(self) -> dict[str, Any]:
"""Преобразует результат в словарь."""
return {
"rag_score": round(self.score, 4),
@@ -58,31 +59,31 @@ class ScoringResult:
"negative_examples": self.negative_examples,
"model": self.model,
"timestamp": self.timestamp,
}
},
}
class RAGService:
"""
RAG сервис для оценки постов на основе векторного сходства.
Использует sentence-transformers для создания эмбеддингов текста и сравнивает
их с эталонными примерами (опубликованные vs отклоненные посты).
Attributes:
model_name: Название модели HuggingFace
vector_store: Хранилище векторов
min_text_length: Минимальная длина текста для обработки
"""
def __init__(
self,
settings: Optional[Settings] = None,
vector_store: Optional[VectorStore] = None,
settings: Settings | None = None,
vector_store: VectorStore | None = None,
):
"""
Инициализация RAG сервиса.
Args:
settings: Настройки сервиса (берутся из get_settings() если не переданы)
vector_store: Хранилище векторов (создается автоматически если не передано)
@@ -91,96 +92,102 @@ class RAGService:
self.model_name = self._settings.model_name
self.cache_dir = self._settings.cache_dir
self.min_text_length = self._settings.min_text_length
# Модель загружается лениво
self._model = None
self._device = None
self._model_loaded = False
# Хранилище векторов
self.vector_store = vector_store or VectorStore(
vector_dim=self._settings.vector_dim,
max_examples=self._settings.max_examples,
max_submitted=self._settings.max_submitted,
storage_path=self._settings.vectors_path,
submitted_path=self._settings.submitted_path,
score_multiplier=self._settings.score_multiplier,
k=3, # Фиксированное значение k для топ-k ближайших примеров
)
logger.info(f"RAGService инициализирован (model={self.model_name})")
@property
def is_model_loaded(self) -> bool:
"""Проверяет, загружена ли модель."""
return self._model_loaded
async def load_model(self) -> None:
"""
Загружает модель и токенизатор.
Выполняется асинхронно в отдельном потоке чтобы не блокировать event loop.
"""
if self._model_loaded:
return
logger.info(f"RAGService: Загрузка модели {self.model_name}...")
try:
# Загрузка в отдельном потоке
loop = asyncio.get_event_loop()
await loop.run_in_executor(None, self._load_model_sync)
self._model_loaded = True
logger.info(f"RAGService: Модель {self.model_name} успешно загружена")
except Exception as e:
logger.error(f"RAGService: Ошибка загрузки модели: {e}")
raise ModelNotLoadedError(f"Не удалось загрузить модель {self.model_name}: {e}")
def _load_model_sync(self) -> None:
"""Синхронная загрузка модели (вызывается в executor)."""
logger.info("RAGService: Начало _load_model_sync, импорт sentence_transformers...")
from sentence_transformers import SentenceTransformer
import torch
from sentence_transformers import SentenceTransformer
# Определяем устройство
self._device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"RAGService: Устройство определено: {self._device}")
# Загружаем модель SentenceTransformer
logger.info(f"RAGService: Загрузка модели из {self.model_name} (это может занять несколько минут)...")
logger.info(
f"RAGService: Загрузка модели из {self.model_name} (это может занять несколько минут)..."
)
self._model = SentenceTransformer(
self.model_name,
cache_folder=self.cache_dir,
device=self._device,
)
logger.info(f"RAGService: Модель готова на устройстве: {self._device}")
def _get_embedding_sync(self, text: str) -> np.ndarray:
"""
Получает эмбеддинг текста (синхронно).
Использует SentenceTransformer для получения нормализованного эмбеддинга.
Args:
text: Текст для векторизации
Returns:
Numpy массив с эмбеддингом (384 измерений для all-MiniLM-L12-v2)
"""
# SentenceTransformer автоматически нормализует эмбеддинги
embedding = self._model.encode(text, convert_to_numpy=True, normalize_embeddings=True)
return embedding.flatten()
def _get_embeddings_batch_sync(self, texts: List[str], batch_size: int = 16) -> List[np.ndarray]:
def _get_embeddings_batch_sync(
self, texts: list[str], batch_size: int = 16
) -> list[np.ndarray]:
"""
Получает эмбеддинги для батча текстов (синхронно).
Обрабатывает тексты пачками для эффективного использования GPU/CPU.
Args:
texts: Список текстов для векторизации
batch_size: Размер батча
Returns:
Список numpy массивов с эмбеддингами
"""
@@ -192,32 +199,34 @@ class RAGService:
normalize_embeddings=True,
show_progress_bar=False,
)
# Преобразуем в список отдельных массивов
return [emb.flatten() for emb in embeddings]
async def get_embeddings_batch(self, texts: List[str], batch_size: Optional[int] = None) -> List[np.ndarray]:
async def get_embeddings_batch(
self, texts: list[str], batch_size: int | None = None
) -> list[np.ndarray]:
"""
Получает эмбеддинги для батча текстов (асинхронно).
Args:
texts: Список текстов для векторизации
batch_size: Размер батча (берется из настроек если не указан)
Returns:
Список numpy массивов с эмбеддингами
"""
if not self._model_loaded:
await self.load_model()
if not self._model_loaded:
raise ModelNotLoadedError("Модель не загружена")
batch_size = batch_size or self._settings.batch_size
# Очищаем тексты
clean_texts = [self._clean_text(text) for text in texts]
# Выполняем батч-обработку в thread pool
loop = asyncio.get_event_loop()
embeddings = await loop.run_in_executor(
@@ -226,71 +235,67 @@ class RAGService:
clean_texts,
batch_size,
)
return embeddings
async def get_embedding(self, text: str) -> np.ndarray:
"""
Получает эмбеддинг текста (асинхронно).
Args:
text: Текст для векторизации
Returns:
Numpy массив с эмбеддингом
Raises:
ModelNotLoadedError: Если модель не загружена
TextTooShortError: Если текст слишком короткий
"""
if not self._model_loaded:
await self.load_model()
if not self._model_loaded:
raise ModelNotLoadedError("Модель не загружена")
# Очищаем текст
clean_text = self._clean_text(text)
if len(clean_text) < self.min_text_length:
raise TextTooShortError(
f"Текст слишком короткий (минимум {self.min_text_length} символов)"
)
# Выполняем в отдельном потоке
loop = asyncio.get_event_loop()
embedding = await loop.run_in_executor(
None,
self._get_embedding_sync,
clean_text
)
embedding = await loop.run_in_executor(None, self._get_embedding_sync, clean_text)
return embedding
def _clean_text(self, text: str) -> str:
"""Очищает текст от лишних символов."""
if not text:
return ""
# Удаляем лишние пробелы и переносы строк
clean = " ".join(text.split())
# Удаляем служебные символы (например "^" для helper сообщений)
if clean == "^":
return ""
return clean.strip()
async def calculate_score(self, text: str) -> ScoringResult:
"""
Рассчитывает скор для текста поста.
Args:
text: Текст поста для оценки
Returns:
ScoringResult с оценкой
Raises:
ScoringError: При ошибке расчета
InsufficientExamplesError: Если недостаточно примеров
@@ -299,16 +304,17 @@ class RAGService:
try:
# Получаем эмбеддинг текста
embedding = await self.get_embedding(text)
# Логируем первые элементы вектора для отладки
logger.debug(
f"RAGService: embedding[:3]={embedding[:3].tolist()}, "
f"text_preview='{text[:30]}'"
f"RAGService: embedding[:3]={embedding[:3].tolist()}, text_preview='{text[:30]}'"
)
# Рассчитываем скор через VectorStore
score, confidence, score_pos_only = self.vector_store.calculate_similarity_score(embedding)
score, confidence, score_pos_only = self.vector_store.calculate_similarity_score(
embedding
)
return ScoringResult(
score=score,
confidence=confidence,
@@ -317,22 +323,22 @@ class RAGService:
negative_examples=self.vector_store.negative_count,
model=self.model_name,
)
except (InsufficientExamplesError, TextTooShortError):
# Пробрасываем ожидаемые исключения
raise
except Exception as e:
logger.error(f"RAGService: Ошибка расчета скора: {e}")
raise ScoringError(f"Ошибка расчета скора: {e}")
async def add_positive_example(self, text: str) -> bool:
"""
Добавляет текст как положительный пример (опубликованный пост).
Args:
text: Текст опубликованного поста
Returns:
True если пример добавлен, False если дубликат/короткий текст
"""
@@ -341,32 +347,32 @@ class RAGService:
if len(clean_text) < self.min_text_length:
logger.debug("RAGService: Текст слишком короткий для примера, пропускаем")
return False
# Получаем эмбеддинг
embedding = await self.get_embedding(clean_text)
# Вычисляем хеш для дедупликации
text_hash = VectorStore.compute_text_hash(clean_text)
# Добавляем в хранилище
added = self.vector_store.add_positive(embedding, text_hash)
if added:
logger.info("RAGService: Добавлен положительный пример")
return added
except Exception as e:
logger.error(f"RAGService: Ошибка добавления положительного примера: {e}")
return False
async def add_negative_example(self, text: str) -> bool:
"""
Добавляет текст как отрицательный пример (отклоненный пост).
Args:
text: Текст отклоненного поста
Returns:
True если пример добавлен, False если дубликат/короткий текст
"""
@@ -375,29 +381,102 @@ class RAGService:
if len(clean_text) < self.min_text_length:
logger.debug("RAGService: Текст слишком короткий для примера, пропускаем")
return False
# Получаем эмбеддинг
embedding = await self.get_embedding(clean_text)
# Вычисляем хеш для дедупликации
text_hash = VectorStore.compute_text_hash(clean_text)
# Добавляем в хранилище
added = self.vector_store.add_negative(embedding, text_hash)
if added:
logger.info("RAGService: Добавлен отрицательный пример")
return added
except Exception as e:
logger.error(f"RAGService: Ошибка добавления отрицательного примера: {e}")
return False
async def add_submitted_post(
self,
text: str,
post_id: int | None = None,
rag_score: float | None = None,
) -> bool:
"""
Добавляет submitted-пост в коллекцию для индексации.
Args:
text: Текст поста
post_id: ID поста (опционально)
rag_score: RAG скор поста (опционально)
Returns:
True если добавлен, False если дубликат/короткий текст
"""
try:
clean_text = self._clean_text(text)
if len(clean_text) < self.min_text_length:
logger.debug("RAGService: Текст слишком короткий для submitted, пропускаем")
return False
embedding = await self.get_embedding(clean_text)
text_hash = VectorStore.compute_text_hash(clean_text)
created_at = int(datetime.now().timestamp())
added = self.vector_store.add_submitted(
vector=embedding,
text_hash=text_hash,
created_at=created_at,
post_id=post_id,
text=clean_text,
rag_score=rag_score,
)
if added:
logger.info("RAGService: Добавлен submitted-пост")
return added
except Exception as e:
logger.error(f"RAGService: Ошибка добавления submitted-поста: {e}")
return False
async def find_similar_posts(
self,
text: str,
threshold: float = 0.9,
hours: int = 24,
) -> list[dict[str, Any]]:
"""
Ищет похожие submitted-посты за последние N часов.
Args:
text: Текст для поиска
threshold: Минимальный порог similarity (0.0 - 1.0)
hours: Количество часов для фильтрации
Returns:
Список dict с полями: similarity, created_at, post_id, text, rag_score
"""
try:
embedding = await self.get_embedding(text)
return self.vector_store.find_similar_submitted(
vector=embedding,
threshold=threshold,
hours=hours,
)
except Exception as e:
logger.error(f"RAGService: Ошибка поиска похожих постов: {e}")
return []
async def warmup(self) -> bool:
"""
Прогревает модель (загружает если не загружена).
Returns:
True если модель загружена успешно
"""
@@ -407,13 +486,15 @@ class RAGService:
except Exception as e:
logger.error(f"RAGService: Ошибка прогрева модели: {e}")
return False
def save_vectors(self) -> None:
"""Сохраняет векторы на диск."""
"""Сохраняет векторы на диск (включая submitted)."""
if self.vector_store.storage_path:
self.vector_store.save_to_disk()
def get_stats(self) -> Dict[str, Any]:
if self.vector_store.submitted_path:
self.vector_store.save_submitted_to_disk()
def get_stats(self) -> dict[str, Any]:
"""Возвращает статистику сервиса."""
return {
"model_name": self.model_name,
@@ -424,13 +505,13 @@ class RAGService:
# Глобальный экземпляр сервиса (singleton)
_rag_service: Optional[RAGService] = None
_rag_service: RAGService | None = None
def get_rag_service() -> RAGService:
"""
Возвращает глобальный экземпляр RAG сервиса.
Returns:
RAGService: Экземпляр сервиса
"""

View File

@@ -9,8 +9,9 @@ import hashlib
import logging
import os
import threading
import time
from pathlib import Path
from typing import List, Optional, Tuple
from typing import Any
import numpy as np
@@ -22,89 +23,109 @@ logger = logging.getLogger(__name__)
class VectorStore:
"""
In-memory хранилище векторов для RAG.
Хранит отдельно положительные (опубликованные) и отрицательные (отклоненные)
примеры. Использует косинусное сходство для расчета скора.
Attributes:
vector_dim: Размерность векторов (384 для all-MiniLM-L12-v2)
max_examples: Максимальное количество примеров каждого типа
"""
def __init__(
self,
vector_dim: int = 384,
max_examples: int = 10000,
storage_path: Optional[str] = None,
max_submitted: int = 5000,
storage_path: str | None = None,
submitted_path: str | None = None,
score_multiplier: float = 5.0,
k: int = 3,
):
"""
Инициализация хранилища.
Args:
vector_dim: Размерность векторов
max_examples: Максимальное количество примеров каждого типа
max_submitted: Максимальное количество submitted-постов
storage_path: Путь для сохранения/загрузки векторов (опционально)
submitted_path: Путь для сохранения/загрузки submitted-постов (опционально)
score_multiplier: Множитель для масштабирования разницы в скорах
k: Количество ближайших примеров для расчета среднего сходства
"""
self.vector_dim = vector_dim
self.max_examples = max_examples
self.max_submitted = max_submitted
self.storage_path = storage_path
self.submitted_path = submitted_path
self.score_multiplier = score_multiplier
self.k = k
# Инициализируем пустые массивы
# Используем список для динамического добавления, потом конвертируем в numpy
self._positive_vectors: list = []
self._negative_vectors: list = []
self._positive_hashes: list = [] # Хеши текстов для дедупликации
self._negative_hashes: list = []
# Submitted-посты (третья коллекция)
self._submitted_vectors: list = []
self._submitted_hashes: list = []
self._submitted_created_at: list = [] # Unix timestamps
self._submitted_post_ids: list = []
self._submitted_texts: list = []
self._submitted_rag_scores: list = []
# Lock для потокобезопасности
self._lock = threading.Lock()
# Пытаемся загрузить сохраненные векторы
# Всегда вызываем _load_from_disk если есть storage_path - он сам решит что загружать
if storage_path:
self._load_from_disk()
if submitted_path:
self._load_submitted_from_disk()
@property
def positive_count(self) -> int:
"""Количество положительных примеров."""
return len(self._positive_vectors)
@property
def negative_count(self) -> int:
"""Количество отрицательных примеров."""
return len(self._negative_vectors)
@property
def total_count(self) -> int:
"""Общее количество примеров."""
return self.positive_count + self.negative_count
@property
def submitted_count(self) -> int:
"""Количество submitted-постов."""
return len(self._submitted_vectors)
@staticmethod
def compute_text_hash(text: str) -> str:
"""Вычисляет хеш текста для дедупликации."""
return hashlib.md5(text.encode('utf-8')).hexdigest()
return hashlib.md5(text.encode("utf-8")).hexdigest()
def _normalize_vector(self, vector: np.ndarray) -> np.ndarray:
"""Нормализует вектор для косинусного сходства."""
norm = np.linalg.norm(vector)
if norm == 0:
return vector
return vector / norm
def add_positive(self, vector: np.ndarray, text_hash: Optional[str] = None) -> bool:
def add_positive(self, vector: np.ndarray, text_hash: str | None = None) -> bool:
"""
Добавляет положительный пример (опубликованный пост).
Args:
vector: Векторное представление текста
text_hash: Хеш текста для дедупликации (опционально)
Returns:
True если добавлен, False если дубликат или превышен лимит
"""
@@ -113,71 +134,73 @@ class VectorStore:
if text_hash and text_hash in self._positive_hashes:
logger.debug("VectorStore: Пропуск дубликата положительного примера")
return False
# Проверяем лимит
if len(self._positive_vectors) >= self.max_examples:
# Удаляем самый старый пример (FIFO)
self._positive_vectors.pop(0)
self._positive_hashes.pop(0)
logger.debug("VectorStore: Удален старый положительный пример (лимит)")
# Нормализуем и добавляем
normalized = self._normalize_vector(vector)
self._positive_vectors.append(normalized)
if text_hash:
self._positive_hashes.append(text_hash)
logger.info(f"VectorStore: Добавлен положительный пример (всего: {self.positive_count})")
logger.info(
f"VectorStore: Добавлен положительный пример (всего: {self.positive_count})"
)
return True
def add_positive_batch(
self,
vectors: List[np.ndarray],
text_hashes: Optional[List[str]] = None
self, vectors: list[np.ndarray], text_hashes: list[str] | None = None
) -> int:
"""
Добавляет батч положительных примеров.
Args:
vectors: Список векторов
text_hashes: Список хешей текстов для дедупликации
Returns:
Количество добавленных примеров
"""
if text_hashes is None:
text_hashes = [None] * len(vectors)
added = 0
with self._lock:
for vector, text_hash in zip(vectors, text_hashes):
# Проверяем дубликат по хешу
if text_hash and text_hash in self._positive_hashes:
continue
# Проверяем лимит
if len(self._positive_vectors) >= self.max_examples:
self._positive_vectors.pop(0)
self._positive_hashes.pop(0)
# Нормализуем и добавляем
normalized = self._normalize_vector(vector)
self._positive_vectors.append(normalized)
if text_hash:
self._positive_hashes.append(text_hash)
added += 1
logger.info(f"VectorStore: Добавлено {added} положительных примеров батчем (всего: {self.positive_count})")
logger.info(
f"VectorStore: Добавлено {added} положительных примеров батчем (всего: {self.positive_count})"
)
return added
def add_negative(self, vector: np.ndarray, text_hash: Optional[str] = None) -> bool:
def add_negative(self, vector: np.ndarray, text_hash: str | None = None) -> bool:
"""
Добавляет отрицательный пример (отклоненный пост).
Args:
vector: Векторное представление текста
text_hash: Хеш текста для дедупликации (опционально)
Returns:
True если добавлен, False если дубликат или превышен лимит
"""
@@ -186,112 +209,208 @@ class VectorStore:
if text_hash and text_hash in self._negative_hashes:
logger.debug("VectorStore: Пропуск дубликата отрицательного примера")
return False
# Проверяем лимит
if len(self._negative_vectors) >= self.max_examples:
# Удаляем самый старый пример (FIFO)
self._negative_vectors.pop(0)
self._negative_hashes.pop(0)
logger.debug("VectorStore: Удален старый отрицательный пример (лимит)")
# Нормализуем и добавляем
normalized = self._normalize_vector(vector)
self._negative_vectors.append(normalized)
if text_hash:
self._negative_hashes.append(text_hash)
logger.info(f"VectorStore: Добавлен отрицательный пример (всего: {self.negative_count})")
logger.info(
f"VectorStore: Добавлен отрицательный пример (всего: {self.negative_count})"
)
return True
def add_negative_batch(
self,
vectors: List[np.ndarray],
text_hashes: Optional[List[str]] = None
self, vectors: list[np.ndarray], text_hashes: list[str] | None = None
) -> int:
"""
Добавляет батч отрицательных примеров.
Args:
vectors: Список векторов
text_hashes: Список хешей текстов для дедупликации
Returns:
Количество добавленных примеров
"""
if text_hashes is None:
text_hashes = [None] * len(vectors)
added = 0
with self._lock:
for vector, text_hash in zip(vectors, text_hashes):
# Проверяем дубликат по хешу
if text_hash and text_hash in self._negative_hashes:
continue
# Проверяем лимит
if len(self._negative_vectors) >= self.max_examples:
self._negative_vectors.pop(0)
self._negative_hashes.pop(0)
# Нормализуем и добавляем
normalized = self._normalize_vector(vector)
self._negative_vectors.append(normalized)
if text_hash:
self._negative_hashes.append(text_hash)
added += 1
logger.info(f"VectorStore: Добавлено {added} отрицательных примеров батчем (всего: {self.negative_count})")
logger.info(
f"VectorStore: Добавлено {added} отрицательных примеров батчем (всего: {self.negative_count})"
)
return added
def calculate_similarity_score(self, vector: np.ndarray) -> Tuple[float, float, float]:
def add_submitted(
self,
vector: np.ndarray,
text_hash: str,
created_at: int,
post_id: int | None = None,
text: str = "",
rag_score: float | None = None,
) -> bool:
"""
Добавляет submitted-пост в коллекцию.
Args:
vector: Векторное представление текста
text_hash: Хеш текста для дедупликации
created_at: Unix timestamp создания
post_id: ID поста (опционально)
text: Текст поста
rag_score: RAG скор поста (опционально)
Returns:
True если добавлен, False если дубликат
"""
with self._lock:
if text_hash in self._submitted_hashes:
logger.debug("VectorStore: Пропуск дубликата submitted-поста")
return False
if len(self._submitted_vectors) >= self.max_submitted:
self._submitted_vectors.pop(0)
self._submitted_hashes.pop(0)
self._submitted_created_at.pop(0)
self._submitted_post_ids.pop(0)
self._submitted_texts.pop(0)
self._submitted_rag_scores.pop(0)
logger.debug("VectorStore: Удален старый submitted-пост (лимит)")
normalized = self._normalize_vector(vector)
self._submitted_vectors.append(normalized)
self._submitted_hashes.append(text_hash)
self._submitted_created_at.append(created_at)
self._submitted_post_ids.append(post_id)
self._submitted_texts.append(text)
self._submitted_rag_scores.append(rag_score)
logger.info(f"VectorStore: Добавлен submitted-пост (всего: {self.submitted_count})")
return True
def find_similar_submitted(
self,
vector: np.ndarray,
threshold: float,
hours: int,
) -> list[dict[str, Any]]:
"""
Ищет похожие submitted-посты за последние N часов.
Args:
vector: Векторное представление запроса
threshold: Минимальный порог similarity (0.0 - 1.0)
hours: Количество часов для фильтрации (created_at >= now - hours*3600)
Returns:
Список dict с полями: similarity, created_at, post_id, text, rag_score
"""
with self._lock:
if self.submitted_count == 0:
return []
now = int(time.time())
cutoff = now - hours * 3600
normalized = self._normalize_vector(vector)
submitted_matrix = np.array(self._submitted_vectors)
similarities = np.dot(submitted_matrix, normalized)
results: list[dict[str, Any]] = []
for i, sim in enumerate(similarities):
if float(sim) < threshold:
continue
created_at = self._submitted_created_at[i]
if created_at < cutoff:
continue
results.append(
{
"similarity": float(sim),
"created_at": created_at,
"post_id": self._submitted_post_ids[i],
"text": self._submitted_texts[i],
"rag_score": self._submitted_rag_scores[i],
}
)
return sorted(results, key=lambda x: x["similarity"], reverse=True)
def calculate_similarity_score(self, vector: np.ndarray) -> tuple[float, float, float]:
"""
Рассчитывает скор на основе сходства с примерами.
Алгоритм:
1. Вычисляем косинусное сходство со всеми примерами
2. Используем топ-k ближайших примеров для более чувствительной оценки
3. Сравниваем топ-k положительных с топ-k отрицательными
Args:
vector: Векторное представление нового поста
Returns:
Tuple (score, confidence, score_pos_only):
- score: Оценка от 0.0 до 1.0 (neg/pos формула)
- confidence: Уверенность (зависит от количества примеров)
- score_pos_only: Оценка только по положительным примерам
Raises:
InsufficientExamplesError: Если недостаточно примеров
"""
with self._lock:
if self.positive_count == 0:
raise InsufficientExamplesError(
"Нет положительных примеров для сравнения"
)
raise InsufficientExamplesError("Нет положительных примеров для сравнения")
# Нормализуем входной вектор
normalized = self._normalize_vector(vector)
normalized = self._normalize_vector(np.asarray(vector).flatten())
# Конвертируем в numpy массивы для быстрых вычислений
pos_matrix = np.array(self._positive_vectors)
# Используем vstack для гарантии одинаковой формы (совместимость со старым npz)
pos_matrix = np.vstack([np.asarray(v).flatten() for v in self._positive_vectors])
# Косинусное сходство с положительными примерами
# Для нормализованных векторов это просто скалярное произведение
pos_similarities = np.dot(pos_matrix, normalized)
# Косинусное сходство с отрицательными примерами
if self.negative_count > 0:
neg_matrix = np.array(self._negative_vectors)
neg_matrix = np.vstack([np.asarray(v).flatten() for v in self._negative_vectors])
neg_similarities = np.dot(neg_matrix, normalized)
else:
neg_similarities = np.array([])
# Используем топ-k ближайших примеров для расчета среднего сходства
k_pos = min(self.k, len(pos_similarities))
top_k_pos = np.sort(pos_similarities)[-k_pos:]
avg_pos = float(np.mean(top_k_pos))
# Для отрицательных: если их меньше k, берем все, иначе топ-k
if len(neg_similarities) > 0:
k_neg = min(self.k, len(neg_similarities))
@@ -300,11 +419,11 @@ class VectorStore:
else:
# Если нет отрицательных примеров, используем нейтральное значение
avg_neg = avg_pos # Нейтральный скор = 0.5
# Формула расчета score: (diff * scale + 1) / 2, переводим из [-1, 1] в [0, 1]
diff = avg_pos - avg_neg
score_neg_pos = np.clip((diff * self.score_multiplier + 1) / 2, 0.0, 1.0)
# === Вариант 2: pos only (только положительные, топ-k ближайших) ===
# Берём топ-5 ближайших положительных примеров
top_5_k = min(5, len(pos_similarities))
@@ -312,20 +431,20 @@ class VectorStore:
# Нормализуем: 0.85 -> 0.0, 0.95 -> 1.0 (типичный диапазон для BERT)
score_pos_only = (top_5_sim - 0.85) / 0.10
score_pos_only = max(0.0, min(1.0, score_pos_only))
# Основной скор — neg/pos
score = score_neg_pos
# Confidence зависит от количества примеров (100% при 1000 примерах)
total_examples = self.positive_count + self.negative_count
confidence = min(1.0, total_examples / 1000)
# Дополнительная диагностическая информация
pos_mean = float(np.mean(pos_similarities))
pos_std = float(np.std(pos_similarities))
pos_min = float(np.min(pos_similarities))
pos_max = float(np.max(pos_similarities))
if len(neg_similarities) > 0:
neg_mean = float(np.mean(neg_similarities))
neg_std = float(np.std(neg_similarities))
@@ -333,7 +452,7 @@ class VectorStore:
neg_max = float(np.max(neg_similarities))
else:
neg_mean = neg_std = neg_min = neg_max = 0.0
logger.info(
f"VectorStore: k={self.k}, k_pos={k_pos}, k_neg={k_neg if len(neg_similarities) > 0 else 0}, "
f"avg_pos={avg_pos:.4f}, avg_neg={avg_neg:.4f}, "
@@ -342,58 +461,145 @@ class VectorStore:
f"pos_mean={pos_mean:.4f}±{pos_std:.4f}[{pos_min:.4f}-{pos_max:.4f}], "
f"neg_mean={neg_mean:.4f}±{neg_std:.4f}[{neg_min:.4f}-{neg_max:.4f}]"
)
return score, confidence, score_pos_only
def save_to_disk(self, path: Optional[str] = None) -> None:
def save_to_disk(self, path: str | None = None) -> None:
"""
Сохраняет векторы на диск.
Args:
path: Путь для сохранения (если не указан, используется storage_path)
"""
save_path = path or self.storage_path
if not save_path:
raise VectorStoreError("Путь для сохранения не указан")
with self._lock:
# Создаем директорию если нужно
Path(save_path).parent.mkdir(parents=True, exist_ok=True)
# Сохраняем в npz формате
np.savez_compressed(
save_path,
positive_vectors=np.array(self._positive_vectors) if self._positive_vectors else np.array([]),
negative_vectors=np.array(self._negative_vectors) if self._negative_vectors else np.array([]),
positive_vectors=np.array(self._positive_vectors)
if self._positive_vectors
else np.array([]),
negative_vectors=np.array(self._negative_vectors)
if self._negative_vectors
else np.array([]),
positive_hashes=np.array(self._positive_hashes, dtype=object),
negative_hashes=np.array(self._negative_hashes, dtype=object),
vector_dim=self.vector_dim,
max_examples=self.max_examples,
)
logger.info(
f"VectorStore: Сохранено на диск ({self.positive_count} pos, "
f"{self.negative_count} neg): {save_path}"
)
def save_submitted_to_disk(self, path: str | None = None) -> None:
"""
Сохраняет submitted-коллекцию на диск.
Args:
path: Путь для сохранения (если не указан, используется submitted_path)
"""
save_path = path or self.submitted_path
if not save_path:
raise VectorStoreError("Путь для сохранения submitted не указан")
with self._lock:
Path(save_path).parent.mkdir(parents=True, exist_ok=True)
np.savez_compressed(
save_path,
vectors=np.array(self._submitted_vectors)
if self._submitted_vectors
else np.array([]),
hashes=np.array(self._submitted_hashes, dtype=object),
created_at=np.array(self._submitted_created_at)
if self._submitted_created_at
else np.array([]),
post_ids=np.array(self._submitted_post_ids, dtype=object),
texts=np.array(self._submitted_texts, dtype=object),
rag_scores=np.array(self._submitted_rag_scores, dtype=object),
)
logger.info(f"VectorStore: Сохранено submitted ({self.submitted_count}): {save_path}")
def _load_submitted_from_disk(self) -> None:
"""Загружает submitted-коллекцию с диска."""
if not self.submitted_path or not os.path.exists(self.submitted_path):
return
try:
with self._lock:
data = np.load(self.submitted_path, allow_pickle=True)
vectors = data.get("vectors", np.array([]))
if vectors.size > 0:
if len(vectors.shape) == 2:
self._submitted_vectors = [
self._normalize_vector(np.array(v)) for v in vectors
]
elif len(vectors.shape) == 1:
self._submitted_vectors = [self._normalize_vector(np.array(vectors))]
else:
self._submitted_vectors = []
else:
self._submitted_vectors = []
hashes = data.get("hashes", np.array([]))
self._submitted_hashes = list(hashes) if hashes.size > 0 else []
created_at = data.get("created_at", np.array([]))
self._submitted_created_at = list(created_at) if created_at.size > 0 else []
post_ids = data.get("post_ids", np.array([]))
self._submitted_post_ids = list(post_ids) if post_ids.size > 0 else []
texts = data.get("texts", np.array([]))
self._submitted_texts = list(texts) if texts.size > 0 else []
rag_scores = data.get("rag_scores", np.array([]))
self._submitted_rag_scores = list(rag_scores) if rag_scores.size > 0 else []
# Выравниваем длины (на случай поврежденных данных)
n = len(self._submitted_vectors)
self._submitted_hashes = self._submitted_hashes[:n]
self._submitted_created_at = self._submitted_created_at[:n]
self._submitted_post_ids = self._submitted_post_ids[:n]
self._submitted_texts = self._submitted_texts[:n]
self._submitted_rag_scores = self._submitted_rag_scores[:n]
logger.info(
f"VectorStore: Загружено submitted ({self.submitted_count}): {self.submitted_path}"
)
except Exception as e:
logger.error(f"VectorStore: Ошибка загрузки submitted с диска: {e}")
def _load_from_disk(self) -> None:
"""Загружает векторы с диска."""
if not self.storage_path:
return
try:
with self._lock:
storage_dir = Path(self.storage_path).parent
positive_npy = storage_dir / "positive_embeddings.npy"
negative_npy = storage_dir / "negative_embeddings.npy"
# Отладочное логирование
logger.info(f"VectorStore: Проверка путей - storage_dir={storage_dir}, positive_npy={positive_npy}, exists={positive_npy.exists()}, negative_npy={negative_npy}, exists={negative_npy.exists()}")
logger.info(
f"VectorStore: Проверка путей - storage_dir={storage_dir}, positive_npy={positive_npy}, exists={positive_npy.exists()}, negative_npy={negative_npy}, exists={negative_npy.exists()}"
)
# Проверяем наличие отдельных .npy файлов
if positive_npy.exists() or negative_npy.exists():
logger.info("VectorStore: Обнаружены отдельные .npy файлы, загружаем их...")
# Загружаем положительные векторы
if positive_npy.exists():
pos_vectors = np.load(positive_npy, allow_pickle=False)
@@ -406,10 +612,14 @@ class VectorStore:
# Один вектор [dim]
self._positive_vectors = [pos_vectors]
else:
logger.warning(f"VectorStore: Неожиданная размерность positive_embeddings.npy: {pos_vectors.shape}")
logger.warning(
f"VectorStore: Неожиданная размерность positive_embeddings.npy: {pos_vectors.shape}"
)
self._positive_vectors = []
logger.info(f"VectorStore: Загружено {len(self._positive_vectors)} положительных векторов из {positive_npy}")
logger.info(
f"VectorStore: Загружено {len(self._positive_vectors)} положительных векторов из {positive_npy}"
)
# Загружаем отрицательные векторы
if negative_npy.exists():
neg_vectors = np.load(negative_npy, allow_pickle=False)
@@ -422,52 +632,62 @@ class VectorStore:
# Один вектор [dim]
self._negative_vectors = [neg_vectors]
else:
logger.warning(f"VectorStore: Неожиданная размерность negative_embeddings.npy: {neg_vectors.shape}")
logger.warning(
f"VectorStore: Неожиданная размерность negative_embeddings.npy: {neg_vectors.shape}"
)
self._negative_vectors = []
logger.info(f"VectorStore: Загружено {len(self._negative_vectors)} отрицательных векторов из {negative_npy}")
logger.info(
f"VectorStore: Загружено {len(self._negative_vectors)} отрицательных векторов из {negative_npy}"
)
# Нормализуем загруженные векторы
self._positive_vectors = [self._normalize_vector(np.array(v)) for v in self._positive_vectors]
self._negative_vectors = [self._normalize_vector(np.array(v)) for v in self._negative_vectors]
self._positive_vectors = [
self._normalize_vector(np.array(v)) for v in self._positive_vectors
]
self._negative_vectors = [
self._normalize_vector(np.array(v)) for v in self._negative_vectors
]
logger.info(
f"VectorStore: Загружено с диска из .npy файлов ({self.positive_count} pos, "
f"{self.negative_count} neg)"
)
return
# Если отдельных .npy файлов нет, пытаемся загрузить из старого формата .npz
if os.path.exists(self.storage_path):
logger.info(f"VectorStore: Загружаем из старого формата .npz: {self.storage_path}")
logger.info(
f"VectorStore: Загружаем из старого формата .npz: {self.storage_path}"
)
data = np.load(self.storage_path, allow_pickle=True)
# Загружаем векторы
pos_vectors = data.get('positive_vectors', np.array([]))
neg_vectors = data.get('negative_vectors', np.array([]))
pos_vectors = data.get("positive_vectors", np.array([]))
neg_vectors = data.get("negative_vectors", np.array([]))
if pos_vectors.size > 0:
self._positive_vectors = list(pos_vectors)
if neg_vectors.size > 0:
self._negative_vectors = list(neg_vectors)
# Загружаем хеши
pos_hashes = data.get('positive_hashes', np.array([]))
neg_hashes = data.get('negative_hashes', np.array([]))
pos_hashes = data.get("positive_hashes", np.array([]))
neg_hashes = data.get("negative_hashes", np.array([]))
if pos_hashes.size > 0:
self._positive_hashes = list(pos_hashes)
if neg_hashes.size > 0:
self._negative_hashes = list(neg_hashes)
logger.info(
f"VectorStore: Загружено с диска ({self.positive_count} pos, "
f"{self.negative_count} neg): {self.storage_path}"
)
except Exception as e:
logger.error(f"VectorStore: Ошибка загрузки с диска: {e}")
# Продолжаем с пустым хранилищем
def clear(self) -> None:
"""Очищает все векторы."""
with self._lock:
@@ -475,40 +695,48 @@ class VectorStore:
self._negative_vectors.clear()
self._positive_hashes.clear()
self._negative_hashes.clear()
self._submitted_vectors.clear()
self._submitted_hashes.clear()
self._submitted_created_at.clear()
self._submitted_post_ids.clear()
self._submitted_texts.clear()
self._submitted_rag_scores.clear()
logger.info("VectorStore: Хранилище очищено")
def get_stats(self) -> dict:
"""Возвращает статистику хранилища."""
return {
"positive_count": self.positive_count,
"negative_count": self.negative_count,
"total_count": self.total_count,
"submitted_count": self.submitted_count,
"vector_dim": self.vector_dim,
"max_examples": self.max_examples,
"max_submitted": self.max_submitted,
}
def get_scoring_params(self) -> dict:
"""Возвращает текущие параметры формулы расчета score."""
return {
"score_multiplier": self.score_multiplier,
"k": self.k,
}
def update_scoring_params(
self,
score_multiplier: Optional[float] = None,
k: Optional[int] = None,
score_multiplier: float | None = None,
k: int | None = None,
) -> dict:
"""
Обновляет параметры формулы расчета score.
Args:
score_multiplier: Множитель для масштабирования разницы (должен быть > 0)
k: Количество ближайших примеров для расчета среднего (должно быть >= 1)
Returns:
dict: Обновленные параметры
Raises:
ValueError: При невалидных значениях
"""
@@ -517,15 +745,15 @@ class VectorStore:
if score_multiplier <= 0:
raise ValueError("score_multiplier должен быть > 0")
self.score_multiplier = score_multiplier
if k is not None:
if k < 1:
raise ValueError("k должен быть >= 1")
self.k = k
logger.info(
f"VectorStore: Параметры формулы обновлены: "
f"score_multiplier={self.score_multiplier}, k={self.k}"
)
return self.get_scoring_params()

View File

@@ -20,6 +20,8 @@ services:
- RAG_CACHE_DIR=/app/data/models
- RAG_VECTORS_PATH=/app/data/vectors/vectors.npz
- RAG_MAX_EXAMPLES=${RAG_MAX_EXAMPLES:-10000}
- RAG_MAX_SUBMITTED=${RAG_MAX_SUBMITTED:-5000}
- RAG_SUBMITTED_PATH=/app/data/vectors/submitted.npz
- RAG_SCORE_MULTIPLIER=${RAG_SCORE_MULTIPLIER:-5.0}
- RAG_BATCH_SIZE=${RAG_BATCH_SIZE:-16}
- RAG_MIN_TEXT_LENGTH=${RAG_MIN_TEXT_LENGTH:-3}

View File

@@ -7,6 +7,8 @@ RAG_CACHE_DIR=data/models
# VectorStore
RAG_VECTORS_PATH=data/vectors/vectors.npz
RAG_MAX_EXAMPLES=10000
RAG_MAX_SUBMITTED=5000
RAG_SUBMITTED_PATH=data/vectors/submitted.npz
RAG_SCORE_MULTIPLIER=5.0
# Батч-обработка

16
tests/conftest.py Normal file
View File

@@ -0,0 +1,16 @@
"""
Pytest fixtures для RAG сервиса.
"""
import os
import pytest
@pytest.fixture(autouse=True)
def allow_no_auth():
"""Разрешает запросы без API ключа в тестах."""
os.environ["RAG_ALLOW_NO_AUTH"] = "true"
yield
if "RAG_ALLOW_NO_AUTH" in os.environ:
del os.environ["RAG_ALLOW_NO_AUTH"]

View File

@@ -0,0 +1,169 @@
"""
Тесты API endpoints /similar и /submitted.
"""
from unittest.mock import AsyncMock, MagicMock
import pytest
from fastapi.testclient import TestClient
from app.main import app
@pytest.fixture
def mock_rag_service():
"""Mock RAGService для тестов API без загрузки модели."""
service = MagicMock()
service.is_model_loaded = True
service.vector_store.submitted_count = 0
return service
@pytest.fixture
def client(mock_rag_service):
"""TestClient с переопределённым RAG сервисом."""
from app.api.routes import get_service
def override_get_service():
return mock_rag_service
app.dependency_overrides[get_service] = override_get_service
# get_rag_service используется при создании сервиса - get_service вызывает get_rag_service
# Смотрю routes: get_service возвращает get_rag_service(). Значит override get_service достаточно.
with TestClient(app) as c:
yield c
app.dependency_overrides.clear()
def test_similar_endpoint(client, mock_rag_service):
"""POST /api/v1/similar возвращает похожие посты."""
mock_rag_service.find_similar_posts = AsyncMock(
return_value=[
{
"similarity": 0.95,
"created_at": 1706270000,
"post_id": 123,
"text": "Похожий пост",
"rag_score": 0.85,
}
]
)
response = client.post(
"/api/v1/similar",
json={"text": "Текст для поиска", "threshold": 0.9, "hours": 24},
)
assert response.status_code == 200
data = response.json()
assert data["similar_count"] == 1
assert len(data["similar_posts"]) == 1
assert data["similar_posts"][0]["similarity"] == 0.95
assert data["similar_posts"][0]["post_id"] == 123
assert data["similar_posts"][0]["text"] == "Похожий пост"
assert data["similar_posts"][0]["rag_score"] == 0.85
def test_similar_endpoint_empty(client, mock_rag_service):
"""POST /api/v1/similar с пустым результатом."""
mock_rag_service.find_similar_posts = AsyncMock(return_value=[])
response = client.post(
"/api/v1/similar",
json={"text": "Уникальный текст", "threshold": 0.99, "hours": 1},
)
assert response.status_code == 200
assert response.json()["similar_count"] == 0
assert response.json()["similar_posts"] == []
def test_similar_endpoint_default_params(client, mock_rag_service):
"""POST /api/v1/similar с дефолтными параметрами."""
mock_rag_service.find_similar_posts = AsyncMock(return_value=[])
response = client.post(
"/api/v1/similar",
json={"text": "Текст"},
)
assert response.status_code == 200
mock_rag_service.find_similar_posts.assert_called_once_with(
text="Текст",
threshold=0.9,
hours=24,
)
def test_submitted_endpoint_success(client, mock_rag_service):
"""POST /api/v1/submitted успешно добавляет пост."""
mock_rag_service.add_submitted_post = AsyncMock(return_value=True)
mock_rag_service.vector_store.submitted_count = 5
response = client.post(
"/api/v1/submitted",
json={"text": "Новый пост", "post_id": 42, "rag_score": 0.8},
)
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert "добавлен" in data["message"].lower()
assert data["submitted_count"] == 5
mock_rag_service.add_submitted_post.assert_called_once_with(
text="Новый пост",
post_id=42,
rag_score=0.8,
)
def test_submitted_endpoint_duplicate(client, mock_rag_service):
"""POST /api/v1/submitted при дубликате возвращает success=False."""
mock_rag_service.add_submitted_post = AsyncMock(return_value=False)
mock_rag_service.vector_store.submitted_count = 10
response = client.post(
"/api/v1/submitted",
json={"text": "Дубликат поста"},
)
assert response.status_code == 200
data = response.json()
assert data["success"] is False
assert data["submitted_count"] == 10
def test_submitted_endpoint_optional_fields(client, mock_rag_service):
"""POST /api/v1/submitted с только текстом (post_id, rag_score опциональны)."""
mock_rag_service.add_submitted_post = AsyncMock(return_value=True)
mock_rag_service.vector_store.submitted_count = 1
response = client.post(
"/api/v1/submitted",
json={"text": "Только текст"},
)
assert response.status_code == 200
mock_rag_service.add_submitted_post.assert_called_once_with(
text="Только текст",
post_id=None,
rag_score=None,
)
def test_similar_validation_empty_text(client):
"""POST /api/v1/similar с пустым текстом возвращает 422."""
response = client.post(
"/api/v1/similar",
json={"text": "", "threshold": 0.9, "hours": 24},
)
assert response.status_code == 422
def test_submitted_validation_empty_text(client):
"""POST /api/v1/submitted с пустым текстом возвращает 422."""
response = client.post(
"/api/v1/submitted",
json={"text": ""},
)
assert response.status_code == 422

View File

@@ -0,0 +1,212 @@
"""
Тесты для submitted-коллекции VectorStore.
"""
import numpy as np
import pytest
from app.storage.vector_store import VectorStore
@pytest.fixture
def vector_store(tmp_path):
"""VectorStore с временным путём для submitted."""
return VectorStore(
vector_dim=4,
max_examples=10,
max_submitted=5,
storage_path=None,
submitted_path=str(tmp_path / "submitted.npz"),
)
@pytest.fixture
def sample_vector():
"""Нормализованный вектор для тестов."""
v = np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32)
return v / np.linalg.norm(v)
def test_add_submitted(vector_store, sample_vector):
"""Добавление submitted-поста."""
added = vector_store.add_submitted(
vector=sample_vector,
text_hash="abc123",
created_at=1000,
post_id=42,
text="Test post",
rag_score=0.85,
)
assert added is True
assert vector_store.submitted_count == 1
def test_add_submitted_duplicate(vector_store, sample_vector):
"""Дубликат по хешу не добавляется."""
vector_store.add_submitted(
vector=sample_vector,
text_hash="same_hash",
created_at=1000,
text="First",
)
added = vector_store.add_submitted(
vector=sample_vector,
text_hash="same_hash",
created_at=2000,
text="Second",
)
assert added is False
assert vector_store.submitted_count == 1
def test_add_submitted_fifo(vector_store, sample_vector):
"""При превышении max_submitted удаляется самый старый (FIFO)."""
for i in range(7):
v = np.array(
[float(i + 1), 0.0, 0.0, 0.0], dtype=np.float32
) # i+1 чтобы избежать нулевого вектора
v = v / np.linalg.norm(v)
vector_store.add_submitted(
vector=v,
text_hash=f"hash_{i}",
created_at=1000 + i,
post_id=i,
text=f"Post {i}",
)
assert vector_store.submitted_count == 5 # max_submitted
# Должны остаться посты 2, 3, 4, 5, 6 (удалены 0, 1)
post_ids = vector_store._submitted_post_ids
assert 0 not in post_ids
assert 1 not in post_ids
assert 2 in post_ids
def test_find_similar_submitted_empty(vector_store, sample_vector):
"""Поиск в пустой коллекции возвращает пустой список."""
result = vector_store.find_similar_submitted(
vector=sample_vector,
threshold=0.5,
hours=24,
)
assert result == []
def test_find_similar_submitted(vector_store, sample_vector):
"""Поиск похожих постов с фильтром по времени и threshold."""
import time
now = int(time.time())
# Похожий вектор
similar_v = np.array([0.99, 0.01, 0.0, 0.0], dtype=np.float32)
similar_v = similar_v / np.linalg.norm(similar_v)
# Непохожий вектор
different_v = np.array([0.0, 1.0, 0.0, 0.0], dtype=np.float32)
different_v = different_v / np.linalg.norm(different_v)
vector_store.add_submitted(
vector=similar_v,
text_hash="similar",
created_at=now - 3600, # 1 час назад
post_id=1,
text="Similar post",
rag_score=0.9,
)
vector_store.add_submitted(
vector=different_v,
text_hash="different",
created_at=now - 3600,
post_id=2,
text="Different post",
rag_score=0.5,
)
result = vector_store.find_similar_submitted(
vector=sample_vector,
threshold=0.9,
hours=24,
)
assert len(result) == 1
assert result[0]["post_id"] == 1
assert result[0]["text"] == "Similar post"
assert result[0]["similarity"] >= 0.9
def test_find_similar_submitted_time_filter(vector_store, sample_vector):
"""Фильтр по hours исключает старые посты."""
import time
now = int(time.time())
vector_store.add_submitted(
vector=sample_vector,
text_hash="old",
created_at=now - 48 * 3600, # 48 часов назад
post_id=1,
text="Old post",
)
vector_store.add_submitted(
vector=sample_vector,
text_hash="recent",
created_at=now - 3600, # 1 час назад
post_id=2,
text="Recent post",
)
result = vector_store.find_similar_submitted(
vector=sample_vector,
threshold=0.5,
hours=24,
)
assert len(result) == 1
assert result[0]["post_id"] == 2
def test_submitted_persistence(vector_store, sample_vector, tmp_path):
"""Сохранение и загрузка submitted-коллекции."""
vector_store.add_submitted(
vector=sample_vector,
text_hash="persist",
created_at=12345,
post_id=999,
text="Persisted post",
rag_score=0.77,
)
vector_store.save_submitted_to_disk()
# Новый store загружает данные
store2 = VectorStore(
vector_dim=4,
max_submitted=5,
storage_path=None,
submitted_path=str(tmp_path / "submitted.npz"),
)
assert store2.submitted_count == 1
assert store2._submitted_post_ids[0] == 999
assert store2._submitted_texts[0] == "Persisted post"
assert store2._submitted_rag_scores[0] == 0.77
def test_get_stats_includes_submitted(vector_store, sample_vector):
"""get_stats включает submitted_count и max_submitted."""
vector_store.add_submitted(
vector=sample_vector,
text_hash="stat",
created_at=1000,
text="For stats",
)
stats = vector_store.get_stats()
assert "submitted_count" in stats
assert stats["submitted_count"] == 1
assert "max_submitted" in stats
assert stats["max_submitted"] == 5
def test_clear_submitted(vector_store, sample_vector):
"""clear() очищает submitted-коллекцию."""
vector_store.add_submitted(
vector=sample_vector,
text_hash="clear",
created_at=1000,
text="To clear",
)
vector_store.clear()
assert vector_store.submitted_count == 0