feat: add submitted collection, /similar and /submitted endpoints (Stage 4)
Made-with: Cursor
This commit is contained in:
16
tests/conftest.py
Normal file
16
tests/conftest.py
Normal file
@@ -0,0 +1,16 @@
|
||||
"""
|
||||
Pytest fixtures для RAG сервиса.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def allow_no_auth():
|
||||
"""Разрешает запросы без API ключа в тестах."""
|
||||
os.environ["RAG_ALLOW_NO_AUTH"] = "true"
|
||||
yield
|
||||
if "RAG_ALLOW_NO_AUTH" in os.environ:
|
||||
del os.environ["RAG_ALLOW_NO_AUTH"]
|
||||
169
tests/test_api_similar_submitted.py
Normal file
169
tests/test_api_similar_submitted.py
Normal file
@@ -0,0 +1,169 @@
|
||||
"""
|
||||
Тесты API endpoints /similar и /submitted.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.main import app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_rag_service():
|
||||
"""Mock RAGService для тестов API без загрузки модели."""
|
||||
service = MagicMock()
|
||||
service.is_model_loaded = True
|
||||
service.vector_store.submitted_count = 0
|
||||
return service
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(mock_rag_service):
|
||||
"""TestClient с переопределённым RAG сервисом."""
|
||||
from app.api.routes import get_service
|
||||
|
||||
def override_get_service():
|
||||
return mock_rag_service
|
||||
|
||||
app.dependency_overrides[get_service] = override_get_service
|
||||
# get_rag_service используется при создании сервиса - get_service вызывает get_rag_service
|
||||
# Смотрю routes: get_service возвращает get_rag_service(). Значит override get_service достаточно.
|
||||
with TestClient(app) as c:
|
||||
yield c
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
def test_similar_endpoint(client, mock_rag_service):
|
||||
"""POST /api/v1/similar возвращает похожие посты."""
|
||||
mock_rag_service.find_similar_posts = AsyncMock(
|
||||
return_value=[
|
||||
{
|
||||
"similarity": 0.95,
|
||||
"created_at": 1706270000,
|
||||
"post_id": 123,
|
||||
"text": "Похожий пост",
|
||||
"rag_score": 0.85,
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/similar",
|
||||
json={"text": "Текст для поиска", "threshold": 0.9, "hours": 24},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["similar_count"] == 1
|
||||
assert len(data["similar_posts"]) == 1
|
||||
assert data["similar_posts"][0]["similarity"] == 0.95
|
||||
assert data["similar_posts"][0]["post_id"] == 123
|
||||
assert data["similar_posts"][0]["text"] == "Похожий пост"
|
||||
assert data["similar_posts"][0]["rag_score"] == 0.85
|
||||
|
||||
|
||||
def test_similar_endpoint_empty(client, mock_rag_service):
|
||||
"""POST /api/v1/similar с пустым результатом."""
|
||||
mock_rag_service.find_similar_posts = AsyncMock(return_value=[])
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/similar",
|
||||
json={"text": "Уникальный текст", "threshold": 0.99, "hours": 1},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["similar_count"] == 0
|
||||
assert response.json()["similar_posts"] == []
|
||||
|
||||
|
||||
def test_similar_endpoint_default_params(client, mock_rag_service):
|
||||
"""POST /api/v1/similar с дефолтными параметрами."""
|
||||
mock_rag_service.find_similar_posts = AsyncMock(return_value=[])
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/similar",
|
||||
json={"text": "Текст"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
mock_rag_service.find_similar_posts.assert_called_once_with(
|
||||
text="Текст",
|
||||
threshold=0.9,
|
||||
hours=24,
|
||||
)
|
||||
|
||||
|
||||
def test_submitted_endpoint_success(client, mock_rag_service):
|
||||
"""POST /api/v1/submitted успешно добавляет пост."""
|
||||
mock_rag_service.add_submitted_post = AsyncMock(return_value=True)
|
||||
mock_rag_service.vector_store.submitted_count = 5
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/submitted",
|
||||
json={"text": "Новый пост", "post_id": 42, "rag_score": 0.8},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
assert "добавлен" in data["message"].lower()
|
||||
assert data["submitted_count"] == 5
|
||||
mock_rag_service.add_submitted_post.assert_called_once_with(
|
||||
text="Новый пост",
|
||||
post_id=42,
|
||||
rag_score=0.8,
|
||||
)
|
||||
|
||||
|
||||
def test_submitted_endpoint_duplicate(client, mock_rag_service):
|
||||
"""POST /api/v1/submitted при дубликате возвращает success=False."""
|
||||
mock_rag_service.add_submitted_post = AsyncMock(return_value=False)
|
||||
mock_rag_service.vector_store.submitted_count = 10
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/submitted",
|
||||
json={"text": "Дубликат поста"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["success"] is False
|
||||
assert data["submitted_count"] == 10
|
||||
|
||||
|
||||
def test_submitted_endpoint_optional_fields(client, mock_rag_service):
|
||||
"""POST /api/v1/submitted с только текстом (post_id, rag_score опциональны)."""
|
||||
mock_rag_service.add_submitted_post = AsyncMock(return_value=True)
|
||||
mock_rag_service.vector_store.submitted_count = 1
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/submitted",
|
||||
json={"text": "Только текст"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
mock_rag_service.add_submitted_post.assert_called_once_with(
|
||||
text="Только текст",
|
||||
post_id=None,
|
||||
rag_score=None,
|
||||
)
|
||||
|
||||
|
||||
def test_similar_validation_empty_text(client):
|
||||
"""POST /api/v1/similar с пустым текстом возвращает 422."""
|
||||
response = client.post(
|
||||
"/api/v1/similar",
|
||||
json={"text": "", "threshold": 0.9, "hours": 24},
|
||||
)
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
def test_submitted_validation_empty_text(client):
|
||||
"""POST /api/v1/submitted с пустым текстом возвращает 422."""
|
||||
response = client.post(
|
||||
"/api/v1/submitted",
|
||||
json={"text": ""},
|
||||
)
|
||||
assert response.status_code == 422
|
||||
212
tests/test_vector_store_submitted.py
Normal file
212
tests/test_vector_store_submitted.py
Normal file
@@ -0,0 +1,212 @@
|
||||
"""
|
||||
Тесты для submitted-коллекции VectorStore.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from app.storage.vector_store import VectorStore
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def vector_store(tmp_path):
|
||||
"""VectorStore с временным путём для submitted."""
|
||||
return VectorStore(
|
||||
vector_dim=4,
|
||||
max_examples=10,
|
||||
max_submitted=5,
|
||||
storage_path=None,
|
||||
submitted_path=str(tmp_path / "submitted.npz"),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_vector():
|
||||
"""Нормализованный вектор для тестов."""
|
||||
v = np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32)
|
||||
return v / np.linalg.norm(v)
|
||||
|
||||
|
||||
def test_add_submitted(vector_store, sample_vector):
|
||||
"""Добавление submitted-поста."""
|
||||
added = vector_store.add_submitted(
|
||||
vector=sample_vector,
|
||||
text_hash="abc123",
|
||||
created_at=1000,
|
||||
post_id=42,
|
||||
text="Test post",
|
||||
rag_score=0.85,
|
||||
)
|
||||
assert added is True
|
||||
assert vector_store.submitted_count == 1
|
||||
|
||||
|
||||
def test_add_submitted_duplicate(vector_store, sample_vector):
|
||||
"""Дубликат по хешу не добавляется."""
|
||||
vector_store.add_submitted(
|
||||
vector=sample_vector,
|
||||
text_hash="same_hash",
|
||||
created_at=1000,
|
||||
text="First",
|
||||
)
|
||||
added = vector_store.add_submitted(
|
||||
vector=sample_vector,
|
||||
text_hash="same_hash",
|
||||
created_at=2000,
|
||||
text="Second",
|
||||
)
|
||||
assert added is False
|
||||
assert vector_store.submitted_count == 1
|
||||
|
||||
|
||||
def test_add_submitted_fifo(vector_store, sample_vector):
|
||||
"""При превышении max_submitted удаляется самый старый (FIFO)."""
|
||||
for i in range(7):
|
||||
v = np.array(
|
||||
[float(i + 1), 0.0, 0.0, 0.0], dtype=np.float32
|
||||
) # i+1 чтобы избежать нулевого вектора
|
||||
v = v / np.linalg.norm(v)
|
||||
vector_store.add_submitted(
|
||||
vector=v,
|
||||
text_hash=f"hash_{i}",
|
||||
created_at=1000 + i,
|
||||
post_id=i,
|
||||
text=f"Post {i}",
|
||||
)
|
||||
assert vector_store.submitted_count == 5 # max_submitted
|
||||
# Должны остаться посты 2, 3, 4, 5, 6 (удалены 0, 1)
|
||||
post_ids = vector_store._submitted_post_ids
|
||||
assert 0 not in post_ids
|
||||
assert 1 not in post_ids
|
||||
assert 2 in post_ids
|
||||
|
||||
|
||||
def test_find_similar_submitted_empty(vector_store, sample_vector):
|
||||
"""Поиск в пустой коллекции возвращает пустой список."""
|
||||
result = vector_store.find_similar_submitted(
|
||||
vector=sample_vector,
|
||||
threshold=0.5,
|
||||
hours=24,
|
||||
)
|
||||
assert result == []
|
||||
|
||||
|
||||
def test_find_similar_submitted(vector_store, sample_vector):
|
||||
"""Поиск похожих постов с фильтром по времени и threshold."""
|
||||
import time
|
||||
|
||||
now = int(time.time())
|
||||
# Похожий вектор
|
||||
similar_v = np.array([0.99, 0.01, 0.0, 0.0], dtype=np.float32)
|
||||
similar_v = similar_v / np.linalg.norm(similar_v)
|
||||
# Непохожий вектор
|
||||
different_v = np.array([0.0, 1.0, 0.0, 0.0], dtype=np.float32)
|
||||
different_v = different_v / np.linalg.norm(different_v)
|
||||
|
||||
vector_store.add_submitted(
|
||||
vector=similar_v,
|
||||
text_hash="similar",
|
||||
created_at=now - 3600, # 1 час назад
|
||||
post_id=1,
|
||||
text="Similar post",
|
||||
rag_score=0.9,
|
||||
)
|
||||
vector_store.add_submitted(
|
||||
vector=different_v,
|
||||
text_hash="different",
|
||||
created_at=now - 3600,
|
||||
post_id=2,
|
||||
text="Different post",
|
||||
rag_score=0.5,
|
||||
)
|
||||
|
||||
result = vector_store.find_similar_submitted(
|
||||
vector=sample_vector,
|
||||
threshold=0.9,
|
||||
hours=24,
|
||||
)
|
||||
assert len(result) == 1
|
||||
assert result[0]["post_id"] == 1
|
||||
assert result[0]["text"] == "Similar post"
|
||||
assert result[0]["similarity"] >= 0.9
|
||||
|
||||
|
||||
def test_find_similar_submitted_time_filter(vector_store, sample_vector):
|
||||
"""Фильтр по hours исключает старые посты."""
|
||||
import time
|
||||
|
||||
now = int(time.time())
|
||||
vector_store.add_submitted(
|
||||
vector=sample_vector,
|
||||
text_hash="old",
|
||||
created_at=now - 48 * 3600, # 48 часов назад
|
||||
post_id=1,
|
||||
text="Old post",
|
||||
)
|
||||
vector_store.add_submitted(
|
||||
vector=sample_vector,
|
||||
text_hash="recent",
|
||||
created_at=now - 3600, # 1 час назад
|
||||
post_id=2,
|
||||
text="Recent post",
|
||||
)
|
||||
|
||||
result = vector_store.find_similar_submitted(
|
||||
vector=sample_vector,
|
||||
threshold=0.5,
|
||||
hours=24,
|
||||
)
|
||||
assert len(result) == 1
|
||||
assert result[0]["post_id"] == 2
|
||||
|
||||
|
||||
def test_submitted_persistence(vector_store, sample_vector, tmp_path):
|
||||
"""Сохранение и загрузка submitted-коллекции."""
|
||||
vector_store.add_submitted(
|
||||
vector=sample_vector,
|
||||
text_hash="persist",
|
||||
created_at=12345,
|
||||
post_id=999,
|
||||
text="Persisted post",
|
||||
rag_score=0.77,
|
||||
)
|
||||
vector_store.save_submitted_to_disk()
|
||||
|
||||
# Новый store загружает данные
|
||||
store2 = VectorStore(
|
||||
vector_dim=4,
|
||||
max_submitted=5,
|
||||
storage_path=None,
|
||||
submitted_path=str(tmp_path / "submitted.npz"),
|
||||
)
|
||||
assert store2.submitted_count == 1
|
||||
assert store2._submitted_post_ids[0] == 999
|
||||
assert store2._submitted_texts[0] == "Persisted post"
|
||||
assert store2._submitted_rag_scores[0] == 0.77
|
||||
|
||||
|
||||
def test_get_stats_includes_submitted(vector_store, sample_vector):
|
||||
"""get_stats включает submitted_count и max_submitted."""
|
||||
vector_store.add_submitted(
|
||||
vector=sample_vector,
|
||||
text_hash="stat",
|
||||
created_at=1000,
|
||||
text="For stats",
|
||||
)
|
||||
stats = vector_store.get_stats()
|
||||
assert "submitted_count" in stats
|
||||
assert stats["submitted_count"] == 1
|
||||
assert "max_submitted" in stats
|
||||
assert stats["max_submitted"] == 5
|
||||
|
||||
|
||||
def test_clear_submitted(vector_store, sample_vector):
|
||||
"""clear() очищает submitted-коллекцию."""
|
||||
vector_store.add_submitted(
|
||||
vector=sample_vector,
|
||||
text_hash="clear",
|
||||
created_at=1000,
|
||||
text="To clear",
|
||||
)
|
||||
vector_store.clear()
|
||||
assert vector_store.submitted_count == 0
|
||||
Reference in New Issue
Block a user