feat: add submitted collection, /similar and /submitted endpoints (Stage 4) #1
2
.gitignore
vendored
2
.gitignore
vendored
@@ -133,6 +133,8 @@ Thumbs.db
|
|||||||
# Project specific
|
# Project specific
|
||||||
data/models/
|
data/models/
|
||||||
data/vectors/*.npz
|
data/vectors/*.npz
|
||||||
|
*.bak
|
||||||
|
*.tar.gz
|
||||||
|
|
||||||
# Keep data directories
|
# Keep data directories
|
||||||
!data/models/.gitkeep
|
!data/models/.gitkeep
|
||||||
|
|||||||
@@ -14,9 +14,9 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
|
|||||||
# Копируем зависимости
|
# Копируем зависимости
|
||||||
COPY requirements.txt .
|
COPY requirements.txt .
|
||||||
|
|
||||||
# Устанавливаем зависимости
|
# Устанавливаем зависимости (CPU-only torch для контейнеров без GPU)
|
||||||
# --no-cache-dir для уменьшения размера образа
|
RUN pip install --no-cache-dir torch --index-url https://download.pytorch.org/whl/cpu \
|
||||||
RUN pip install --no-cache-dir -r requirements.txt
|
&& pip install --no-cache-dir -r requirements.txt
|
||||||
|
|
||||||
# Копируем код приложения
|
# Копируем код приложения
|
||||||
COPY app/ ./app/
|
COPY app/ ./app/
|
||||||
|
|||||||
@@ -24,14 +24,14 @@ async def verify_api_key(
|
|||||||
) -> bool:
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
Проверяет API ключ из заголовка запроса.
|
Проверяет API ключ из заголовка запроса.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
api_key: Ключ из заголовка X-API-Key
|
api_key: Ключ из заголовка X-API-Key
|
||||||
settings: Настройки приложения
|
settings: Настройки приложения
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
True если авторизация успешна
|
True если авторизация успешна
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
HTTPException: Если ключ неверный или отсутствует
|
HTTPException: Если ключ неверный или отсутствует
|
||||||
"""
|
"""
|
||||||
@@ -47,7 +47,7 @@ async def verify_api_key(
|
|||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
detail="API ключ не настроен на сервере",
|
detail="API ключ не настроен на сервере",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Проверяем ключ
|
# Проверяем ключ
|
||||||
if api_key is None:
|
if api_key is None:
|
||||||
logger.warning("Запрос без API ключа")
|
logger.warning("Запрос без API ключа")
|
||||||
@@ -56,14 +56,14 @@ async def verify_api_key(
|
|||||||
detail="API ключ не предоставлен. Используйте заголовок X-API-Key",
|
detail="API ключ не предоставлен. Используйте заголовок X-API-Key",
|
||||||
headers={"WWW-Authenticate": "ApiKey"},
|
headers={"WWW-Authenticate": "ApiKey"},
|
||||||
)
|
)
|
||||||
|
|
||||||
if api_key != settings.api_key:
|
if api_key != settings.api_key:
|
||||||
logger.warning("Неверный API ключ")
|
logger.warning("Неверный API ключ")
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
detail="Неверный API ключ",
|
detail="Неверный API ключ",
|
||||||
)
|
)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -24,7 +24,12 @@ from app.schemas import (
|
|||||||
ScoreRequest,
|
ScoreRequest,
|
||||||
ScoreResponse,
|
ScoreResponse,
|
||||||
ScoringParamsResponse,
|
ScoringParamsResponse,
|
||||||
|
SimilarPostItem,
|
||||||
|
SimilarRequest,
|
||||||
|
SimilarResponse,
|
||||||
StatsResponse,
|
StatsResponse,
|
||||||
|
SubmittedRequest,
|
||||||
|
SubmittedResponse,
|
||||||
UpdateScoringParamsRequest,
|
UpdateScoringParamsRequest,
|
||||||
VectorStoreStats,
|
VectorStoreStats,
|
||||||
WarmupResponse,
|
WarmupResponse,
|
||||||
@@ -49,6 +54,7 @@ RAGServiceDep = Annotated[RAGService, Depends(get_service)]
|
|||||||
# Health Check
|
# Health Check
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
"/health",
|
"/health",
|
||||||
response_model=HealthResponse,
|
response_model=HealthResponse,
|
||||||
@@ -58,7 +64,7 @@ RAGServiceDep = Annotated[RAGService, Depends(get_service)]
|
|||||||
async def health_check(service: RAGServiceDep, _auth: AuthDep) -> HealthResponse:
|
async def health_check(service: RAGServiceDep, _auth: AuthDep) -> HealthResponse:
|
||||||
"""
|
"""
|
||||||
Проверяет состояние сервиса.
|
Проверяет состояние сервиса.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
HealthResponse: Статус сервиса
|
HealthResponse: Статус сервиса
|
||||||
"""
|
"""
|
||||||
@@ -73,6 +79,7 @@ async def health_check(service: RAGServiceDep, _auth: AuthDep) -> HealthResponse
|
|||||||
# Scoring
|
# Scoring
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
"/score",
|
"/score",
|
||||||
response_model=ScoreResponse,
|
response_model=ScoreResponse,
|
||||||
@@ -86,55 +93,55 @@ async def health_check(service: RAGServiceDep, _auth: AuthDep) -> HealthResponse
|
|||||||
tags=["scoring"],
|
tags=["scoring"],
|
||||||
)
|
)
|
||||||
async def calculate_score(
|
async def calculate_score(
|
||||||
request: ScoreRequest,
|
request: ScoreRequest,
|
||||||
service: RAGServiceDep,
|
service: RAGServiceDep,
|
||||||
_auth: AuthDep,
|
_auth: AuthDep,
|
||||||
) -> ScoreResponse:
|
) -> ScoreResponse:
|
||||||
"""
|
"""
|
||||||
Рассчитывает скор для текста поста.
|
Рассчитывает скор для текста поста.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
request: Запрос с текстом
|
request: Запрос с текстом
|
||||||
service: RAG сервис
|
service: RAG сервис
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
ScoreResponse: Результат скоринга
|
ScoreResponse: Результат скоринга
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
HTTPException: При ошибке расчета
|
HTTPException: При ошибке расчета
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
result = await service.calculate_score(request.text)
|
result = await service.calculate_score(request.text)
|
||||||
response_dict = result.to_dict()
|
response_dict = result.to_dict()
|
||||||
|
|
||||||
return ScoreResponse(
|
return ScoreResponse(
|
||||||
rag_score=response_dict["rag_score"],
|
rag_score=response_dict["rag_score"],
|
||||||
rag_confidence=response_dict["rag_confidence"],
|
rag_confidence=response_dict["rag_confidence"],
|
||||||
rag_score_pos_only=response_dict["rag_score_pos_only"],
|
rag_score_pos_only=response_dict["rag_score_pos_only"],
|
||||||
meta=ScoreMetadata(**response_dict["meta"]),
|
meta=ScoreMetadata(**response_dict["meta"]),
|
||||||
)
|
)
|
||||||
|
|
||||||
except TextTooShortError as e:
|
except TextTooShortError as e:
|
||||||
logger.warning(f"Текст слишком короткий: {e}")
|
logger.warning(f"Текст слишком короткий: {e}")
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
detail={"detail": str(e), "error_type": "TextTooShortError"},
|
detail={"detail": str(e), "error_type": "TextTooShortError"},
|
||||||
)
|
)
|
||||||
|
|
||||||
except InsufficientExamplesError as e:
|
except InsufficientExamplesError as e:
|
||||||
logger.warning(f"Недостаточно примеров: {e}")
|
logger.warning(f"Недостаточно примеров: {e}")
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
detail={"detail": str(e), "error_type": "InsufficientExamplesError"},
|
detail={"detail": str(e), "error_type": "InsufficientExamplesError"},
|
||||||
)
|
)
|
||||||
|
|
||||||
except ModelNotLoadedError as e:
|
except ModelNotLoadedError as e:
|
||||||
logger.error(f"Модель не загружена: {e}")
|
logger.error(f"Модель не загружена: {e}")
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||||
detail={"detail": str(e), "error_type": "ModelNotLoadedError"},
|
detail={"detail": str(e), "error_type": "ModelNotLoadedError"},
|
||||||
)
|
)
|
||||||
|
|
||||||
except ScoringError as e:
|
except ScoringError as e:
|
||||||
logger.error(f"Ошибка скоринга: {e}")
|
logger.error(f"Ошибка скоринга: {e}")
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
@@ -147,6 +154,7 @@ async def calculate_score(
|
|||||||
# Examples
|
# Examples
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
"/examples/positive",
|
"/examples/positive",
|
||||||
response_model=ExampleResponse,
|
response_model=ExampleResponse,
|
||||||
@@ -159,27 +167,27 @@ async def calculate_score(
|
|||||||
tags=["examples"],
|
tags=["examples"],
|
||||||
)
|
)
|
||||||
async def add_positive_example(
|
async def add_positive_example(
|
||||||
request: ExampleRequest,
|
request: ExampleRequest,
|
||||||
service: RAGServiceDep,
|
service: RAGServiceDep,
|
||||||
_auth: AuthDep,
|
_auth: AuthDep,
|
||||||
x_test_mode: str | None = Header(default=None, alias="X-Test-Mode"),
|
x_test_mode: str | None = Header(default=None, alias="X-Test-Mode"),
|
||||||
) -> ExampleResponse:
|
) -> ExampleResponse:
|
||||||
"""
|
"""
|
||||||
Добавляет текст как положительный пример (опубликованный пост).
|
Добавляет текст как положительный пример (опубликованный пост).
|
||||||
|
|
||||||
При наличии заголовка X-Test-Mode: true пример НЕ сохраняется (тестовый режим).
|
При наличии заголовка X-Test-Mode: true пример НЕ сохраняется (тестовый режим).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
request: Запрос с текстом
|
request: Запрос с текстом
|
||||||
service: RAG сервис
|
service: RAG сервис
|
||||||
x_test_mode: Заголовок тестового режима
|
x_test_mode: Заголовок тестового режима
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
ExampleResponse: Результат добавления
|
ExampleResponse: Результат добавления
|
||||||
"""
|
"""
|
||||||
# Тестовый режим — не сохраняем примеры
|
# Тестовый режим — не сохраняем примеры
|
||||||
is_test = x_test_mode and x_test_mode.lower() == "true"
|
is_test = x_test_mode and x_test_mode.lower() == "true"
|
||||||
|
|
||||||
if is_test:
|
if is_test:
|
||||||
logger.info("Тестовый режим: положительный пример НЕ сохранён")
|
logger.info("Тестовый режим: положительный пример НЕ сохранён")
|
||||||
return ExampleResponse(
|
return ExampleResponse(
|
||||||
@@ -188,22 +196,22 @@ async def add_positive_example(
|
|||||||
positive_count=service.vector_store.positive_count,
|
positive_count=service.vector_store.positive_count,
|
||||||
negative_count=service.vector_store.negative_count,
|
negative_count=service.vector_store.negative_count,
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
added = await service.add_positive_example(request.text)
|
added = await service.add_positive_example(request.text)
|
||||||
|
|
||||||
if added:
|
if added:
|
||||||
message = "Положительный пример добавлен"
|
message = "Положительный пример добавлен"
|
||||||
else:
|
else:
|
||||||
message = "Пример не добавлен (дубликат или слишком короткий текст)"
|
message = "Пример не добавлен (дубликат или слишком короткий текст)"
|
||||||
|
|
||||||
return ExampleResponse(
|
return ExampleResponse(
|
||||||
success=added,
|
success=added,
|
||||||
message=message,
|
message=message,
|
||||||
positive_count=service.vector_store.positive_count,
|
positive_count=service.vector_store.positive_count,
|
||||||
negative_count=service.vector_store.negative_count,
|
negative_count=service.vector_store.negative_count,
|
||||||
)
|
)
|
||||||
|
|
||||||
except ModelNotLoadedError as e:
|
except ModelNotLoadedError as e:
|
||||||
logger.error(f"Модель не загружена: {e}")
|
logger.error(f"Модель не загружена: {e}")
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
@@ -224,27 +232,27 @@ async def add_positive_example(
|
|||||||
tags=["examples"],
|
tags=["examples"],
|
||||||
)
|
)
|
||||||
async def add_negative_example(
|
async def add_negative_example(
|
||||||
request: ExampleRequest,
|
request: ExampleRequest,
|
||||||
service: RAGServiceDep,
|
service: RAGServiceDep,
|
||||||
_auth: AuthDep,
|
_auth: AuthDep,
|
||||||
x_test_mode: str | None = Header(default=None, alias="X-Test-Mode"),
|
x_test_mode: str | None = Header(default=None, alias="X-Test-Mode"),
|
||||||
) -> ExampleResponse:
|
) -> ExampleResponse:
|
||||||
"""
|
"""
|
||||||
Добавляет текст как отрицательный пример (отклоненный пост).
|
Добавляет текст как отрицательный пример (отклоненный пост).
|
||||||
|
|
||||||
При наличии заголовка X-Test-Mode: true пример НЕ сохраняется (тестовый режим).
|
При наличии заголовка X-Test-Mode: true пример НЕ сохраняется (тестовый режим).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
request: Запрос с текстом
|
request: Запрос с текстом
|
||||||
service: RAG сервис
|
service: RAG сервис
|
||||||
x_test_mode: Заголовок тестового режима
|
x_test_mode: Заголовок тестового режима
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
ExampleResponse: Результат добавления
|
ExampleResponse: Результат добавления
|
||||||
"""
|
"""
|
||||||
# Тестовый режим — не сохраняем примеры
|
# Тестовый режим — не сохраняем примеры
|
||||||
is_test = x_test_mode and x_test_mode.lower() == "true"
|
is_test = x_test_mode and x_test_mode.lower() == "true"
|
||||||
|
|
||||||
if is_test:
|
if is_test:
|
||||||
logger.info("Тестовый режим: отрицательный пример НЕ сохранён")
|
logger.info("Тестовый режим: отрицательный пример НЕ сохранён")
|
||||||
return ExampleResponse(
|
return ExampleResponse(
|
||||||
@@ -253,22 +261,128 @@ async def add_negative_example(
|
|||||||
positive_count=service.vector_store.positive_count,
|
positive_count=service.vector_store.positive_count,
|
||||||
negative_count=service.vector_store.negative_count,
|
negative_count=service.vector_store.negative_count,
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
added = await service.add_negative_example(request.text)
|
added = await service.add_negative_example(request.text)
|
||||||
|
|
||||||
if added:
|
if added:
|
||||||
message = "Отрицательный пример добавлен"
|
message = "Отрицательный пример добавлен"
|
||||||
else:
|
else:
|
||||||
message = "Пример не добавлен (дубликат или слишком короткий текст)"
|
message = "Пример не добавлен (дубликат или слишком короткий текст)"
|
||||||
|
|
||||||
return ExampleResponse(
|
return ExampleResponse(
|
||||||
success=added,
|
success=added,
|
||||||
message=message,
|
message=message,
|
||||||
positive_count=service.vector_store.positive_count,
|
positive_count=service.vector_store.positive_count,
|
||||||
negative_count=service.vector_store.negative_count,
|
negative_count=service.vector_store.negative_count,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
except ModelNotLoadedError as e:
|
||||||
|
logger.error(f"Модель не загружена: {e}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||||
|
detail={"detail": str(e), "error_type": "ModelNotLoadedError"},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Similar & Submitted
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/similar",
|
||||||
|
response_model=SimilarResponse,
|
||||||
|
responses={
|
||||||
|
400: {"model": ErrorResponse, "description": "Ошибка в запросе"},
|
||||||
|
401: {"model": ErrorResponse, "description": "Не авторизован"},
|
||||||
|
403: {"model": ErrorResponse, "description": "Доступ запрещён"},
|
||||||
|
503: {"model": ErrorResponse, "description": "Сервис недоступен"},
|
||||||
|
},
|
||||||
|
summary="Поиск похожих постов",
|
||||||
|
tags=["similar"],
|
||||||
|
)
|
||||||
|
async def find_similar_posts(
|
||||||
|
request: SimilarRequest,
|
||||||
|
service: RAGServiceDep,
|
||||||
|
_auth: AuthDep,
|
||||||
|
) -> SimilarResponse:
|
||||||
|
"""
|
||||||
|
Ищет похожие submitted-посты за последние N часов.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: Запрос с текстом, threshold и hours
|
||||||
|
service: RAG сервис
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SimilarResponse: Список похожих постов
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
similar = await service.find_similar_posts(
|
||||||
|
text=request.text,
|
||||||
|
threshold=request.threshold,
|
||||||
|
hours=request.hours,
|
||||||
|
)
|
||||||
|
return SimilarResponse(
|
||||||
|
similar_count=len(similar),
|
||||||
|
similar_posts=[SimilarPostItem(**item) for item in similar],
|
||||||
|
)
|
||||||
|
except TextTooShortError as e:
|
||||||
|
logger.warning(f"Текст слишком короткий: {e}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail={"detail": str(e), "error_type": "TextTooShortError"},
|
||||||
|
)
|
||||||
|
except ModelNotLoadedError as e:
|
||||||
|
logger.error(f"Модель не загружена: {e}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||||
|
detail={"detail": str(e), "error_type": "ModelNotLoadedError"},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/submitted",
|
||||||
|
response_model=SubmittedResponse,
|
||||||
|
responses={
|
||||||
|
400: {"model": ErrorResponse, "description": "Ошибка в запросе"},
|
||||||
|
401: {"model": ErrorResponse, "description": "Не авторизован"},
|
||||||
|
403: {"model": ErrorResponse, "description": "Доступ запрещён"},
|
||||||
|
503: {"model": ErrorResponse, "description": "Сервис недоступен"},
|
||||||
|
},
|
||||||
|
summary="Добавить submitted-пост",
|
||||||
|
tags=["submitted"],
|
||||||
|
)
|
||||||
|
async def add_submitted_post(
|
||||||
|
request: SubmittedRequest,
|
||||||
|
service: RAGServiceDep,
|
||||||
|
_auth: AuthDep,
|
||||||
|
) -> SubmittedResponse:
|
||||||
|
"""
|
||||||
|
Добавляет submitted-пост в коллекцию для индексации ботом.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: Запрос с текстом, post_id и rag_score
|
||||||
|
service: RAG сервис
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SubmittedResponse: Результат добавления
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
added = await service.add_submitted_post(
|
||||||
|
text=request.text,
|
||||||
|
post_id=request.post_id,
|
||||||
|
rag_score=request.rag_score,
|
||||||
|
)
|
||||||
|
if added:
|
||||||
|
message = "Submitted-пост добавлен"
|
||||||
|
else:
|
||||||
|
message = "Пост не добавлен (дубликат или слишком короткий текст)"
|
||||||
|
return SubmittedResponse(
|
||||||
|
success=added,
|
||||||
|
message=message,
|
||||||
|
submitted_count=service.vector_store.submitted_count,
|
||||||
|
)
|
||||||
except ModelNotLoadedError as e:
|
except ModelNotLoadedError as e:
|
||||||
logger.error(f"Модель не загружена: {e}")
|
logger.error(f"Модель не загружена: {e}")
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
@@ -281,6 +395,7 @@ async def add_negative_example(
|
|||||||
# Stats & Warmup
|
# Stats & Warmup
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
"/stats",
|
"/stats",
|
||||||
response_model=StatsResponse,
|
response_model=StatsResponse,
|
||||||
@@ -294,15 +409,15 @@ async def add_negative_example(
|
|||||||
async def get_stats(service: RAGServiceDep, _auth: AuthDep) -> StatsResponse:
|
async def get_stats(service: RAGServiceDep, _auth: AuthDep) -> StatsResponse:
|
||||||
"""
|
"""
|
||||||
Возвращает статистику сервиса.
|
Возвращает статистику сервиса.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
service: RAG сервис
|
service: RAG сервис
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
StatsResponse: Статистика
|
StatsResponse: Статистика
|
||||||
"""
|
"""
|
||||||
stats = service.get_stats()
|
stats = service.get_stats()
|
||||||
|
|
||||||
return StatsResponse(
|
return StatsResponse(
|
||||||
model_name=stats["model_name"],
|
model_name=stats["model_name"],
|
||||||
model_loaded=stats["model_loaded"],
|
model_loaded=stats["model_loaded"],
|
||||||
@@ -325,15 +440,15 @@ async def get_stats(service: RAGServiceDep, _auth: AuthDep) -> StatsResponse:
|
|||||||
async def warmup(service: RAGServiceDep, _auth: AuthDep) -> WarmupResponse:
|
async def warmup(service: RAGServiceDep, _auth: AuthDep) -> WarmupResponse:
|
||||||
"""
|
"""
|
||||||
Прогревает модель (загружает если не загружена).
|
Прогревает модель (загружает если не загружена).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
service: RAG сервис
|
service: RAG сервис
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
WarmupResponse: Результат прогрева
|
WarmupResponse: Результат прогрева
|
||||||
"""
|
"""
|
||||||
success = await service.warmup()
|
success = await service.warmup()
|
||||||
|
|
||||||
if success:
|
if success:
|
||||||
message = "Модель успешно загружена"
|
message = "Модель успешно загружена"
|
||||||
else:
|
else:
|
||||||
@@ -342,7 +457,7 @@ async def warmup(service: RAGServiceDep, _auth: AuthDep) -> WarmupResponse:
|
|||||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||||
detail={"detail": message, "error_type": "ModelNotLoadedError"},
|
detail={"detail": message, "error_type": "ModelNotLoadedError"},
|
||||||
)
|
)
|
||||||
|
|
||||||
return WarmupResponse(
|
return WarmupResponse(
|
||||||
success=success,
|
success=success,
|
||||||
model_loaded=service.is_model_loaded,
|
model_loaded=service.is_model_loaded,
|
||||||
@@ -354,6 +469,7 @@ async def warmup(service: RAGServiceDep, _auth: AuthDep) -> WarmupResponse:
|
|||||||
# Scoring Parameters
|
# Scoring Parameters
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
"/scoring/params",
|
"/scoring/params",
|
||||||
response_model=ScoringParamsResponse,
|
response_model=ScoringParamsResponse,
|
||||||
@@ -370,10 +486,10 @@ async def get_scoring_params(
|
|||||||
) -> ScoringParamsResponse:
|
) -> ScoringParamsResponse:
|
||||||
"""
|
"""
|
||||||
Возвращает текущие параметры формулы расчета score.
|
Возвращает текущие параметры формулы расчета score.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
service: RAG сервис
|
service: RAG сервис
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
ScoringParamsResponse: Текущие параметры формулы
|
ScoringParamsResponse: Текущие параметры формулы
|
||||||
"""
|
"""
|
||||||
@@ -399,17 +515,17 @@ async def update_scoring_params(
|
|||||||
) -> ScoringParamsResponse:
|
) -> ScoringParamsResponse:
|
||||||
"""
|
"""
|
||||||
Обновляет параметры формулы расчета score.
|
Обновляет параметры формулы расчета score.
|
||||||
|
|
||||||
Можно обновить один или несколько параметров одновременно.
|
Можно обновить один или несколько параметров одновременно.
|
||||||
Параметры, которые не указаны, остаются без изменений.
|
Параметры, которые не указаны, остаются без изменений.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
request: Запрос с новыми параметрами
|
request: Запрос с новыми параметрами
|
||||||
service: RAG сервис
|
service: RAG сервис
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
ScoringParamsResponse: Обновленные параметры формулы
|
ScoringParamsResponse: Обновленные параметры формулы
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
HTTPException: При невалидных значениях параметров
|
HTTPException: При невалидных значениях параметров
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -5,82 +5,69 @@
|
|||||||
import os
|
import os
|
||||||
import secrets
|
import secrets
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Settings:
|
class Settings:
|
||||||
"""
|
"""
|
||||||
Настройки RAG сервиса.
|
Настройки RAG сервиса.
|
||||||
|
|
||||||
Все параметры загружаются из переменных окружения.
|
Все параметры загружаются из переменных окружения.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Модель
|
# Модель
|
||||||
model_name: str = field(
|
model_name: str = field(
|
||||||
default_factory=lambda: os.getenv("RAG_MODEL", "sentence-transformers/all-MiniLM-L12-v2")
|
default_factory=lambda: os.getenv("RAG_MODEL", "sentence-transformers/all-MiniLM-L12-v2")
|
||||||
)
|
)
|
||||||
cache_dir: str = field(
|
cache_dir: str = field(default_factory=lambda: os.getenv("RAG_CACHE_DIR", "data/models"))
|
||||||
default_factory=lambda: os.getenv("RAG_CACHE_DIR", "data/models")
|
|
||||||
)
|
|
||||||
|
|
||||||
# VectorStore
|
# VectorStore
|
||||||
vectors_path: str = field(
|
vectors_path: str = field(
|
||||||
default_factory=lambda: os.getenv("RAG_VECTORS_PATH", "data/vectors/vectors.npz")
|
default_factory=lambda: os.getenv("RAG_VECTORS_PATH", "data/vectors/vectors.npz")
|
||||||
)
|
)
|
||||||
max_examples: int = field(
|
max_examples: int = field(default_factory=lambda: int(os.getenv("RAG_MAX_EXAMPLES", "10000")))
|
||||||
default_factory=lambda: int(os.getenv("RAG_MAX_EXAMPLES", "10000"))
|
max_submitted: int = field(default_factory=lambda: int(os.getenv("RAG_MAX_SUBMITTED", "5000")))
|
||||||
|
submitted_path: str = field(
|
||||||
|
default_factory=lambda: os.getenv("RAG_SUBMITTED_PATH", "data/vectors/submitted.npz")
|
||||||
)
|
)
|
||||||
score_multiplier: float = field(
|
score_multiplier: float = field(
|
||||||
default_factory=lambda: float(os.getenv("RAG_SCORE_MULTIPLIER", "5.0"))
|
default_factory=lambda: float(os.getenv("RAG_SCORE_MULTIPLIER", "5.0"))
|
||||||
)
|
)
|
||||||
|
|
||||||
# Батч-обработка
|
# Батч-обработка
|
||||||
batch_size: int = field(
|
batch_size: int = field(default_factory=lambda: int(os.getenv("RAG_BATCH_SIZE", "16")))
|
||||||
default_factory=lambda: int(os.getenv("RAG_BATCH_SIZE", "16"))
|
|
||||||
)
|
|
||||||
|
|
||||||
# Минимальная длина текста
|
# Минимальная длина текста
|
||||||
min_text_length: int = field(
|
min_text_length: int = field(default_factory=lambda: int(os.getenv("RAG_MIN_TEXT_LENGTH", "3")))
|
||||||
default_factory=lambda: int(os.getenv("RAG_MIN_TEXT_LENGTH", "3"))
|
|
||||||
)
|
|
||||||
|
|
||||||
# API настройки
|
# API настройки
|
||||||
api_host: str = field(
|
api_host: str = field(default_factory=lambda: os.getenv("RAG_API_HOST", "0.0.0.0"))
|
||||||
default_factory=lambda: os.getenv("RAG_API_HOST", "0.0.0.0")
|
api_port: int = field(default_factory=lambda: int(os.getenv("RAG_API_PORT", "8000")))
|
||||||
)
|
|
||||||
api_port: int = field(
|
|
||||||
default_factory=lambda: int(os.getenv("RAG_API_PORT", "8000"))
|
|
||||||
)
|
|
||||||
|
|
||||||
# Безопасность
|
# Безопасность
|
||||||
# API ключ для авторизации (обязателен в продакшене!)
|
# API ключ для авторизации (обязателен в продакшене!)
|
||||||
api_key: Optional[str] = field(
|
api_key: str | None = field(default_factory=lambda: os.getenv("RAG_API_KEY"))
|
||||||
default_factory=lambda: os.getenv("RAG_API_KEY")
|
|
||||||
)
|
|
||||||
# Разрешить запросы без ключа (только для разработки)
|
# Разрешить запросы без ключа (только для разработки)
|
||||||
allow_no_auth: bool = field(
|
allow_no_auth: bool = field(
|
||||||
default_factory=lambda: os.getenv("RAG_ALLOW_NO_AUTH", "false").lower() == "true"
|
default_factory=lambda: os.getenv("RAG_ALLOW_NO_AUTH", "false").lower() == "true"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Логирование
|
# Логирование
|
||||||
log_level: str = field(
|
log_level: str = field(default_factory=lambda: os.getenv("LOG_LEVEL", "INFO"))
|
||||||
default_factory=lambda: os.getenv("LOG_LEVEL", "INFO")
|
|
||||||
)
|
|
||||||
|
|
||||||
# Автосохранение (интервал в секундах, 0 = отключено)
|
# Автосохранение (интервал в секундах, 0 = отключено)
|
||||||
autosave_interval: int = field(
|
autosave_interval: int = field(
|
||||||
default_factory=lambda: int(os.getenv("RAG_AUTOSAVE_INTERVAL", "600")) # 10 минут
|
default_factory=lambda: int(os.getenv("RAG_AUTOSAVE_INTERVAL", "600")) # 10 минут
|
||||||
)
|
)
|
||||||
|
|
||||||
# Размерность векторов (384 для all-MiniLM-L12-v2)
|
# Размерность векторов (384 для all-MiniLM-L12-v2)
|
||||||
vector_dim: int = 384
|
vector_dim: int = 384
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_auth_required(self) -> bool:
|
def is_auth_required(self) -> bool:
|
||||||
"""Проверяет, требуется ли авторизация."""
|
"""Проверяет, требуется ли авторизация."""
|
||||||
return self.api_key is not None and not self.allow_no_auth
|
return self.api_key is not None and not self.allow_no_auth
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def generate_api_key() -> str:
|
def generate_api_key() -> str:
|
||||||
"""Генерирует случайный API ключ."""
|
"""Генерирует случайный API ключ."""
|
||||||
@@ -88,13 +75,13 @@ class Settings:
|
|||||||
|
|
||||||
|
|
||||||
# Глобальный экземпляр настроек
|
# Глобальный экземпляр настроек
|
||||||
_settings: Optional[Settings] = None
|
_settings: Settings | None = None
|
||||||
|
|
||||||
|
|
||||||
def get_settings() -> Settings:
|
def get_settings() -> Settings:
|
||||||
"""
|
"""
|
||||||
Возвращает глобальный экземпляр настроек.
|
Возвращает глобальный экземпляр настроек.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Settings: Настройки приложения
|
Settings: Настройки приложения
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -5,29 +5,35 @@
|
|||||||
|
|
||||||
class RAGServiceError(Exception):
|
class RAGServiceError(Exception):
|
||||||
"""Базовое исключение для ошибок RAG сервиса."""
|
"""Базовое исключение для ошибок RAG сервиса."""
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class ModelNotLoadedError(RAGServiceError):
|
class ModelNotLoadedError(RAGServiceError):
|
||||||
"""Модель не загружена или недоступна."""
|
"""Модель не загружена или недоступна."""
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class VectorStoreError(RAGServiceError):
|
class VectorStoreError(RAGServiceError):
|
||||||
"""Ошибка при работе с хранилищем векторов."""
|
"""Ошибка при работе с хранилищем векторов."""
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class InsufficientExamplesError(RAGServiceError):
|
class InsufficientExamplesError(RAGServiceError):
|
||||||
"""Недостаточно примеров для расчета скора."""
|
"""Недостаточно примеров для расчета скора."""
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class TextTooShortError(RAGServiceError):
|
class TextTooShortError(RAGServiceError):
|
||||||
"""Текст слишком короткий для векторизации."""
|
"""Текст слишком короткий для векторизации."""
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class ScoringError(RAGServiceError):
|
class ScoringError(RAGServiceError):
|
||||||
"""Ошибка при расчете скора."""
|
"""Ошибка при расчете скора."""
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
|||||||
66
app/main.py
66
app/main.py
@@ -7,8 +7,8 @@ FastAPI приложение Embedding сервиса.
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import sys
|
import sys
|
||||||
|
from collections.abc import AsyncGenerator
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from typing import AsyncGenerator, Optional
|
|
||||||
|
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
@@ -18,11 +18,12 @@ from app.api.routes import router
|
|||||||
from app.config import get_settings
|
from app.config import get_settings
|
||||||
from app.services.rag_service import RAGService, get_rag_service
|
from app.services.rag_service import RAGService, get_rag_service
|
||||||
|
|
||||||
|
|
||||||
# Настройка логирования
|
# Настройка логирования
|
||||||
def setup_logging() -> None:
|
def setup_logging() -> None:
|
||||||
"""Настраивает логирование для приложения."""
|
"""Настраивает логирование для приложения."""
|
||||||
settings = get_settings()
|
settings = get_settings()
|
||||||
|
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
level=getattr(logging, settings.log_level.upper()),
|
level=getattr(logging, settings.log_level.upper()),
|
||||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||||
@@ -30,7 +31,7 @@ def setup_logging() -> None:
|
|||||||
logging.StreamHandler(sys.stdout),
|
logging.StreamHandler(sys.stdout),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Уменьшаем логи от библиотек
|
# Уменьшаем логи от библиотек
|
||||||
logging.getLogger("transformers").setLevel(logging.WARNING)
|
logging.getLogger("transformers").setLevel(logging.WARNING)
|
||||||
logging.getLogger("torch").setLevel(logging.WARNING)
|
logging.getLogger("torch").setLevel(logging.WARNING)
|
||||||
@@ -40,33 +41,40 @@ def setup_logging() -> None:
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Глобальная задача автосохранения
|
# Глобальная задача автосохранения
|
||||||
_autosave_task: Optional[asyncio.Task] = None
|
_autosave_task: asyncio.Task | None = None
|
||||||
|
|
||||||
|
|
||||||
async def autosave_loop(service: RAGService, interval: int) -> None:
|
async def autosave_loop(service: RAGService, interval: int) -> None:
|
||||||
"""
|
"""
|
||||||
Фоновая задача для периодического сохранения векторов.
|
Фоновая задача для периодического сохранения векторов.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
service: RAG сервис
|
service: RAG сервис
|
||||||
interval: Интервал сохранения в секундах
|
interval: Интервал сохранения в секундах
|
||||||
"""
|
"""
|
||||||
logger.info(f"Автосохранение запущено (интервал: {interval} сек)")
|
logger.info(f"Автосохранение запущено (интервал: {interval} сек)")
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
await asyncio.sleep(interval)
|
await asyncio.sleep(interval)
|
||||||
|
|
||||||
# Сохраняем только если есть данные
|
# Сохраняем только если есть данные
|
||||||
if service.vector_store.total_count > 0:
|
has_examples = service.vector_store.total_count > 0
|
||||||
|
has_submitted = service.vector_store.submitted_count > 0
|
||||||
|
if has_examples or has_submitted:
|
||||||
service.save_vectors()
|
service.save_vectors()
|
||||||
logger.info(
|
parts = []
|
||||||
f"Автосохранение: сохранено {service.vector_store.positive_count} pos, "
|
if has_examples:
|
||||||
f"{service.vector_store.negative_count} neg"
|
parts.append(
|
||||||
)
|
f"{service.vector_store.positive_count} pos, "
|
||||||
|
f"{service.vector_store.negative_count} neg"
|
||||||
|
)
|
||||||
|
if has_submitted:
|
||||||
|
parts.append(f"{service.vector_store.submitted_count} submitted")
|
||||||
|
logger.info(f"Автосохранение: сохранено {', '.join(parts)}")
|
||||||
else:
|
else:
|
||||||
logger.debug("Автосохранение: нет данных для сохранения")
|
logger.debug("Автосохранение: нет данных для сохранения")
|
||||||
|
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
logger.info("Автосохранение остановлено")
|
logger.info("Автосохранение остановлено")
|
||||||
break
|
break
|
||||||
@@ -79,43 +87,41 @@ async def autosave_loop(service: RAGService, interval: int) -> None:
|
|||||||
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||||
"""
|
"""
|
||||||
Lifespan контекст для FastAPI.
|
Lifespan контекст для FastAPI.
|
||||||
|
|
||||||
При запуске:
|
При запуске:
|
||||||
- Настраивает логирование
|
- Настраивает логирование
|
||||||
- Прогревает модель (опционально)
|
- Прогревает модель (опционально)
|
||||||
|
|
||||||
При остановке:
|
При остановке:
|
||||||
- Сохраняет векторы на диск
|
- Сохраняет векторы на диск
|
||||||
"""
|
"""
|
||||||
global _autosave_task
|
global _autosave_task
|
||||||
|
|
||||||
setup_logging()
|
setup_logging()
|
||||||
logger.info(f"Embedding Service v{__version__} запускается...")
|
logger.info(f"Embedding Service v{__version__} запускается...")
|
||||||
|
|
||||||
settings = get_settings()
|
settings = get_settings()
|
||||||
logger.info(f"Настройки: model={settings.model_name}, vectors_path={settings.vectors_path}")
|
logger.info(f"Настройки: model={settings.model_name}, vectors_path={settings.vectors_path}")
|
||||||
|
|
||||||
# Получаем сервис (создается singleton)
|
# Получаем сервис (создается singleton)
|
||||||
service = get_rag_service()
|
service = get_rag_service()
|
||||||
|
|
||||||
# Запускаем автосохранение если включено
|
# Запускаем автосохранение если включено
|
||||||
if settings.autosave_interval > 0:
|
if settings.autosave_interval > 0:
|
||||||
_autosave_task = asyncio.create_task(
|
_autosave_task = asyncio.create_task(autosave_loop(service, settings.autosave_interval))
|
||||||
autosave_loop(service, settings.autosave_interval)
|
|
||||||
)
|
|
||||||
logger.info(f"Автосохранение включено: каждые {settings.autosave_interval} сек")
|
logger.info(f"Автосохранение включено: каждые {settings.autosave_interval} сек")
|
||||||
else:
|
else:
|
||||||
logger.info("Автосохранение отключено")
|
logger.info("Автосохранение отключено")
|
||||||
|
|
||||||
# Прогреваем модель при запуске (опционально)
|
# Прогреваем модель при запуске (опционально)
|
||||||
# Можно раскомментировать если нужен автопрогрев
|
# Можно раскомментировать если нужен автопрогрев
|
||||||
# logger.info("Прогрев модели при запуске...")
|
# logger.info("Прогрев модели при запуске...")
|
||||||
# await service.warmup()
|
# await service.warmup()
|
||||||
|
|
||||||
logger.info("Embedding Service готов к работе")
|
logger.info("Embedding Service готов к работе")
|
||||||
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
# Останавливаем автосохранение
|
# Останавливаем автосохранение
|
||||||
if _autosave_task and not _autosave_task.done():
|
if _autosave_task and not _autosave_task.done():
|
||||||
_autosave_task.cancel()
|
_autosave_task.cancel()
|
||||||
@@ -123,7 +129,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
|||||||
await _autosave_task
|
await _autosave_task
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# При остановке сохраняем векторы
|
# При остановке сохраняем векторы
|
||||||
logger.info("Embedding Service останавливается, финальное сохранение векторов...")
|
logger.info("Embedding Service останавливается, финальное сохранение векторов...")
|
||||||
try:
|
try:
|
||||||
@@ -131,7 +137,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
|||||||
logger.info("Векторы сохранены")
|
logger.info("Векторы сохранены")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Ошибка сохранения векторов: {e}")
|
logger.error(f"Ошибка сохранения векторов: {e}")
|
||||||
|
|
||||||
logger.info("Embedding Service остановлен")
|
logger.info("Embedding Service остановлен")
|
||||||
|
|
||||||
|
|
||||||
@@ -176,19 +182,21 @@ app.add_middleware(
|
|||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# Простой healthcheck endpoint без авторизации (для Docker healthcheck)
|
# Простой healthcheck endpoint без авторизации (для Docker healthcheck)
|
||||||
@app.get("/health")
|
@app.get("/health")
|
||||||
async def simple_health_check():
|
async def simple_health_check():
|
||||||
"""Простая проверка здоровья без авторизации (для Docker healthcheck)."""
|
"""Простая проверка здоровья без авторизации (для Docker healthcheck)."""
|
||||||
return {"status": "ok"}
|
return {"status": "ok"}
|
||||||
|
|
||||||
|
|
||||||
# Подключение роутов
|
# Подключение роутов
|
||||||
app.include_router(router, prefix="/api/v1")
|
app.include_router(router, prefix="/api/v1")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
|
||||||
settings = get_settings()
|
settings = get_settings()
|
||||||
uvicorn.run(
|
uvicorn.run(
|
||||||
"app.main:app",
|
"app.main:app",
|
||||||
|
|||||||
189
app/schemas.py
189
app/schemas.py
@@ -2,36 +2,66 @@
|
|||||||
Pydantic схемы для API Embedding сервиса.
|
Pydantic схемы для API Embedding сервиса.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any, Dict, Optional
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Запросы
|
# Запросы
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
class ScoreRequest(BaseModel):
|
class ScoreRequest(BaseModel):
|
||||||
"""Запрос на расчет скора."""
|
"""Запрос на расчет скора."""
|
||||||
|
|
||||||
text: str = Field(..., min_length=1, description="Текст поста для оценки")
|
text: str = Field(..., min_length=1, description="Текст поста для оценки")
|
||||||
|
|
||||||
model_config = {
|
model_config = {
|
||||||
"json_schema_extra": {
|
"json_schema_extra": {"example": {"text": "Это пример текста поста для оценки скоринга"}}
|
||||||
"example": {
|
|
||||||
"text": "Это пример текста поста для оценки скоринга"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class ExampleRequest(BaseModel):
|
class ExampleRequest(BaseModel):
|
||||||
"""Запрос на добавление примера."""
|
"""Запрос на добавление примера."""
|
||||||
|
|
||||||
text: str = Field(..., min_length=1, description="Текст примера")
|
text: str = Field(..., min_length=1, description="Текст примера")
|
||||||
|
|
||||||
|
model_config = {
|
||||||
|
"json_schema_extra": {"example": {"text": "Это пример опубликованного/отклоненного поста"}}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class SimilarRequest(BaseModel):
|
||||||
|
"""Запрос на поиск похожих постов."""
|
||||||
|
|
||||||
|
text: str = Field(..., min_length=1, description="Текст для поиска похожих")
|
||||||
|
threshold: float = Field(
|
||||||
|
default=0.9, ge=0.0, le=1.0, description="Минимальный порог similarity"
|
||||||
|
)
|
||||||
|
hours: int = Field(default=24, ge=1, le=168, description="Количество часов для фильтрации")
|
||||||
|
|
||||||
model_config = {
|
model_config = {
|
||||||
"json_schema_extra": {
|
"json_schema_extra": {
|
||||||
"example": {
|
"example": {
|
||||||
"text": "Это пример опубликованного/отклоненного поста"
|
"text": "Текст поста для поиска похожих",
|
||||||
|
"threshold": 0.9,
|
||||||
|
"hours": 24,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class SubmittedRequest(BaseModel):
|
||||||
|
"""Запрос на добавление submitted-поста."""
|
||||||
|
|
||||||
|
text: str = Field(..., min_length=1, description="Текст поста")
|
||||||
|
post_id: int | None = None
|
||||||
|
rag_score: float | None = None
|
||||||
|
|
||||||
|
model_config = {
|
||||||
|
"json_schema_extra": {
|
||||||
|
"example": {
|
||||||
|
"text": "Текст submitted-поста",
|
||||||
|
"post_id": 12345,
|
||||||
|
"rag_score": 0.85,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -41,8 +71,10 @@ class ExampleRequest(BaseModel):
|
|||||||
# Ответы
|
# Ответы
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
class ScoreMetadata(BaseModel):
|
class ScoreMetadata(BaseModel):
|
||||||
"""Метаданные результата скоринга."""
|
"""Метаданные результата скоринга."""
|
||||||
|
|
||||||
positive_examples: int = Field(..., description="Количество положительных примеров")
|
positive_examples: int = Field(..., description="Количество положительных примеров")
|
||||||
negative_examples: int = Field(..., description="Количество отрицательных примеров")
|
negative_examples: int = Field(..., description="Количество отрицательных примеров")
|
||||||
model: str = Field(..., description="Название модели")
|
model: str = Field(..., description="Название модели")
|
||||||
@@ -51,11 +83,14 @@ class ScoreMetadata(BaseModel):
|
|||||||
|
|
||||||
class ScoreResponse(BaseModel):
|
class ScoreResponse(BaseModel):
|
||||||
"""Ответ с результатом скоринга."""
|
"""Ответ с результатом скоринга."""
|
||||||
|
|
||||||
rag_score: float = Field(..., ge=0.0, le=1.0, description="Основной скор (neg/pos формула)")
|
rag_score: float = Field(..., ge=0.0, le=1.0, description="Основной скор (neg/pos формула)")
|
||||||
rag_confidence: float = Field(..., ge=0.0, le=1.0, description="Уверенность в оценке")
|
rag_confidence: float = Field(..., ge=0.0, le=1.0, description="Уверенность в оценке")
|
||||||
rag_score_pos_only: float = Field(..., ge=0.0, le=1.0, description="Скор только по положительным примерам")
|
rag_score_pos_only: float = Field(
|
||||||
|
..., ge=0.0, le=1.0, description="Скор только по положительным примерам"
|
||||||
|
)
|
||||||
meta: ScoreMetadata = Field(..., description="Метаданные")
|
meta: ScoreMetadata = Field(..., description="Метаданные")
|
||||||
|
|
||||||
model_config = {
|
model_config = {
|
||||||
"json_schema_extra": {
|
"json_schema_extra": {
|
||||||
"example": {
|
"example": {
|
||||||
@@ -66,8 +101,8 @@ class ScoreResponse(BaseModel):
|
|||||||
"positive_examples": 500,
|
"positive_examples": 500,
|
||||||
"negative_examples": 350,
|
"negative_examples": 350,
|
||||||
"model": "sentence-transformers/all-MiniLM-L12-v2",
|
"model": "sentence-transformers/all-MiniLM-L12-v2",
|
||||||
"timestamp": 1706270000
|
"timestamp": 1706270000,
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -75,18 +110,71 @@ class ScoreResponse(BaseModel):
|
|||||||
|
|
||||||
class ExampleResponse(BaseModel):
|
class ExampleResponse(BaseModel):
|
||||||
"""Ответ на добавление примера."""
|
"""Ответ на добавление примера."""
|
||||||
|
|
||||||
success: bool = Field(..., description="Успешность добавления")
|
success: bool = Field(..., description="Успешность добавления")
|
||||||
message: str = Field(..., description="Сообщение о результате")
|
message: str = Field(..., description="Сообщение о результате")
|
||||||
positive_count: int = Field(..., description="Текущее количество положительных примеров")
|
positive_count: int = Field(..., description="Текущее количество положительных примеров")
|
||||||
negative_count: int = Field(..., description="Текущее количество отрицательных примеров")
|
negative_count: int = Field(..., description="Текущее количество отрицательных примеров")
|
||||||
|
|
||||||
model_config = {
|
model_config = {
|
||||||
"json_schema_extra": {
|
"json_schema_extra": {
|
||||||
"example": {
|
"example": {
|
||||||
"success": True,
|
"success": True,
|
||||||
"message": "Положительный пример добавлен",
|
"message": "Положительный пример добавлен",
|
||||||
"positive_count": 501,
|
"positive_count": 501,
|
||||||
"negative_count": 350
|
"negative_count": 350,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class SimilarPostItem(BaseModel):
|
||||||
|
"""Элемент похожего поста."""
|
||||||
|
|
||||||
|
similarity: float = Field(..., description="Косинусное сходство")
|
||||||
|
created_at: int = Field(..., description="Unix timestamp создания")
|
||||||
|
post_id: int | None = None
|
||||||
|
text: str = Field(..., description="Текст поста")
|
||||||
|
rag_score: float | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class SimilarResponse(BaseModel):
|
||||||
|
"""Ответ с похожими постами."""
|
||||||
|
|
||||||
|
similar_count: int = Field(..., description="Количество найденных похожих постов")
|
||||||
|
similar_posts: list[SimilarPostItem] = Field(..., description="Список похожих постов")
|
||||||
|
|
||||||
|
model_config = {
|
||||||
|
"json_schema_extra": {
|
||||||
|
"example": {
|
||||||
|
"similar_count": 2,
|
||||||
|
"similar_posts": [
|
||||||
|
{
|
||||||
|
"similarity": 0.95,
|
||||||
|
"created_at": 1706270000,
|
||||||
|
"post_id": 123,
|
||||||
|
"text": "Похожий пост",
|
||||||
|
"rag_score": 0.85,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class SubmittedResponse(BaseModel):
|
||||||
|
"""Ответ на добавление submitted-поста."""
|
||||||
|
|
||||||
|
success: bool = Field(..., description="Успешность добавления")
|
||||||
|
message: str = Field(..., description="Сообщение о результате")
|
||||||
|
submitted_count: int = Field(..., description="Текущее количество submitted-постов")
|
||||||
|
|
||||||
|
model_config = {
|
||||||
|
"json_schema_extra": {
|
||||||
|
"example": {
|
||||||
|
"success": True,
|
||||||
|
"message": "Submitted-пост добавлен",
|
||||||
|
"submitted_count": 42,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -94,20 +182,24 @@ class ExampleResponse(BaseModel):
|
|||||||
|
|
||||||
class VectorStoreStats(BaseModel):
|
class VectorStoreStats(BaseModel):
|
||||||
"""Статистика хранилища векторов."""
|
"""Статистика хранилища векторов."""
|
||||||
|
|
||||||
positive_count: int = Field(..., description="Количество положительных примеров")
|
positive_count: int = Field(..., description="Количество положительных примеров")
|
||||||
negative_count: int = Field(..., description="Количество отрицательных примеров")
|
negative_count: int = Field(..., description="Количество отрицательных примеров")
|
||||||
total_count: int = Field(..., description="Общее количество примеров")
|
total_count: int = Field(..., description="Общее количество примеров")
|
||||||
|
submitted_count: int = Field(default=0, description="Количество submitted-постов")
|
||||||
vector_dim: int = Field(..., description="Размерность векторов")
|
vector_dim: int = Field(..., description="Размерность векторов")
|
||||||
max_examples: int = Field(..., description="Максимальное количество примеров")
|
max_examples: int = Field(..., description="Максимальное количество примеров")
|
||||||
|
max_submitted: int = Field(default=5000, description="Максимальное количество submitted-постов")
|
||||||
|
|
||||||
|
|
||||||
class StatsResponse(BaseModel):
|
class StatsResponse(BaseModel):
|
||||||
"""Ответ со статистикой сервиса."""
|
"""Ответ со статистикой сервиса."""
|
||||||
|
|
||||||
model_name: str = Field(..., description="Название модели")
|
model_name: str = Field(..., description="Название модели")
|
||||||
model_loaded: bool = Field(..., description="Загружена ли модель")
|
model_loaded: bool = Field(..., description="Загружена ли модель")
|
||||||
device: Optional[str] = Field(None, description="Устройство (cpu/cuda)")
|
device: str | None = Field(None, description="Устройство (cpu/cuda)")
|
||||||
vector_store: VectorStoreStats = Field(..., description="Статистика хранилища векторов")
|
vector_store: VectorStoreStats = Field(..., description="Статистика хранилища векторов")
|
||||||
|
|
||||||
model_config = {
|
model_config = {
|
||||||
"json_schema_extra": {
|
"json_schema_extra": {
|
||||||
"example": {
|
"example": {
|
||||||
@@ -119,8 +211,8 @@ class StatsResponse(BaseModel):
|
|||||||
"negative_count": 350,
|
"negative_count": 350,
|
||||||
"total_count": 850,
|
"total_count": 850,
|
||||||
"vector_dim": 384,
|
"vector_dim": 384,
|
||||||
"max_examples": 10000
|
"max_examples": 10000,
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -128,16 +220,17 @@ class StatsResponse(BaseModel):
|
|||||||
|
|
||||||
class WarmupResponse(BaseModel):
|
class WarmupResponse(BaseModel):
|
||||||
"""Ответ на прогрев модели."""
|
"""Ответ на прогрев модели."""
|
||||||
|
|
||||||
success: bool = Field(..., description="Успешность загрузки")
|
success: bool = Field(..., description="Успешность загрузки")
|
||||||
model_loaded: bool = Field(..., description="Загружена ли модель")
|
model_loaded: bool = Field(..., description="Загружена ли модель")
|
||||||
message: str = Field(..., description="Сообщение о результате")
|
message: str = Field(..., description="Сообщение о результате")
|
||||||
|
|
||||||
model_config = {
|
model_config = {
|
||||||
"json_schema_extra": {
|
"json_schema_extra": {
|
||||||
"example": {
|
"example": {
|
||||||
"success": True,
|
"success": True,
|
||||||
"model_loaded": True,
|
"model_loaded": True,
|
||||||
"message": "Модель успешно загружена"
|
"message": "Модель успешно загружена",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -145,14 +238,15 @@ class WarmupResponse(BaseModel):
|
|||||||
|
|
||||||
class ErrorResponse(BaseModel):
|
class ErrorResponse(BaseModel):
|
||||||
"""Ответ с ошибкой."""
|
"""Ответ с ошибкой."""
|
||||||
|
|
||||||
detail: str = Field(..., description="Описание ошибки")
|
detail: str = Field(..., description="Описание ошибки")
|
||||||
error_type: str = Field(..., description="Тип ошибки")
|
error_type: str = Field(..., description="Тип ошибки")
|
||||||
|
|
||||||
model_config = {
|
model_config = {
|
||||||
"json_schema_extra": {
|
"json_schema_extra": {
|
||||||
"example": {
|
"example": {
|
||||||
"detail": "Недостаточно примеров для расчета скора",
|
"detail": "Недостаточно примеров для расчета скора",
|
||||||
"error_type": "InsufficientExamplesError"
|
"error_type": "InsufficientExamplesError",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -160,23 +254,21 @@ class ErrorResponse(BaseModel):
|
|||||||
|
|
||||||
class HealthResponse(BaseModel):
|
class HealthResponse(BaseModel):
|
||||||
"""Ответ проверки здоровья сервиса."""
|
"""Ответ проверки здоровья сервиса."""
|
||||||
|
|
||||||
status: str = Field(..., description="Статус сервиса")
|
status: str = Field(..., description="Статус сервиса")
|
||||||
model_loaded: bool = Field(..., description="Загружена ли модель")
|
model_loaded: bool = Field(..., description="Загружена ли модель")
|
||||||
version: str = Field(..., description="Версия сервиса")
|
version: str = Field(..., description="Версия сервиса")
|
||||||
|
|
||||||
model_config = {
|
model_config = {
|
||||||
"json_schema_extra": {
|
"json_schema_extra": {
|
||||||
"example": {
|
"example": {"status": "healthy", "model_loaded": True, "version": "0.1.0"}
|
||||||
"status": "healthy",
|
|
||||||
"model_loaded": True,
|
|
||||||
"version": "0.1.0"
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class ScoringParamsResponse(BaseModel):
|
class ScoringParamsResponse(BaseModel):
|
||||||
"""Ответ с текущими параметрами формулы расчета score."""
|
"""Ответ с текущими параметрами формулы расчета score."""
|
||||||
|
|
||||||
score_multiplier: float = Field(
|
score_multiplier: float = Field(
|
||||||
...,
|
...,
|
||||||
description=(
|
description=(
|
||||||
@@ -185,7 +277,7 @@ class ScoringParamsResponse(BaseModel):
|
|||||||
"где diff = avg_pos - avg_neg (разница средних сходств топ-k примеров). "
|
"где diff = avg_pos - avg_neg (разница средних сходств топ-k примеров). "
|
||||||
"Чем больше значение, тем сильнее влияние разницы между положительными и отрицательными примерами на итоговый score. "
|
"Чем больше значение, тем сильнее влияние разницы между положительными и отрицательными примерами на итоговый score. "
|
||||||
"Рекомендуемое значение: 5.0"
|
"Рекомендуемое значение: 5.0"
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
k: int = Field(
|
k: int = Field(
|
||||||
...,
|
...,
|
||||||
@@ -195,22 +287,16 @@ class ScoringParamsResponse(BaseModel):
|
|||||||
"и вычисляет среднее косинусное сходство. "
|
"и вычисляет среднее косинусное сходство. "
|
||||||
"Меньшее значение k делает алгоритм более чувствительным к различиям, но может быть менее стабильным. "
|
"Меньшее значение k делает алгоритм более чувствительным к различиям, но может быть менее стабильным. "
|
||||||
"Рекомендуемое значение: 3"
|
"Рекомендуемое значение: 3"
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
model_config = {
|
model_config = {"json_schema_extra": {"example": {"score_multiplier": 5.0, "k": 3}}}
|
||||||
"json_schema_extra": {
|
|
||||||
"example": {
|
|
||||||
"score_multiplier": 5.0,
|
|
||||||
"k": 3
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class UpdateScoringParamsRequest(BaseModel):
|
class UpdateScoringParamsRequest(BaseModel):
|
||||||
"""Запрос на обновление параметров формулы расчета score."""
|
"""Запрос на обновление параметров формулы расчета score."""
|
||||||
score_multiplier: Optional[float] = Field(
|
|
||||||
|
score_multiplier: float | None = Field(
|
||||||
None,
|
None,
|
||||||
gt=0,
|
gt=0,
|
||||||
description=(
|
description=(
|
||||||
@@ -219,9 +305,9 @@ class UpdateScoringParamsRequest(BaseModel):
|
|||||||
"где diff = avg_pos - avg_neg (разница средних сходств топ-k примеров). "
|
"где diff = avg_pos - avg_neg (разница средних сходств топ-k примеров). "
|
||||||
"Чем больше значение, тем сильнее влияние разницы между положительными и отрицательными примерами на итоговый score. "
|
"Чем больше значение, тем сильнее влияние разницы между положительными и отрицательными примерами на итоговый score. "
|
||||||
"Должен быть > 0. Рекомендуемое значение: 5.0"
|
"Должен быть > 0. Рекомендуемое значение: 5.0"
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
k: Optional[int] = Field(
|
k: int | None = Field(
|
||||||
None,
|
None,
|
||||||
ge=1,
|
ge=1,
|
||||||
description=(
|
description=(
|
||||||
@@ -230,14 +316,7 @@ class UpdateScoringParamsRequest(BaseModel):
|
|||||||
"и вычисляет среднее косинусное сходство. "
|
"и вычисляет среднее косинусное сходство. "
|
||||||
"Меньшее значение k делает алгоритм более чувствительным к различиям, но может быть менее стабильным. "
|
"Меньшее значение k делает алгоритм более чувствительным к различиям, но может быть менее стабильным. "
|
||||||
"Должно быть >= 1. Рекомендуемое значение: 3"
|
"Должно быть >= 1. Рекомендуемое значение: 3"
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
model_config = {
|
model_config = {"json_schema_extra": {"example": {"score_multiplier": 5.0, "k": 3}}}
|
||||||
"json_schema_extra": {
|
|
||||||
"example": {
|
|
||||||
"score_multiplier": 5.0,
|
|
||||||
"k": 3
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ import asyncio
|
|||||||
import logging
|
import logging
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -29,7 +29,7 @@ logger = logging.getLogger(__name__)
|
|||||||
class ScoringResult:
|
class ScoringResult:
|
||||||
"""
|
"""
|
||||||
Результат оценки поста.
|
Результат оценки поста.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
score: Оценка от 0.0 до 1.0 (вероятность публикации)
|
score: Оценка от 0.0 до 1.0 (вероятность публикации)
|
||||||
confidence: Уверенность в оценке
|
confidence: Уверенность в оценке
|
||||||
@@ -39,6 +39,7 @@ class ScoringResult:
|
|||||||
model: Название используемой модели
|
model: Название используемой модели
|
||||||
timestamp: Время получения оценки
|
timestamp: Время получения оценки
|
||||||
"""
|
"""
|
||||||
|
|
||||||
score: float
|
score: float
|
||||||
confidence: float
|
confidence: float
|
||||||
score_pos_only: float
|
score_pos_only: float
|
||||||
@@ -46,8 +47,8 @@ class ScoringResult:
|
|||||||
negative_examples: int
|
negative_examples: int
|
||||||
model: str
|
model: str
|
||||||
timestamp: int = field(default_factory=lambda: int(datetime.now().timestamp()))
|
timestamp: int = field(default_factory=lambda: int(datetime.now().timestamp()))
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
def to_dict(self) -> dict[str, Any]:
|
||||||
"""Преобразует результат в словарь."""
|
"""Преобразует результат в словарь."""
|
||||||
return {
|
return {
|
||||||
"rag_score": round(self.score, 4),
|
"rag_score": round(self.score, 4),
|
||||||
@@ -58,31 +59,31 @@ class ScoringResult:
|
|||||||
"negative_examples": self.negative_examples,
|
"negative_examples": self.negative_examples,
|
||||||
"model": self.model,
|
"model": self.model,
|
||||||
"timestamp": self.timestamp,
|
"timestamp": self.timestamp,
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class RAGService:
|
class RAGService:
|
||||||
"""
|
"""
|
||||||
RAG сервис для оценки постов на основе векторного сходства.
|
RAG сервис для оценки постов на основе векторного сходства.
|
||||||
|
|
||||||
Использует sentence-transformers для создания эмбеддингов текста и сравнивает
|
Использует sentence-transformers для создания эмбеддингов текста и сравнивает
|
||||||
их с эталонными примерами (опубликованные vs отклоненные посты).
|
их с эталонными примерами (опубликованные vs отклоненные посты).
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
model_name: Название модели HuggingFace
|
model_name: Название модели HuggingFace
|
||||||
vector_store: Хранилище векторов
|
vector_store: Хранилище векторов
|
||||||
min_text_length: Минимальная длина текста для обработки
|
min_text_length: Минимальная длина текста для обработки
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
settings: Optional[Settings] = None,
|
settings: Settings | None = None,
|
||||||
vector_store: Optional[VectorStore] = None,
|
vector_store: VectorStore | None = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Инициализация RAG сервиса.
|
Инициализация RAG сервиса.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
settings: Настройки сервиса (берутся из get_settings() если не переданы)
|
settings: Настройки сервиса (берутся из get_settings() если не переданы)
|
||||||
vector_store: Хранилище векторов (создается автоматически если не передано)
|
vector_store: Хранилище векторов (создается автоматически если не передано)
|
||||||
@@ -91,96 +92,102 @@ class RAGService:
|
|||||||
self.model_name = self._settings.model_name
|
self.model_name = self._settings.model_name
|
||||||
self.cache_dir = self._settings.cache_dir
|
self.cache_dir = self._settings.cache_dir
|
||||||
self.min_text_length = self._settings.min_text_length
|
self.min_text_length = self._settings.min_text_length
|
||||||
|
|
||||||
# Модель загружается лениво
|
# Модель загружается лениво
|
||||||
self._model = None
|
self._model = None
|
||||||
self._device = None
|
self._device = None
|
||||||
self._model_loaded = False
|
self._model_loaded = False
|
||||||
|
|
||||||
# Хранилище векторов
|
# Хранилище векторов
|
||||||
self.vector_store = vector_store or VectorStore(
|
self.vector_store = vector_store or VectorStore(
|
||||||
vector_dim=self._settings.vector_dim,
|
vector_dim=self._settings.vector_dim,
|
||||||
max_examples=self._settings.max_examples,
|
max_examples=self._settings.max_examples,
|
||||||
|
max_submitted=self._settings.max_submitted,
|
||||||
storage_path=self._settings.vectors_path,
|
storage_path=self._settings.vectors_path,
|
||||||
|
submitted_path=self._settings.submitted_path,
|
||||||
score_multiplier=self._settings.score_multiplier,
|
score_multiplier=self._settings.score_multiplier,
|
||||||
k=3, # Фиксированное значение k для топ-k ближайших примеров
|
k=3, # Фиксированное значение k для топ-k ближайших примеров
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"RAGService инициализирован (model={self.model_name})")
|
logger.info(f"RAGService инициализирован (model={self.model_name})")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_model_loaded(self) -> bool:
|
def is_model_loaded(self) -> bool:
|
||||||
"""Проверяет, загружена ли модель."""
|
"""Проверяет, загружена ли модель."""
|
||||||
return self._model_loaded
|
return self._model_loaded
|
||||||
|
|
||||||
async def load_model(self) -> None:
|
async def load_model(self) -> None:
|
||||||
"""
|
"""
|
||||||
Загружает модель и токенизатор.
|
Загружает модель и токенизатор.
|
||||||
|
|
||||||
Выполняется асинхронно в отдельном потоке чтобы не блокировать event loop.
|
Выполняется асинхронно в отдельном потоке чтобы не блокировать event loop.
|
||||||
"""
|
"""
|
||||||
if self._model_loaded:
|
if self._model_loaded:
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.info(f"RAGService: Загрузка модели {self.model_name}...")
|
logger.info(f"RAGService: Загрузка модели {self.model_name}...")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Загрузка в отдельном потоке
|
# Загрузка в отдельном потоке
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
await loop.run_in_executor(None, self._load_model_sync)
|
await loop.run_in_executor(None, self._load_model_sync)
|
||||||
|
|
||||||
self._model_loaded = True
|
self._model_loaded = True
|
||||||
logger.info(f"RAGService: Модель {self.model_name} успешно загружена")
|
logger.info(f"RAGService: Модель {self.model_name} успешно загружена")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"RAGService: Ошибка загрузки модели: {e}")
|
logger.error(f"RAGService: Ошибка загрузки модели: {e}")
|
||||||
raise ModelNotLoadedError(f"Не удалось загрузить модель {self.model_name}: {e}")
|
raise ModelNotLoadedError(f"Не удалось загрузить модель {self.model_name}: {e}")
|
||||||
|
|
||||||
def _load_model_sync(self) -> None:
|
def _load_model_sync(self) -> None:
|
||||||
"""Синхронная загрузка модели (вызывается в executor)."""
|
"""Синхронная загрузка модели (вызывается в executor)."""
|
||||||
logger.info("RAGService: Начало _load_model_sync, импорт sentence_transformers...")
|
logger.info("RAGService: Начало _load_model_sync, импорт sentence_transformers...")
|
||||||
from sentence_transformers import SentenceTransformer
|
|
||||||
import torch
|
import torch
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
|
||||||
# Определяем устройство
|
# Определяем устройство
|
||||||
self._device = "cuda" if torch.cuda.is_available() else "cpu"
|
self._device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
logger.info(f"RAGService: Устройство определено: {self._device}")
|
logger.info(f"RAGService: Устройство определено: {self._device}")
|
||||||
|
|
||||||
# Загружаем модель SentenceTransformer
|
# Загружаем модель SentenceTransformer
|
||||||
logger.info(f"RAGService: Загрузка модели из {self.model_name} (это может занять несколько минут)...")
|
logger.info(
|
||||||
|
f"RAGService: Загрузка модели из {self.model_name} (это может занять несколько минут)..."
|
||||||
|
)
|
||||||
self._model = SentenceTransformer(
|
self._model = SentenceTransformer(
|
||||||
self.model_name,
|
self.model_name,
|
||||||
cache_folder=self.cache_dir,
|
cache_folder=self.cache_dir,
|
||||||
device=self._device,
|
device=self._device,
|
||||||
)
|
)
|
||||||
logger.info(f"RAGService: Модель готова на устройстве: {self._device}")
|
logger.info(f"RAGService: Модель готова на устройстве: {self._device}")
|
||||||
|
|
||||||
def _get_embedding_sync(self, text: str) -> np.ndarray:
|
def _get_embedding_sync(self, text: str) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
Получает эмбеддинг текста (синхронно).
|
Получает эмбеддинг текста (синхронно).
|
||||||
|
|
||||||
Использует SentenceTransformer для получения нормализованного эмбеддинга.
|
Использует SentenceTransformer для получения нормализованного эмбеддинга.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
text: Текст для векторизации
|
text: Текст для векторизации
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Numpy массив с эмбеддингом (384 измерений для all-MiniLM-L12-v2)
|
Numpy массив с эмбеддингом (384 измерений для all-MiniLM-L12-v2)
|
||||||
"""
|
"""
|
||||||
# SentenceTransformer автоматически нормализует эмбеддинги
|
# SentenceTransformer автоматически нормализует эмбеддинги
|
||||||
embedding = self._model.encode(text, convert_to_numpy=True, normalize_embeddings=True)
|
embedding = self._model.encode(text, convert_to_numpy=True, normalize_embeddings=True)
|
||||||
return embedding.flatten()
|
return embedding.flatten()
|
||||||
|
|
||||||
def _get_embeddings_batch_sync(self, texts: List[str], batch_size: int = 16) -> List[np.ndarray]:
|
def _get_embeddings_batch_sync(
|
||||||
|
self, texts: list[str], batch_size: int = 16
|
||||||
|
) -> list[np.ndarray]:
|
||||||
"""
|
"""
|
||||||
Получает эмбеддинги для батча текстов (синхронно).
|
Получает эмбеддинги для батча текстов (синхронно).
|
||||||
|
|
||||||
Обрабатывает тексты пачками для эффективного использования GPU/CPU.
|
Обрабатывает тексты пачками для эффективного использования GPU/CPU.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
texts: Список текстов для векторизации
|
texts: Список текстов для векторизации
|
||||||
batch_size: Размер батча
|
batch_size: Размер батча
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Список numpy массивов с эмбеддингами
|
Список numpy массивов с эмбеддингами
|
||||||
"""
|
"""
|
||||||
@@ -192,32 +199,34 @@ class RAGService:
|
|||||||
normalize_embeddings=True,
|
normalize_embeddings=True,
|
||||||
show_progress_bar=False,
|
show_progress_bar=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Преобразуем в список отдельных массивов
|
# Преобразуем в список отдельных массивов
|
||||||
return [emb.flatten() for emb in embeddings]
|
return [emb.flatten() for emb in embeddings]
|
||||||
|
|
||||||
async def get_embeddings_batch(self, texts: List[str], batch_size: Optional[int] = None) -> List[np.ndarray]:
|
async def get_embeddings_batch(
|
||||||
|
self, texts: list[str], batch_size: int | None = None
|
||||||
|
) -> list[np.ndarray]:
|
||||||
"""
|
"""
|
||||||
Получает эмбеддинги для батча текстов (асинхронно).
|
Получает эмбеддинги для батча текстов (асинхронно).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
texts: Список текстов для векторизации
|
texts: Список текстов для векторизации
|
||||||
batch_size: Размер батча (берется из настроек если не указан)
|
batch_size: Размер батча (берется из настроек если не указан)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Список numpy массивов с эмбеддингами
|
Список numpy массивов с эмбеддингами
|
||||||
"""
|
"""
|
||||||
if not self._model_loaded:
|
if not self._model_loaded:
|
||||||
await self.load_model()
|
await self.load_model()
|
||||||
|
|
||||||
if not self._model_loaded:
|
if not self._model_loaded:
|
||||||
raise ModelNotLoadedError("Модель не загружена")
|
raise ModelNotLoadedError("Модель не загружена")
|
||||||
|
|
||||||
batch_size = batch_size or self._settings.batch_size
|
batch_size = batch_size or self._settings.batch_size
|
||||||
|
|
||||||
# Очищаем тексты
|
# Очищаем тексты
|
||||||
clean_texts = [self._clean_text(text) for text in texts]
|
clean_texts = [self._clean_text(text) for text in texts]
|
||||||
|
|
||||||
# Выполняем батч-обработку в thread pool
|
# Выполняем батч-обработку в thread pool
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
embeddings = await loop.run_in_executor(
|
embeddings = await loop.run_in_executor(
|
||||||
@@ -226,71 +235,67 @@ class RAGService:
|
|||||||
clean_texts,
|
clean_texts,
|
||||||
batch_size,
|
batch_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
async def get_embedding(self, text: str) -> np.ndarray:
|
async def get_embedding(self, text: str) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
Получает эмбеддинг текста (асинхронно).
|
Получает эмбеддинг текста (асинхронно).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
text: Текст для векторизации
|
text: Текст для векторизации
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Numpy массив с эмбеддингом
|
Numpy массив с эмбеддингом
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ModelNotLoadedError: Если модель не загружена
|
ModelNotLoadedError: Если модель не загружена
|
||||||
TextTooShortError: Если текст слишком короткий
|
TextTooShortError: Если текст слишком короткий
|
||||||
"""
|
"""
|
||||||
if not self._model_loaded:
|
if not self._model_loaded:
|
||||||
await self.load_model()
|
await self.load_model()
|
||||||
|
|
||||||
if not self._model_loaded:
|
if not self._model_loaded:
|
||||||
raise ModelNotLoadedError("Модель не загружена")
|
raise ModelNotLoadedError("Модель не загружена")
|
||||||
|
|
||||||
# Очищаем текст
|
# Очищаем текст
|
||||||
clean_text = self._clean_text(text)
|
clean_text = self._clean_text(text)
|
||||||
|
|
||||||
if len(clean_text) < self.min_text_length:
|
if len(clean_text) < self.min_text_length:
|
||||||
raise TextTooShortError(
|
raise TextTooShortError(
|
||||||
f"Текст слишком короткий (минимум {self.min_text_length} символов)"
|
f"Текст слишком короткий (минимум {self.min_text_length} символов)"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Выполняем в отдельном потоке
|
# Выполняем в отдельном потоке
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
embedding = await loop.run_in_executor(
|
embedding = await loop.run_in_executor(None, self._get_embedding_sync, clean_text)
|
||||||
None,
|
|
||||||
self._get_embedding_sync,
|
|
||||||
clean_text
|
|
||||||
)
|
|
||||||
|
|
||||||
return embedding
|
return embedding
|
||||||
|
|
||||||
def _clean_text(self, text: str) -> str:
|
def _clean_text(self, text: str) -> str:
|
||||||
"""Очищает текст от лишних символов."""
|
"""Очищает текст от лишних символов."""
|
||||||
if not text:
|
if not text:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
# Удаляем лишние пробелы и переносы строк
|
# Удаляем лишние пробелы и переносы строк
|
||||||
clean = " ".join(text.split())
|
clean = " ".join(text.split())
|
||||||
|
|
||||||
# Удаляем служебные символы (например "^" для helper сообщений)
|
# Удаляем служебные символы (например "^" для helper сообщений)
|
||||||
if clean == "^":
|
if clean == "^":
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
return clean.strip()
|
return clean.strip()
|
||||||
|
|
||||||
async def calculate_score(self, text: str) -> ScoringResult:
|
async def calculate_score(self, text: str) -> ScoringResult:
|
||||||
"""
|
"""
|
||||||
Рассчитывает скор для текста поста.
|
Рассчитывает скор для текста поста.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
text: Текст поста для оценки
|
text: Текст поста для оценки
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
ScoringResult с оценкой
|
ScoringResult с оценкой
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ScoringError: При ошибке расчета
|
ScoringError: При ошибке расчета
|
||||||
InsufficientExamplesError: Если недостаточно примеров
|
InsufficientExamplesError: Если недостаточно примеров
|
||||||
@@ -299,16 +304,17 @@ class RAGService:
|
|||||||
try:
|
try:
|
||||||
# Получаем эмбеддинг текста
|
# Получаем эмбеддинг текста
|
||||||
embedding = await self.get_embedding(text)
|
embedding = await self.get_embedding(text)
|
||||||
|
|
||||||
# Логируем первые элементы вектора для отладки
|
# Логируем первые элементы вектора для отладки
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"RAGService: embedding[:3]={embedding[:3].tolist()}, "
|
f"RAGService: embedding[:3]={embedding[:3].tolist()}, text_preview='{text[:30]}'"
|
||||||
f"text_preview='{text[:30]}'"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Рассчитываем скор через VectorStore
|
# Рассчитываем скор через VectorStore
|
||||||
score, confidence, score_pos_only = self.vector_store.calculate_similarity_score(embedding)
|
score, confidence, score_pos_only = self.vector_store.calculate_similarity_score(
|
||||||
|
embedding
|
||||||
|
)
|
||||||
|
|
||||||
return ScoringResult(
|
return ScoringResult(
|
||||||
score=score,
|
score=score,
|
||||||
confidence=confidence,
|
confidence=confidence,
|
||||||
@@ -317,22 +323,22 @@ class RAGService:
|
|||||||
negative_examples=self.vector_store.negative_count,
|
negative_examples=self.vector_store.negative_count,
|
||||||
model=self.model_name,
|
model=self.model_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
except (InsufficientExamplesError, TextTooShortError):
|
except (InsufficientExamplesError, TextTooShortError):
|
||||||
# Пробрасываем ожидаемые исключения
|
# Пробрасываем ожидаемые исключения
|
||||||
raise
|
raise
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"RAGService: Ошибка расчета скора: {e}")
|
logger.error(f"RAGService: Ошибка расчета скора: {e}")
|
||||||
raise ScoringError(f"Ошибка расчета скора: {e}")
|
raise ScoringError(f"Ошибка расчета скора: {e}")
|
||||||
|
|
||||||
async def add_positive_example(self, text: str) -> bool:
|
async def add_positive_example(self, text: str) -> bool:
|
||||||
"""
|
"""
|
||||||
Добавляет текст как положительный пример (опубликованный пост).
|
Добавляет текст как положительный пример (опубликованный пост).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
text: Текст опубликованного поста
|
text: Текст опубликованного поста
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
True если пример добавлен, False если дубликат/короткий текст
|
True если пример добавлен, False если дубликат/короткий текст
|
||||||
"""
|
"""
|
||||||
@@ -341,32 +347,32 @@ class RAGService:
|
|||||||
if len(clean_text) < self.min_text_length:
|
if len(clean_text) < self.min_text_length:
|
||||||
logger.debug("RAGService: Текст слишком короткий для примера, пропускаем")
|
logger.debug("RAGService: Текст слишком короткий для примера, пропускаем")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Получаем эмбеддинг
|
# Получаем эмбеддинг
|
||||||
embedding = await self.get_embedding(clean_text)
|
embedding = await self.get_embedding(clean_text)
|
||||||
|
|
||||||
# Вычисляем хеш для дедупликации
|
# Вычисляем хеш для дедупликации
|
||||||
text_hash = VectorStore.compute_text_hash(clean_text)
|
text_hash = VectorStore.compute_text_hash(clean_text)
|
||||||
|
|
||||||
# Добавляем в хранилище
|
# Добавляем в хранилище
|
||||||
added = self.vector_store.add_positive(embedding, text_hash)
|
added = self.vector_store.add_positive(embedding, text_hash)
|
||||||
|
|
||||||
if added:
|
if added:
|
||||||
logger.info("RAGService: Добавлен положительный пример")
|
logger.info("RAGService: Добавлен положительный пример")
|
||||||
|
|
||||||
return added
|
return added
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"RAGService: Ошибка добавления положительного примера: {e}")
|
logger.error(f"RAGService: Ошибка добавления положительного примера: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def add_negative_example(self, text: str) -> bool:
|
async def add_negative_example(self, text: str) -> bool:
|
||||||
"""
|
"""
|
||||||
Добавляет текст как отрицательный пример (отклоненный пост).
|
Добавляет текст как отрицательный пример (отклоненный пост).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
text: Текст отклоненного поста
|
text: Текст отклоненного поста
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
True если пример добавлен, False если дубликат/короткий текст
|
True если пример добавлен, False если дубликат/короткий текст
|
||||||
"""
|
"""
|
||||||
@@ -375,29 +381,102 @@ class RAGService:
|
|||||||
if len(clean_text) < self.min_text_length:
|
if len(clean_text) < self.min_text_length:
|
||||||
logger.debug("RAGService: Текст слишком короткий для примера, пропускаем")
|
logger.debug("RAGService: Текст слишком короткий для примера, пропускаем")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Получаем эмбеддинг
|
# Получаем эмбеддинг
|
||||||
embedding = await self.get_embedding(clean_text)
|
embedding = await self.get_embedding(clean_text)
|
||||||
|
|
||||||
# Вычисляем хеш для дедупликации
|
# Вычисляем хеш для дедупликации
|
||||||
text_hash = VectorStore.compute_text_hash(clean_text)
|
text_hash = VectorStore.compute_text_hash(clean_text)
|
||||||
|
|
||||||
# Добавляем в хранилище
|
# Добавляем в хранилище
|
||||||
added = self.vector_store.add_negative(embedding, text_hash)
|
added = self.vector_store.add_negative(embedding, text_hash)
|
||||||
|
|
||||||
if added:
|
if added:
|
||||||
logger.info("RAGService: Добавлен отрицательный пример")
|
logger.info("RAGService: Добавлен отрицательный пример")
|
||||||
|
|
||||||
return added
|
return added
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"RAGService: Ошибка добавления отрицательного примера: {e}")
|
logger.error(f"RAGService: Ошибка добавления отрицательного примера: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
async def add_submitted_post(
|
||||||
|
self,
|
||||||
|
text: str,
|
||||||
|
post_id: int | None = None,
|
||||||
|
rag_score: float | None = None,
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Добавляет submitted-пост в коллекцию для индексации.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Текст поста
|
||||||
|
post_id: ID поста (опционально)
|
||||||
|
rag_score: RAG скор поста (опционально)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True если добавлен, False если дубликат/короткий текст
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
clean_text = self._clean_text(text)
|
||||||
|
if len(clean_text) < self.min_text_length:
|
||||||
|
logger.debug("RAGService: Текст слишком короткий для submitted, пропускаем")
|
||||||
|
return False
|
||||||
|
|
||||||
|
embedding = await self.get_embedding(clean_text)
|
||||||
|
text_hash = VectorStore.compute_text_hash(clean_text)
|
||||||
|
created_at = int(datetime.now().timestamp())
|
||||||
|
|
||||||
|
added = self.vector_store.add_submitted(
|
||||||
|
vector=embedding,
|
||||||
|
text_hash=text_hash,
|
||||||
|
created_at=created_at,
|
||||||
|
post_id=post_id,
|
||||||
|
text=clean_text,
|
||||||
|
rag_score=rag_score,
|
||||||
|
)
|
||||||
|
|
||||||
|
if added:
|
||||||
|
logger.info("RAGService: Добавлен submitted-пост")
|
||||||
|
|
||||||
|
return added
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"RAGService: Ошибка добавления submitted-поста: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def find_similar_posts(
|
||||||
|
self,
|
||||||
|
text: str,
|
||||||
|
threshold: float = 0.9,
|
||||||
|
hours: int = 24,
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Ищет похожие submitted-посты за последние N часов.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Текст для поиска
|
||||||
|
threshold: Минимальный порог similarity (0.0 - 1.0)
|
||||||
|
hours: Количество часов для фильтрации
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Список dict с полями: similarity, created_at, post_id, text, rag_score
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
embedding = await self.get_embedding(text)
|
||||||
|
return self.vector_store.find_similar_submitted(
|
||||||
|
vector=embedding,
|
||||||
|
threshold=threshold,
|
||||||
|
hours=hours,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"RAGService: Ошибка поиска похожих постов: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
async def warmup(self) -> bool:
|
async def warmup(self) -> bool:
|
||||||
"""
|
"""
|
||||||
Прогревает модель (загружает если не загружена).
|
Прогревает модель (загружает если не загружена).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
True если модель загружена успешно
|
True если модель загружена успешно
|
||||||
"""
|
"""
|
||||||
@@ -407,13 +486,15 @@ class RAGService:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"RAGService: Ошибка прогрева модели: {e}")
|
logger.error(f"RAGService: Ошибка прогрева модели: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def save_vectors(self) -> None:
|
def save_vectors(self) -> None:
|
||||||
"""Сохраняет векторы на диск."""
|
"""Сохраняет векторы на диск (включая submitted)."""
|
||||||
if self.vector_store.storage_path:
|
if self.vector_store.storage_path:
|
||||||
self.vector_store.save_to_disk()
|
self.vector_store.save_to_disk()
|
||||||
|
if self.vector_store.submitted_path:
|
||||||
def get_stats(self) -> Dict[str, Any]:
|
self.vector_store.save_submitted_to_disk()
|
||||||
|
|
||||||
|
def get_stats(self) -> dict[str, Any]:
|
||||||
"""Возвращает статистику сервиса."""
|
"""Возвращает статистику сервиса."""
|
||||||
return {
|
return {
|
||||||
"model_name": self.model_name,
|
"model_name": self.model_name,
|
||||||
@@ -424,13 +505,13 @@ class RAGService:
|
|||||||
|
|
||||||
|
|
||||||
# Глобальный экземпляр сервиса (singleton)
|
# Глобальный экземпляр сервиса (singleton)
|
||||||
_rag_service: Optional[RAGService] = None
|
_rag_service: RAGService | None = None
|
||||||
|
|
||||||
|
|
||||||
def get_rag_service() -> RAGService:
|
def get_rag_service() -> RAGService:
|
||||||
"""
|
"""
|
||||||
Возвращает глобальный экземпляр RAG сервиса.
|
Возвращает глобальный экземпляр RAG сервиса.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
RAGService: Экземпляр сервиса
|
RAGService: Экземпляр сервиса
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -9,8 +9,9 @@ import hashlib
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import threading
|
import threading
|
||||||
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Optional, Tuple
|
from typing import Any
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -22,89 +23,109 @@ logger = logging.getLogger(__name__)
|
|||||||
class VectorStore:
|
class VectorStore:
|
||||||
"""
|
"""
|
||||||
In-memory хранилище векторов для RAG.
|
In-memory хранилище векторов для RAG.
|
||||||
|
|
||||||
Хранит отдельно положительные (опубликованные) и отрицательные (отклоненные)
|
Хранит отдельно положительные (опубликованные) и отрицательные (отклоненные)
|
||||||
примеры. Использует косинусное сходство для расчета скора.
|
примеры. Использует косинусное сходство для расчета скора.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
vector_dim: Размерность векторов (384 для all-MiniLM-L12-v2)
|
vector_dim: Размерность векторов (384 для all-MiniLM-L12-v2)
|
||||||
max_examples: Максимальное количество примеров каждого типа
|
max_examples: Максимальное количество примеров каждого типа
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
vector_dim: int = 384,
|
vector_dim: int = 384,
|
||||||
max_examples: int = 10000,
|
max_examples: int = 10000,
|
||||||
storage_path: Optional[str] = None,
|
max_submitted: int = 5000,
|
||||||
|
storage_path: str | None = None,
|
||||||
|
submitted_path: str | None = None,
|
||||||
score_multiplier: float = 5.0,
|
score_multiplier: float = 5.0,
|
||||||
k: int = 3,
|
k: int = 3,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Инициализация хранилища.
|
Инициализация хранилища.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
vector_dim: Размерность векторов
|
vector_dim: Размерность векторов
|
||||||
max_examples: Максимальное количество примеров каждого типа
|
max_examples: Максимальное количество примеров каждого типа
|
||||||
|
max_submitted: Максимальное количество submitted-постов
|
||||||
storage_path: Путь для сохранения/загрузки векторов (опционально)
|
storage_path: Путь для сохранения/загрузки векторов (опционально)
|
||||||
|
submitted_path: Путь для сохранения/загрузки submitted-постов (опционально)
|
||||||
score_multiplier: Множитель для масштабирования разницы в скорах
|
score_multiplier: Множитель для масштабирования разницы в скорах
|
||||||
k: Количество ближайших примеров для расчета среднего сходства
|
k: Количество ближайших примеров для расчета среднего сходства
|
||||||
"""
|
"""
|
||||||
self.vector_dim = vector_dim
|
self.vector_dim = vector_dim
|
||||||
self.max_examples = max_examples
|
self.max_examples = max_examples
|
||||||
|
self.max_submitted = max_submitted
|
||||||
self.storage_path = storage_path
|
self.storage_path = storage_path
|
||||||
|
self.submitted_path = submitted_path
|
||||||
self.score_multiplier = score_multiplier
|
self.score_multiplier = score_multiplier
|
||||||
self.k = k
|
self.k = k
|
||||||
|
|
||||||
# Инициализируем пустые массивы
|
# Инициализируем пустые массивы
|
||||||
# Используем список для динамического добавления, потом конвертируем в numpy
|
# Используем список для динамического добавления, потом конвертируем в numpy
|
||||||
self._positive_vectors: list = []
|
self._positive_vectors: list = []
|
||||||
self._negative_vectors: list = []
|
self._negative_vectors: list = []
|
||||||
self._positive_hashes: list = [] # Хеши текстов для дедупликации
|
self._positive_hashes: list = [] # Хеши текстов для дедупликации
|
||||||
self._negative_hashes: list = []
|
self._negative_hashes: list = []
|
||||||
|
|
||||||
|
# Submitted-посты (третья коллекция)
|
||||||
|
self._submitted_vectors: list = []
|
||||||
|
self._submitted_hashes: list = []
|
||||||
|
self._submitted_created_at: list = [] # Unix timestamps
|
||||||
|
self._submitted_post_ids: list = []
|
||||||
|
self._submitted_texts: list = []
|
||||||
|
self._submitted_rag_scores: list = []
|
||||||
|
|
||||||
# Lock для потокобезопасности
|
# Lock для потокобезопасности
|
||||||
self._lock = threading.Lock()
|
self._lock = threading.Lock()
|
||||||
|
|
||||||
# Пытаемся загрузить сохраненные векторы
|
# Пытаемся загрузить сохраненные векторы
|
||||||
# Всегда вызываем _load_from_disk если есть storage_path - он сам решит что загружать
|
|
||||||
if storage_path:
|
if storage_path:
|
||||||
self._load_from_disk()
|
self._load_from_disk()
|
||||||
|
if submitted_path:
|
||||||
|
self._load_submitted_from_disk()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def positive_count(self) -> int:
|
def positive_count(self) -> int:
|
||||||
"""Количество положительных примеров."""
|
"""Количество положительных примеров."""
|
||||||
return len(self._positive_vectors)
|
return len(self._positive_vectors)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def negative_count(self) -> int:
|
def negative_count(self) -> int:
|
||||||
"""Количество отрицательных примеров."""
|
"""Количество отрицательных примеров."""
|
||||||
return len(self._negative_vectors)
|
return len(self._negative_vectors)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def total_count(self) -> int:
|
def total_count(self) -> int:
|
||||||
"""Общее количество примеров."""
|
"""Общее количество примеров."""
|
||||||
return self.positive_count + self.negative_count
|
return self.positive_count + self.negative_count
|
||||||
|
|
||||||
|
@property
|
||||||
|
def submitted_count(self) -> int:
|
||||||
|
"""Количество submitted-постов."""
|
||||||
|
return len(self._submitted_vectors)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def compute_text_hash(text: str) -> str:
|
def compute_text_hash(text: str) -> str:
|
||||||
"""Вычисляет хеш текста для дедупликации."""
|
"""Вычисляет хеш текста для дедупликации."""
|
||||||
return hashlib.md5(text.encode('utf-8')).hexdigest()
|
return hashlib.md5(text.encode("utf-8")).hexdigest()
|
||||||
|
|
||||||
def _normalize_vector(self, vector: np.ndarray) -> np.ndarray:
|
def _normalize_vector(self, vector: np.ndarray) -> np.ndarray:
|
||||||
"""Нормализует вектор для косинусного сходства."""
|
"""Нормализует вектор для косинусного сходства."""
|
||||||
norm = np.linalg.norm(vector)
|
norm = np.linalg.norm(vector)
|
||||||
if norm == 0:
|
if norm == 0:
|
||||||
return vector
|
return vector
|
||||||
return vector / norm
|
return vector / norm
|
||||||
|
|
||||||
def add_positive(self, vector: np.ndarray, text_hash: Optional[str] = None) -> bool:
|
def add_positive(self, vector: np.ndarray, text_hash: str | None = None) -> bool:
|
||||||
"""
|
"""
|
||||||
Добавляет положительный пример (опубликованный пост).
|
Добавляет положительный пример (опубликованный пост).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
vector: Векторное представление текста
|
vector: Векторное представление текста
|
||||||
text_hash: Хеш текста для дедупликации (опционально)
|
text_hash: Хеш текста для дедупликации (опционально)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
True если добавлен, False если дубликат или превышен лимит
|
True если добавлен, False если дубликат или превышен лимит
|
||||||
"""
|
"""
|
||||||
@@ -113,71 +134,73 @@ class VectorStore:
|
|||||||
if text_hash and text_hash in self._positive_hashes:
|
if text_hash and text_hash in self._positive_hashes:
|
||||||
logger.debug("VectorStore: Пропуск дубликата положительного примера")
|
logger.debug("VectorStore: Пропуск дубликата положительного примера")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Проверяем лимит
|
# Проверяем лимит
|
||||||
if len(self._positive_vectors) >= self.max_examples:
|
if len(self._positive_vectors) >= self.max_examples:
|
||||||
# Удаляем самый старый пример (FIFO)
|
# Удаляем самый старый пример (FIFO)
|
||||||
self._positive_vectors.pop(0)
|
self._positive_vectors.pop(0)
|
||||||
self._positive_hashes.pop(0)
|
self._positive_hashes.pop(0)
|
||||||
logger.debug("VectorStore: Удален старый положительный пример (лимит)")
|
logger.debug("VectorStore: Удален старый положительный пример (лимит)")
|
||||||
|
|
||||||
# Нормализуем и добавляем
|
# Нормализуем и добавляем
|
||||||
normalized = self._normalize_vector(vector)
|
normalized = self._normalize_vector(vector)
|
||||||
self._positive_vectors.append(normalized)
|
self._positive_vectors.append(normalized)
|
||||||
if text_hash:
|
if text_hash:
|
||||||
self._positive_hashes.append(text_hash)
|
self._positive_hashes.append(text_hash)
|
||||||
|
|
||||||
logger.info(f"VectorStore: Добавлен положительный пример (всего: {self.positive_count})")
|
logger.info(
|
||||||
|
f"VectorStore: Добавлен положительный пример (всего: {self.positive_count})"
|
||||||
|
)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def add_positive_batch(
|
def add_positive_batch(
|
||||||
self,
|
self, vectors: list[np.ndarray], text_hashes: list[str] | None = None
|
||||||
vectors: List[np.ndarray],
|
|
||||||
text_hashes: Optional[List[str]] = None
|
|
||||||
) -> int:
|
) -> int:
|
||||||
"""
|
"""
|
||||||
Добавляет батч положительных примеров.
|
Добавляет батч положительных примеров.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
vectors: Список векторов
|
vectors: Список векторов
|
||||||
text_hashes: Список хешей текстов для дедупликации
|
text_hashes: Список хешей текстов для дедупликации
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Количество добавленных примеров
|
Количество добавленных примеров
|
||||||
"""
|
"""
|
||||||
if text_hashes is None:
|
if text_hashes is None:
|
||||||
text_hashes = [None] * len(vectors)
|
text_hashes = [None] * len(vectors)
|
||||||
|
|
||||||
added = 0
|
added = 0
|
||||||
with self._lock:
|
with self._lock:
|
||||||
for vector, text_hash in zip(vectors, text_hashes):
|
for vector, text_hash in zip(vectors, text_hashes):
|
||||||
# Проверяем дубликат по хешу
|
# Проверяем дубликат по хешу
|
||||||
if text_hash and text_hash in self._positive_hashes:
|
if text_hash and text_hash in self._positive_hashes:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Проверяем лимит
|
# Проверяем лимит
|
||||||
if len(self._positive_vectors) >= self.max_examples:
|
if len(self._positive_vectors) >= self.max_examples:
|
||||||
self._positive_vectors.pop(0)
|
self._positive_vectors.pop(0)
|
||||||
self._positive_hashes.pop(0)
|
self._positive_hashes.pop(0)
|
||||||
|
|
||||||
# Нормализуем и добавляем
|
# Нормализуем и добавляем
|
||||||
normalized = self._normalize_vector(vector)
|
normalized = self._normalize_vector(vector)
|
||||||
self._positive_vectors.append(normalized)
|
self._positive_vectors.append(normalized)
|
||||||
if text_hash:
|
if text_hash:
|
||||||
self._positive_hashes.append(text_hash)
|
self._positive_hashes.append(text_hash)
|
||||||
added += 1
|
added += 1
|
||||||
|
|
||||||
logger.info(f"VectorStore: Добавлено {added} положительных примеров батчем (всего: {self.positive_count})")
|
logger.info(
|
||||||
|
f"VectorStore: Добавлено {added} положительных примеров батчем (всего: {self.positive_count})"
|
||||||
|
)
|
||||||
return added
|
return added
|
||||||
|
|
||||||
def add_negative(self, vector: np.ndarray, text_hash: Optional[str] = None) -> bool:
|
def add_negative(self, vector: np.ndarray, text_hash: str | None = None) -> bool:
|
||||||
"""
|
"""
|
||||||
Добавляет отрицательный пример (отклоненный пост).
|
Добавляет отрицательный пример (отклоненный пост).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
vector: Векторное представление текста
|
vector: Векторное представление текста
|
||||||
text_hash: Хеш текста для дедупликации (опционально)
|
text_hash: Хеш текста для дедупликации (опционально)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
True если добавлен, False если дубликат или превышен лимит
|
True если добавлен, False если дубликат или превышен лимит
|
||||||
"""
|
"""
|
||||||
@@ -186,112 +209,208 @@ class VectorStore:
|
|||||||
if text_hash and text_hash in self._negative_hashes:
|
if text_hash and text_hash in self._negative_hashes:
|
||||||
logger.debug("VectorStore: Пропуск дубликата отрицательного примера")
|
logger.debug("VectorStore: Пропуск дубликата отрицательного примера")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Проверяем лимит
|
# Проверяем лимит
|
||||||
if len(self._negative_vectors) >= self.max_examples:
|
if len(self._negative_vectors) >= self.max_examples:
|
||||||
# Удаляем самый старый пример (FIFO)
|
# Удаляем самый старый пример (FIFO)
|
||||||
self._negative_vectors.pop(0)
|
self._negative_vectors.pop(0)
|
||||||
self._negative_hashes.pop(0)
|
self._negative_hashes.pop(0)
|
||||||
logger.debug("VectorStore: Удален старый отрицательный пример (лимит)")
|
logger.debug("VectorStore: Удален старый отрицательный пример (лимит)")
|
||||||
|
|
||||||
# Нормализуем и добавляем
|
# Нормализуем и добавляем
|
||||||
normalized = self._normalize_vector(vector)
|
normalized = self._normalize_vector(vector)
|
||||||
self._negative_vectors.append(normalized)
|
self._negative_vectors.append(normalized)
|
||||||
if text_hash:
|
if text_hash:
|
||||||
self._negative_hashes.append(text_hash)
|
self._negative_hashes.append(text_hash)
|
||||||
|
|
||||||
logger.info(f"VectorStore: Добавлен отрицательный пример (всего: {self.negative_count})")
|
logger.info(
|
||||||
|
f"VectorStore: Добавлен отрицательный пример (всего: {self.negative_count})"
|
||||||
|
)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def add_negative_batch(
|
def add_negative_batch(
|
||||||
self,
|
self, vectors: list[np.ndarray], text_hashes: list[str] | None = None
|
||||||
vectors: List[np.ndarray],
|
|
||||||
text_hashes: Optional[List[str]] = None
|
|
||||||
) -> int:
|
) -> int:
|
||||||
"""
|
"""
|
||||||
Добавляет батч отрицательных примеров.
|
Добавляет батч отрицательных примеров.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
vectors: Список векторов
|
vectors: Список векторов
|
||||||
text_hashes: Список хешей текстов для дедупликации
|
text_hashes: Список хешей текстов для дедупликации
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Количество добавленных примеров
|
Количество добавленных примеров
|
||||||
"""
|
"""
|
||||||
if text_hashes is None:
|
if text_hashes is None:
|
||||||
text_hashes = [None] * len(vectors)
|
text_hashes = [None] * len(vectors)
|
||||||
|
|
||||||
added = 0
|
added = 0
|
||||||
with self._lock:
|
with self._lock:
|
||||||
for vector, text_hash in zip(vectors, text_hashes):
|
for vector, text_hash in zip(vectors, text_hashes):
|
||||||
# Проверяем дубликат по хешу
|
# Проверяем дубликат по хешу
|
||||||
if text_hash and text_hash in self._negative_hashes:
|
if text_hash and text_hash in self._negative_hashes:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Проверяем лимит
|
# Проверяем лимит
|
||||||
if len(self._negative_vectors) >= self.max_examples:
|
if len(self._negative_vectors) >= self.max_examples:
|
||||||
self._negative_vectors.pop(0)
|
self._negative_vectors.pop(0)
|
||||||
self._negative_hashes.pop(0)
|
self._negative_hashes.pop(0)
|
||||||
|
|
||||||
# Нормализуем и добавляем
|
# Нормализуем и добавляем
|
||||||
normalized = self._normalize_vector(vector)
|
normalized = self._normalize_vector(vector)
|
||||||
self._negative_vectors.append(normalized)
|
self._negative_vectors.append(normalized)
|
||||||
if text_hash:
|
if text_hash:
|
||||||
self._negative_hashes.append(text_hash)
|
self._negative_hashes.append(text_hash)
|
||||||
added += 1
|
added += 1
|
||||||
|
|
||||||
logger.info(f"VectorStore: Добавлено {added} отрицательных примеров батчем (всего: {self.negative_count})")
|
logger.info(
|
||||||
|
f"VectorStore: Добавлено {added} отрицательных примеров батчем (всего: {self.negative_count})"
|
||||||
|
)
|
||||||
return added
|
return added
|
||||||
|
|
||||||
def calculate_similarity_score(self, vector: np.ndarray) -> Tuple[float, float, float]:
|
def add_submitted(
|
||||||
|
self,
|
||||||
|
vector: np.ndarray,
|
||||||
|
text_hash: str,
|
||||||
|
created_at: int,
|
||||||
|
post_id: int | None = None,
|
||||||
|
text: str = "",
|
||||||
|
rag_score: float | None = None,
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Добавляет submitted-пост в коллекцию.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vector: Векторное представление текста
|
||||||
|
text_hash: Хеш текста для дедупликации
|
||||||
|
created_at: Unix timestamp создания
|
||||||
|
post_id: ID поста (опционально)
|
||||||
|
text: Текст поста
|
||||||
|
rag_score: RAG скор поста (опционально)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True если добавлен, False если дубликат
|
||||||
|
"""
|
||||||
|
with self._lock:
|
||||||
|
if text_hash in self._submitted_hashes:
|
||||||
|
logger.debug("VectorStore: Пропуск дубликата submitted-поста")
|
||||||
|
return False
|
||||||
|
|
||||||
|
if len(self._submitted_vectors) >= self.max_submitted:
|
||||||
|
self._submitted_vectors.pop(0)
|
||||||
|
self._submitted_hashes.pop(0)
|
||||||
|
self._submitted_created_at.pop(0)
|
||||||
|
self._submitted_post_ids.pop(0)
|
||||||
|
self._submitted_texts.pop(0)
|
||||||
|
self._submitted_rag_scores.pop(0)
|
||||||
|
logger.debug("VectorStore: Удален старый submitted-пост (лимит)")
|
||||||
|
|
||||||
|
normalized = self._normalize_vector(vector)
|
||||||
|
self._submitted_vectors.append(normalized)
|
||||||
|
self._submitted_hashes.append(text_hash)
|
||||||
|
self._submitted_created_at.append(created_at)
|
||||||
|
self._submitted_post_ids.append(post_id)
|
||||||
|
self._submitted_texts.append(text)
|
||||||
|
self._submitted_rag_scores.append(rag_score)
|
||||||
|
|
||||||
|
logger.info(f"VectorStore: Добавлен submitted-пост (всего: {self.submitted_count})")
|
||||||
|
return True
|
||||||
|
|
||||||
|
def find_similar_submitted(
|
||||||
|
self,
|
||||||
|
vector: np.ndarray,
|
||||||
|
threshold: float,
|
||||||
|
hours: int,
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Ищет похожие submitted-посты за последние N часов.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vector: Векторное представление запроса
|
||||||
|
threshold: Минимальный порог similarity (0.0 - 1.0)
|
||||||
|
hours: Количество часов для фильтрации (created_at >= now - hours*3600)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Список dict с полями: similarity, created_at, post_id, text, rag_score
|
||||||
|
"""
|
||||||
|
with self._lock:
|
||||||
|
if self.submitted_count == 0:
|
||||||
|
return []
|
||||||
|
|
||||||
|
now = int(time.time())
|
||||||
|
cutoff = now - hours * 3600
|
||||||
|
|
||||||
|
normalized = self._normalize_vector(vector)
|
||||||
|
submitted_matrix = np.array(self._submitted_vectors)
|
||||||
|
similarities = np.dot(submitted_matrix, normalized)
|
||||||
|
|
||||||
|
results: list[dict[str, Any]] = []
|
||||||
|
for i, sim in enumerate(similarities):
|
||||||
|
if float(sim) < threshold:
|
||||||
|
continue
|
||||||
|
created_at = self._submitted_created_at[i]
|
||||||
|
if created_at < cutoff:
|
||||||
|
continue
|
||||||
|
results.append(
|
||||||
|
{
|
||||||
|
"similarity": float(sim),
|
||||||
|
"created_at": created_at,
|
||||||
|
"post_id": self._submitted_post_ids[i],
|
||||||
|
"text": self._submitted_texts[i],
|
||||||
|
"rag_score": self._submitted_rag_scores[i],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return sorted(results, key=lambda x: x["similarity"], reverse=True)
|
||||||
|
|
||||||
|
def calculate_similarity_score(self, vector: np.ndarray) -> tuple[float, float, float]:
|
||||||
"""
|
"""
|
||||||
Рассчитывает скор на основе сходства с примерами.
|
Рассчитывает скор на основе сходства с примерами.
|
||||||
|
|
||||||
Алгоритм:
|
Алгоритм:
|
||||||
1. Вычисляем косинусное сходство со всеми примерами
|
1. Вычисляем косинусное сходство со всеми примерами
|
||||||
2. Используем топ-k ближайших примеров для более чувствительной оценки
|
2. Используем топ-k ближайших примеров для более чувствительной оценки
|
||||||
3. Сравниваем топ-k положительных с топ-k отрицательными
|
3. Сравниваем топ-k положительных с топ-k отрицательными
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
vector: Векторное представление нового поста
|
vector: Векторное представление нового поста
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple (score, confidence, score_pos_only):
|
Tuple (score, confidence, score_pos_only):
|
||||||
- score: Оценка от 0.0 до 1.0 (neg/pos формула)
|
- score: Оценка от 0.0 до 1.0 (neg/pos формула)
|
||||||
- confidence: Уверенность (зависит от количества примеров)
|
- confidence: Уверенность (зависит от количества примеров)
|
||||||
- score_pos_only: Оценка только по положительным примерам
|
- score_pos_only: Оценка только по положительным примерам
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
InsufficientExamplesError: Если недостаточно примеров
|
InsufficientExamplesError: Если недостаточно примеров
|
||||||
"""
|
"""
|
||||||
with self._lock:
|
with self._lock:
|
||||||
if self.positive_count == 0:
|
if self.positive_count == 0:
|
||||||
raise InsufficientExamplesError(
|
raise InsufficientExamplesError("Нет положительных примеров для сравнения")
|
||||||
"Нет положительных примеров для сравнения"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Нормализуем входной вектор
|
# Нормализуем входной вектор
|
||||||
normalized = self._normalize_vector(vector)
|
normalized = self._normalize_vector(np.asarray(vector).flatten())
|
||||||
|
|
||||||
# Конвертируем в numpy массивы для быстрых вычислений
|
# Конвертируем в numpy массивы для быстрых вычислений
|
||||||
pos_matrix = np.array(self._positive_vectors)
|
# Используем vstack для гарантии одинаковой формы (совместимость со старым npz)
|
||||||
|
pos_matrix = np.vstack([np.asarray(v).flatten() for v in self._positive_vectors])
|
||||||
|
|
||||||
# Косинусное сходство с положительными примерами
|
# Косинусное сходство с положительными примерами
|
||||||
# Для нормализованных векторов это просто скалярное произведение
|
# Для нормализованных векторов это просто скалярное произведение
|
||||||
pos_similarities = np.dot(pos_matrix, normalized)
|
pos_similarities = np.dot(pos_matrix, normalized)
|
||||||
|
|
||||||
# Косинусное сходство с отрицательными примерами
|
# Косинусное сходство с отрицательными примерами
|
||||||
if self.negative_count > 0:
|
if self.negative_count > 0:
|
||||||
neg_matrix = np.array(self._negative_vectors)
|
neg_matrix = np.vstack([np.asarray(v).flatten() for v in self._negative_vectors])
|
||||||
neg_similarities = np.dot(neg_matrix, normalized)
|
neg_similarities = np.dot(neg_matrix, normalized)
|
||||||
else:
|
else:
|
||||||
neg_similarities = np.array([])
|
neg_similarities = np.array([])
|
||||||
|
|
||||||
# Используем топ-k ближайших примеров для расчета среднего сходства
|
# Используем топ-k ближайших примеров для расчета среднего сходства
|
||||||
k_pos = min(self.k, len(pos_similarities))
|
k_pos = min(self.k, len(pos_similarities))
|
||||||
top_k_pos = np.sort(pos_similarities)[-k_pos:]
|
top_k_pos = np.sort(pos_similarities)[-k_pos:]
|
||||||
avg_pos = float(np.mean(top_k_pos))
|
avg_pos = float(np.mean(top_k_pos))
|
||||||
|
|
||||||
# Для отрицательных: если их меньше k, берем все, иначе топ-k
|
# Для отрицательных: если их меньше k, берем все, иначе топ-k
|
||||||
if len(neg_similarities) > 0:
|
if len(neg_similarities) > 0:
|
||||||
k_neg = min(self.k, len(neg_similarities))
|
k_neg = min(self.k, len(neg_similarities))
|
||||||
@@ -300,11 +419,11 @@ class VectorStore:
|
|||||||
else:
|
else:
|
||||||
# Если нет отрицательных примеров, используем нейтральное значение
|
# Если нет отрицательных примеров, используем нейтральное значение
|
||||||
avg_neg = avg_pos # Нейтральный скор = 0.5
|
avg_neg = avg_pos # Нейтральный скор = 0.5
|
||||||
|
|
||||||
# Формула расчета score: (diff * scale + 1) / 2, переводим из [-1, 1] в [0, 1]
|
# Формула расчета score: (diff * scale + 1) / 2, переводим из [-1, 1] в [0, 1]
|
||||||
diff = avg_pos - avg_neg
|
diff = avg_pos - avg_neg
|
||||||
score_neg_pos = np.clip((diff * self.score_multiplier + 1) / 2, 0.0, 1.0)
|
score_neg_pos = np.clip((diff * self.score_multiplier + 1) / 2, 0.0, 1.0)
|
||||||
|
|
||||||
# === Вариант 2: pos only (только положительные, топ-k ближайших) ===
|
# === Вариант 2: pos only (только положительные, топ-k ближайших) ===
|
||||||
# Берём топ-5 ближайших положительных примеров
|
# Берём топ-5 ближайших положительных примеров
|
||||||
top_5_k = min(5, len(pos_similarities))
|
top_5_k = min(5, len(pos_similarities))
|
||||||
@@ -312,20 +431,20 @@ class VectorStore:
|
|||||||
# Нормализуем: 0.85 -> 0.0, 0.95 -> 1.0 (типичный диапазон для BERT)
|
# Нормализуем: 0.85 -> 0.0, 0.95 -> 1.0 (типичный диапазон для BERT)
|
||||||
score_pos_only = (top_5_sim - 0.85) / 0.10
|
score_pos_only = (top_5_sim - 0.85) / 0.10
|
||||||
score_pos_only = max(0.0, min(1.0, score_pos_only))
|
score_pos_only = max(0.0, min(1.0, score_pos_only))
|
||||||
|
|
||||||
# Основной скор — neg/pos
|
# Основной скор — neg/pos
|
||||||
score = score_neg_pos
|
score = score_neg_pos
|
||||||
|
|
||||||
# Confidence зависит от количества примеров (100% при 1000 примерах)
|
# Confidence зависит от количества примеров (100% при 1000 примерах)
|
||||||
total_examples = self.positive_count + self.negative_count
|
total_examples = self.positive_count + self.negative_count
|
||||||
confidence = min(1.0, total_examples / 1000)
|
confidence = min(1.0, total_examples / 1000)
|
||||||
|
|
||||||
# Дополнительная диагностическая информация
|
# Дополнительная диагностическая информация
|
||||||
pos_mean = float(np.mean(pos_similarities))
|
pos_mean = float(np.mean(pos_similarities))
|
||||||
pos_std = float(np.std(pos_similarities))
|
pos_std = float(np.std(pos_similarities))
|
||||||
pos_min = float(np.min(pos_similarities))
|
pos_min = float(np.min(pos_similarities))
|
||||||
pos_max = float(np.max(pos_similarities))
|
pos_max = float(np.max(pos_similarities))
|
||||||
|
|
||||||
if len(neg_similarities) > 0:
|
if len(neg_similarities) > 0:
|
||||||
neg_mean = float(np.mean(neg_similarities))
|
neg_mean = float(np.mean(neg_similarities))
|
||||||
neg_std = float(np.std(neg_similarities))
|
neg_std = float(np.std(neg_similarities))
|
||||||
@@ -333,7 +452,7 @@ class VectorStore:
|
|||||||
neg_max = float(np.max(neg_similarities))
|
neg_max = float(np.max(neg_similarities))
|
||||||
else:
|
else:
|
||||||
neg_mean = neg_std = neg_min = neg_max = 0.0
|
neg_mean = neg_std = neg_min = neg_max = 0.0
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"VectorStore: k={self.k}, k_pos={k_pos}, k_neg={k_neg if len(neg_similarities) > 0 else 0}, "
|
f"VectorStore: k={self.k}, k_pos={k_pos}, k_neg={k_neg if len(neg_similarities) > 0 else 0}, "
|
||||||
f"avg_pos={avg_pos:.4f}, avg_neg={avg_neg:.4f}, "
|
f"avg_pos={avg_pos:.4f}, avg_neg={avg_neg:.4f}, "
|
||||||
@@ -342,58 +461,145 @@ class VectorStore:
|
|||||||
f"pos_mean={pos_mean:.4f}±{pos_std:.4f}[{pos_min:.4f}-{pos_max:.4f}], "
|
f"pos_mean={pos_mean:.4f}±{pos_std:.4f}[{pos_min:.4f}-{pos_max:.4f}], "
|
||||||
f"neg_mean={neg_mean:.4f}±{neg_std:.4f}[{neg_min:.4f}-{neg_max:.4f}]"
|
f"neg_mean={neg_mean:.4f}±{neg_std:.4f}[{neg_min:.4f}-{neg_max:.4f}]"
|
||||||
)
|
)
|
||||||
|
|
||||||
return score, confidence, score_pos_only
|
return score, confidence, score_pos_only
|
||||||
|
|
||||||
def save_to_disk(self, path: Optional[str] = None) -> None:
|
def save_to_disk(self, path: str | None = None) -> None:
|
||||||
"""
|
"""
|
||||||
Сохраняет векторы на диск.
|
Сохраняет векторы на диск.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
path: Путь для сохранения (если не указан, используется storage_path)
|
path: Путь для сохранения (если не указан, используется storage_path)
|
||||||
"""
|
"""
|
||||||
save_path = path or self.storage_path
|
save_path = path or self.storage_path
|
||||||
if not save_path:
|
if not save_path:
|
||||||
raise VectorStoreError("Путь для сохранения не указан")
|
raise VectorStoreError("Путь для сохранения не указан")
|
||||||
|
|
||||||
with self._lock:
|
with self._lock:
|
||||||
# Создаем директорию если нужно
|
# Создаем директорию если нужно
|
||||||
Path(save_path).parent.mkdir(parents=True, exist_ok=True)
|
Path(save_path).parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
# Сохраняем в npz формате
|
# Сохраняем в npz формате
|
||||||
np.savez_compressed(
|
np.savez_compressed(
|
||||||
save_path,
|
save_path,
|
||||||
positive_vectors=np.array(self._positive_vectors) if self._positive_vectors else np.array([]),
|
positive_vectors=np.array(self._positive_vectors)
|
||||||
negative_vectors=np.array(self._negative_vectors) if self._negative_vectors else np.array([]),
|
if self._positive_vectors
|
||||||
|
else np.array([]),
|
||||||
|
negative_vectors=np.array(self._negative_vectors)
|
||||||
|
if self._negative_vectors
|
||||||
|
else np.array([]),
|
||||||
positive_hashes=np.array(self._positive_hashes, dtype=object),
|
positive_hashes=np.array(self._positive_hashes, dtype=object),
|
||||||
negative_hashes=np.array(self._negative_hashes, dtype=object),
|
negative_hashes=np.array(self._negative_hashes, dtype=object),
|
||||||
vector_dim=self.vector_dim,
|
vector_dim=self.vector_dim,
|
||||||
max_examples=self.max_examples,
|
max_examples=self.max_examples,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"VectorStore: Сохранено на диск ({self.positive_count} pos, "
|
f"VectorStore: Сохранено на диск ({self.positive_count} pos, "
|
||||||
f"{self.negative_count} neg): {save_path}"
|
f"{self.negative_count} neg): {save_path}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def save_submitted_to_disk(self, path: str | None = None) -> None:
|
||||||
|
"""
|
||||||
|
Сохраняет submitted-коллекцию на диск.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: Путь для сохранения (если не указан, используется submitted_path)
|
||||||
|
"""
|
||||||
|
save_path = path or self.submitted_path
|
||||||
|
if not save_path:
|
||||||
|
raise VectorStoreError("Путь для сохранения submitted не указан")
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
Path(save_path).parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
np.savez_compressed(
|
||||||
|
save_path,
|
||||||
|
vectors=np.array(self._submitted_vectors)
|
||||||
|
if self._submitted_vectors
|
||||||
|
else np.array([]),
|
||||||
|
hashes=np.array(self._submitted_hashes, dtype=object),
|
||||||
|
created_at=np.array(self._submitted_created_at)
|
||||||
|
if self._submitted_created_at
|
||||||
|
else np.array([]),
|
||||||
|
post_ids=np.array(self._submitted_post_ids, dtype=object),
|
||||||
|
texts=np.array(self._submitted_texts, dtype=object),
|
||||||
|
rag_scores=np.array(self._submitted_rag_scores, dtype=object),
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"VectorStore: Сохранено submitted ({self.submitted_count}): {save_path}")
|
||||||
|
|
||||||
|
def _load_submitted_from_disk(self) -> None:
|
||||||
|
"""Загружает submitted-коллекцию с диска."""
|
||||||
|
if not self.submitted_path or not os.path.exists(self.submitted_path):
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
with self._lock:
|
||||||
|
data = np.load(self.submitted_path, allow_pickle=True)
|
||||||
|
|
||||||
|
vectors = data.get("vectors", np.array([]))
|
||||||
|
if vectors.size > 0:
|
||||||
|
if len(vectors.shape) == 2:
|
||||||
|
self._submitted_vectors = [
|
||||||
|
self._normalize_vector(np.array(v)) for v in vectors
|
||||||
|
]
|
||||||
|
elif len(vectors.shape) == 1:
|
||||||
|
self._submitted_vectors = [self._normalize_vector(np.array(vectors))]
|
||||||
|
else:
|
||||||
|
self._submitted_vectors = []
|
||||||
|
else:
|
||||||
|
self._submitted_vectors = []
|
||||||
|
|
||||||
|
hashes = data.get("hashes", np.array([]))
|
||||||
|
self._submitted_hashes = list(hashes) if hashes.size > 0 else []
|
||||||
|
|
||||||
|
created_at = data.get("created_at", np.array([]))
|
||||||
|
self._submitted_created_at = list(created_at) if created_at.size > 0 else []
|
||||||
|
|
||||||
|
post_ids = data.get("post_ids", np.array([]))
|
||||||
|
self._submitted_post_ids = list(post_ids) if post_ids.size > 0 else []
|
||||||
|
|
||||||
|
texts = data.get("texts", np.array([]))
|
||||||
|
self._submitted_texts = list(texts) if texts.size > 0 else []
|
||||||
|
|
||||||
|
rag_scores = data.get("rag_scores", np.array([]))
|
||||||
|
self._submitted_rag_scores = list(rag_scores) if rag_scores.size > 0 else []
|
||||||
|
|
||||||
|
# Выравниваем длины (на случай поврежденных данных)
|
||||||
|
n = len(self._submitted_vectors)
|
||||||
|
self._submitted_hashes = self._submitted_hashes[:n]
|
||||||
|
self._submitted_created_at = self._submitted_created_at[:n]
|
||||||
|
self._submitted_post_ids = self._submitted_post_ids[:n]
|
||||||
|
self._submitted_texts = self._submitted_texts[:n]
|
||||||
|
self._submitted_rag_scores = self._submitted_rag_scores[:n]
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"VectorStore: Загружено submitted ({self.submitted_count}): {self.submitted_path}"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"VectorStore: Ошибка загрузки submitted с диска: {e}")
|
||||||
|
|
||||||
def _load_from_disk(self) -> None:
|
def _load_from_disk(self) -> None:
|
||||||
"""Загружает векторы с диска."""
|
"""Загружает векторы с диска."""
|
||||||
if not self.storage_path:
|
if not self.storage_path:
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with self._lock:
|
with self._lock:
|
||||||
storage_dir = Path(self.storage_path).parent
|
storage_dir = Path(self.storage_path).parent
|
||||||
positive_npy = storage_dir / "positive_embeddings.npy"
|
positive_npy = storage_dir / "positive_embeddings.npy"
|
||||||
negative_npy = storage_dir / "negative_embeddings.npy"
|
negative_npy = storage_dir / "negative_embeddings.npy"
|
||||||
|
|
||||||
# Отладочное логирование
|
# Отладочное логирование
|
||||||
logger.info(f"VectorStore: Проверка путей - storage_dir={storage_dir}, positive_npy={positive_npy}, exists={positive_npy.exists()}, negative_npy={negative_npy}, exists={negative_npy.exists()}")
|
logger.info(
|
||||||
|
f"VectorStore: Проверка путей - storage_dir={storage_dir}, positive_npy={positive_npy}, exists={positive_npy.exists()}, negative_npy={negative_npy}, exists={negative_npy.exists()}"
|
||||||
|
)
|
||||||
|
|
||||||
# Проверяем наличие отдельных .npy файлов
|
# Проверяем наличие отдельных .npy файлов
|
||||||
if positive_npy.exists() or negative_npy.exists():
|
if positive_npy.exists() or negative_npy.exists():
|
||||||
logger.info("VectorStore: Обнаружены отдельные .npy файлы, загружаем их...")
|
logger.info("VectorStore: Обнаружены отдельные .npy файлы, загружаем их...")
|
||||||
|
|
||||||
# Загружаем положительные векторы
|
# Загружаем положительные векторы
|
||||||
if positive_npy.exists():
|
if positive_npy.exists():
|
||||||
pos_vectors = np.load(positive_npy, allow_pickle=False)
|
pos_vectors = np.load(positive_npy, allow_pickle=False)
|
||||||
@@ -406,10 +612,14 @@ class VectorStore:
|
|||||||
# Один вектор [dim]
|
# Один вектор [dim]
|
||||||
self._positive_vectors = [pos_vectors]
|
self._positive_vectors = [pos_vectors]
|
||||||
else:
|
else:
|
||||||
logger.warning(f"VectorStore: Неожиданная размерность positive_embeddings.npy: {pos_vectors.shape}")
|
logger.warning(
|
||||||
|
f"VectorStore: Неожиданная размерность positive_embeddings.npy: {pos_vectors.shape}"
|
||||||
|
)
|
||||||
self._positive_vectors = []
|
self._positive_vectors = []
|
||||||
logger.info(f"VectorStore: Загружено {len(self._positive_vectors)} положительных векторов из {positive_npy}")
|
logger.info(
|
||||||
|
f"VectorStore: Загружено {len(self._positive_vectors)} положительных векторов из {positive_npy}"
|
||||||
|
)
|
||||||
|
|
||||||
# Загружаем отрицательные векторы
|
# Загружаем отрицательные векторы
|
||||||
if negative_npy.exists():
|
if negative_npy.exists():
|
||||||
neg_vectors = np.load(negative_npy, allow_pickle=False)
|
neg_vectors = np.load(negative_npy, allow_pickle=False)
|
||||||
@@ -422,52 +632,62 @@ class VectorStore:
|
|||||||
# Один вектор [dim]
|
# Один вектор [dim]
|
||||||
self._negative_vectors = [neg_vectors]
|
self._negative_vectors = [neg_vectors]
|
||||||
else:
|
else:
|
||||||
logger.warning(f"VectorStore: Неожиданная размерность negative_embeddings.npy: {neg_vectors.shape}")
|
logger.warning(
|
||||||
|
f"VectorStore: Неожиданная размерность negative_embeddings.npy: {neg_vectors.shape}"
|
||||||
|
)
|
||||||
self._negative_vectors = []
|
self._negative_vectors = []
|
||||||
logger.info(f"VectorStore: Загружено {len(self._negative_vectors)} отрицательных векторов из {negative_npy}")
|
logger.info(
|
||||||
|
f"VectorStore: Загружено {len(self._negative_vectors)} отрицательных векторов из {negative_npy}"
|
||||||
|
)
|
||||||
|
|
||||||
# Нормализуем загруженные векторы
|
# Нормализуем загруженные векторы
|
||||||
self._positive_vectors = [self._normalize_vector(np.array(v)) for v in self._positive_vectors]
|
self._positive_vectors = [
|
||||||
self._negative_vectors = [self._normalize_vector(np.array(v)) for v in self._negative_vectors]
|
self._normalize_vector(np.array(v)) for v in self._positive_vectors
|
||||||
|
]
|
||||||
|
self._negative_vectors = [
|
||||||
|
self._normalize_vector(np.array(v)) for v in self._negative_vectors
|
||||||
|
]
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"VectorStore: Загружено с диска из .npy файлов ({self.positive_count} pos, "
|
f"VectorStore: Загружено с диска из .npy файлов ({self.positive_count} pos, "
|
||||||
f"{self.negative_count} neg)"
|
f"{self.negative_count} neg)"
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
# Если отдельных .npy файлов нет, пытаемся загрузить из старого формата .npz
|
# Если отдельных .npy файлов нет, пытаемся загрузить из старого формата .npz
|
||||||
if os.path.exists(self.storage_path):
|
if os.path.exists(self.storage_path):
|
||||||
logger.info(f"VectorStore: Загружаем из старого формата .npz: {self.storage_path}")
|
logger.info(
|
||||||
|
f"VectorStore: Загружаем из старого формата .npz: {self.storage_path}"
|
||||||
|
)
|
||||||
data = np.load(self.storage_path, allow_pickle=True)
|
data = np.load(self.storage_path, allow_pickle=True)
|
||||||
|
|
||||||
# Загружаем векторы
|
# Загружаем векторы
|
||||||
pos_vectors = data.get('positive_vectors', np.array([]))
|
pos_vectors = data.get("positive_vectors", np.array([]))
|
||||||
neg_vectors = data.get('negative_vectors', np.array([]))
|
neg_vectors = data.get("negative_vectors", np.array([]))
|
||||||
|
|
||||||
if pos_vectors.size > 0:
|
if pos_vectors.size > 0:
|
||||||
self._positive_vectors = list(pos_vectors)
|
self._positive_vectors = list(pos_vectors)
|
||||||
if neg_vectors.size > 0:
|
if neg_vectors.size > 0:
|
||||||
self._negative_vectors = list(neg_vectors)
|
self._negative_vectors = list(neg_vectors)
|
||||||
|
|
||||||
# Загружаем хеши
|
# Загружаем хеши
|
||||||
pos_hashes = data.get('positive_hashes', np.array([]))
|
pos_hashes = data.get("positive_hashes", np.array([]))
|
||||||
neg_hashes = data.get('negative_hashes', np.array([]))
|
neg_hashes = data.get("negative_hashes", np.array([]))
|
||||||
|
|
||||||
if pos_hashes.size > 0:
|
if pos_hashes.size > 0:
|
||||||
self._positive_hashes = list(pos_hashes)
|
self._positive_hashes = list(pos_hashes)
|
||||||
if neg_hashes.size > 0:
|
if neg_hashes.size > 0:
|
||||||
self._negative_hashes = list(neg_hashes)
|
self._negative_hashes = list(neg_hashes)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"VectorStore: Загружено с диска ({self.positive_count} pos, "
|
f"VectorStore: Загружено с диска ({self.positive_count} pos, "
|
||||||
f"{self.negative_count} neg): {self.storage_path}"
|
f"{self.negative_count} neg): {self.storage_path}"
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"VectorStore: Ошибка загрузки с диска: {e}")
|
logger.error(f"VectorStore: Ошибка загрузки с диска: {e}")
|
||||||
# Продолжаем с пустым хранилищем
|
# Продолжаем с пустым хранилищем
|
||||||
|
|
||||||
def clear(self) -> None:
|
def clear(self) -> None:
|
||||||
"""Очищает все векторы."""
|
"""Очищает все векторы."""
|
||||||
with self._lock:
|
with self._lock:
|
||||||
@@ -475,40 +695,48 @@ class VectorStore:
|
|||||||
self._negative_vectors.clear()
|
self._negative_vectors.clear()
|
||||||
self._positive_hashes.clear()
|
self._positive_hashes.clear()
|
||||||
self._negative_hashes.clear()
|
self._negative_hashes.clear()
|
||||||
|
self._submitted_vectors.clear()
|
||||||
|
self._submitted_hashes.clear()
|
||||||
|
self._submitted_created_at.clear()
|
||||||
|
self._submitted_post_ids.clear()
|
||||||
|
self._submitted_texts.clear()
|
||||||
|
self._submitted_rag_scores.clear()
|
||||||
logger.info("VectorStore: Хранилище очищено")
|
logger.info("VectorStore: Хранилище очищено")
|
||||||
|
|
||||||
def get_stats(self) -> dict:
|
def get_stats(self) -> dict:
|
||||||
"""Возвращает статистику хранилища."""
|
"""Возвращает статистику хранилища."""
|
||||||
return {
|
return {
|
||||||
"positive_count": self.positive_count,
|
"positive_count": self.positive_count,
|
||||||
"negative_count": self.negative_count,
|
"negative_count": self.negative_count,
|
||||||
"total_count": self.total_count,
|
"total_count": self.total_count,
|
||||||
|
"submitted_count": self.submitted_count,
|
||||||
"vector_dim": self.vector_dim,
|
"vector_dim": self.vector_dim,
|
||||||
"max_examples": self.max_examples,
|
"max_examples": self.max_examples,
|
||||||
|
"max_submitted": self.max_submitted,
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_scoring_params(self) -> dict:
|
def get_scoring_params(self) -> dict:
|
||||||
"""Возвращает текущие параметры формулы расчета score."""
|
"""Возвращает текущие параметры формулы расчета score."""
|
||||||
return {
|
return {
|
||||||
"score_multiplier": self.score_multiplier,
|
"score_multiplier": self.score_multiplier,
|
||||||
"k": self.k,
|
"k": self.k,
|
||||||
}
|
}
|
||||||
|
|
||||||
def update_scoring_params(
|
def update_scoring_params(
|
||||||
self,
|
self,
|
||||||
score_multiplier: Optional[float] = None,
|
score_multiplier: float | None = None,
|
||||||
k: Optional[int] = None,
|
k: int | None = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""
|
"""
|
||||||
Обновляет параметры формулы расчета score.
|
Обновляет параметры формулы расчета score.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
score_multiplier: Множитель для масштабирования разницы (должен быть > 0)
|
score_multiplier: Множитель для масштабирования разницы (должен быть > 0)
|
||||||
k: Количество ближайших примеров для расчета среднего (должно быть >= 1)
|
k: Количество ближайших примеров для расчета среднего (должно быть >= 1)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict: Обновленные параметры
|
dict: Обновленные параметры
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: При невалидных значениях
|
ValueError: При невалидных значениях
|
||||||
"""
|
"""
|
||||||
@@ -517,15 +745,15 @@ class VectorStore:
|
|||||||
if score_multiplier <= 0:
|
if score_multiplier <= 0:
|
||||||
raise ValueError("score_multiplier должен быть > 0")
|
raise ValueError("score_multiplier должен быть > 0")
|
||||||
self.score_multiplier = score_multiplier
|
self.score_multiplier = score_multiplier
|
||||||
|
|
||||||
if k is not None:
|
if k is not None:
|
||||||
if k < 1:
|
if k < 1:
|
||||||
raise ValueError("k должен быть >= 1")
|
raise ValueError("k должен быть >= 1")
|
||||||
self.k = k
|
self.k = k
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"VectorStore: Параметры формулы обновлены: "
|
f"VectorStore: Параметры формулы обновлены: "
|
||||||
f"score_multiplier={self.score_multiplier}, k={self.k}"
|
f"score_multiplier={self.score_multiplier}, k={self.k}"
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.get_scoring_params()
|
return self.get_scoring_params()
|
||||||
|
|||||||
@@ -20,6 +20,8 @@ services:
|
|||||||
- RAG_CACHE_DIR=/app/data/models
|
- RAG_CACHE_DIR=/app/data/models
|
||||||
- RAG_VECTORS_PATH=/app/data/vectors/vectors.npz
|
- RAG_VECTORS_PATH=/app/data/vectors/vectors.npz
|
||||||
- RAG_MAX_EXAMPLES=${RAG_MAX_EXAMPLES:-10000}
|
- RAG_MAX_EXAMPLES=${RAG_MAX_EXAMPLES:-10000}
|
||||||
|
- RAG_MAX_SUBMITTED=${RAG_MAX_SUBMITTED:-5000}
|
||||||
|
- RAG_SUBMITTED_PATH=/app/data/vectors/submitted.npz
|
||||||
- RAG_SCORE_MULTIPLIER=${RAG_SCORE_MULTIPLIER:-5.0}
|
- RAG_SCORE_MULTIPLIER=${RAG_SCORE_MULTIPLIER:-5.0}
|
||||||
- RAG_BATCH_SIZE=${RAG_BATCH_SIZE:-16}
|
- RAG_BATCH_SIZE=${RAG_BATCH_SIZE:-16}
|
||||||
- RAG_MIN_TEXT_LENGTH=${RAG_MIN_TEXT_LENGTH:-3}
|
- RAG_MIN_TEXT_LENGTH=${RAG_MIN_TEXT_LENGTH:-3}
|
||||||
|
|||||||
@@ -7,6 +7,8 @@ RAG_CACHE_DIR=data/models
|
|||||||
# VectorStore
|
# VectorStore
|
||||||
RAG_VECTORS_PATH=data/vectors/vectors.npz
|
RAG_VECTORS_PATH=data/vectors/vectors.npz
|
||||||
RAG_MAX_EXAMPLES=10000
|
RAG_MAX_EXAMPLES=10000
|
||||||
|
RAG_MAX_SUBMITTED=5000
|
||||||
|
RAG_SUBMITTED_PATH=data/vectors/submitted.npz
|
||||||
RAG_SCORE_MULTIPLIER=5.0
|
RAG_SCORE_MULTIPLIER=5.0
|
||||||
|
|
||||||
# Батч-обработка
|
# Батч-обработка
|
||||||
|
|||||||
16
tests/conftest.py
Normal file
16
tests/conftest.py
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
"""
|
||||||
|
Pytest fixtures для RAG сервиса.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def allow_no_auth():
|
||||||
|
"""Разрешает запросы без API ключа в тестах."""
|
||||||
|
os.environ["RAG_ALLOW_NO_AUTH"] = "true"
|
||||||
|
yield
|
||||||
|
if "RAG_ALLOW_NO_AUTH" in os.environ:
|
||||||
|
del os.environ["RAG_ALLOW_NO_AUTH"]
|
||||||
169
tests/test_api_similar_submitted.py
Normal file
169
tests/test_api_similar_submitted.py
Normal file
@@ -0,0 +1,169 @@
|
|||||||
|
"""
|
||||||
|
Тесты API endpoints /similar и /submitted.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
from app.main import app
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_rag_service():
|
||||||
|
"""Mock RAGService для тестов API без загрузки модели."""
|
||||||
|
service = MagicMock()
|
||||||
|
service.is_model_loaded = True
|
||||||
|
service.vector_store.submitted_count = 0
|
||||||
|
return service
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def client(mock_rag_service):
|
||||||
|
"""TestClient с переопределённым RAG сервисом."""
|
||||||
|
from app.api.routes import get_service
|
||||||
|
|
||||||
|
def override_get_service():
|
||||||
|
return mock_rag_service
|
||||||
|
|
||||||
|
app.dependency_overrides[get_service] = override_get_service
|
||||||
|
# get_rag_service используется при создании сервиса - get_service вызывает get_rag_service
|
||||||
|
# Смотрю routes: get_service возвращает get_rag_service(). Значит override get_service достаточно.
|
||||||
|
with TestClient(app) as c:
|
||||||
|
yield c
|
||||||
|
app.dependency_overrides.clear()
|
||||||
|
|
||||||
|
|
||||||
|
def test_similar_endpoint(client, mock_rag_service):
|
||||||
|
"""POST /api/v1/similar возвращает похожие посты."""
|
||||||
|
mock_rag_service.find_similar_posts = AsyncMock(
|
||||||
|
return_value=[
|
||||||
|
{
|
||||||
|
"similarity": 0.95,
|
||||||
|
"created_at": 1706270000,
|
||||||
|
"post_id": 123,
|
||||||
|
"text": "Похожий пост",
|
||||||
|
"rag_score": 0.85,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/api/v1/similar",
|
||||||
|
json={"text": "Текст для поиска", "threshold": 0.9, "hours": 24},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["similar_count"] == 1
|
||||||
|
assert len(data["similar_posts"]) == 1
|
||||||
|
assert data["similar_posts"][0]["similarity"] == 0.95
|
||||||
|
assert data["similar_posts"][0]["post_id"] == 123
|
||||||
|
assert data["similar_posts"][0]["text"] == "Похожий пост"
|
||||||
|
assert data["similar_posts"][0]["rag_score"] == 0.85
|
||||||
|
|
||||||
|
|
||||||
|
def test_similar_endpoint_empty(client, mock_rag_service):
|
||||||
|
"""POST /api/v1/similar с пустым результатом."""
|
||||||
|
mock_rag_service.find_similar_posts = AsyncMock(return_value=[])
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/api/v1/similar",
|
||||||
|
json={"text": "Уникальный текст", "threshold": 0.99, "hours": 1},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["similar_count"] == 0
|
||||||
|
assert response.json()["similar_posts"] == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_similar_endpoint_default_params(client, mock_rag_service):
|
||||||
|
"""POST /api/v1/similar с дефолтными параметрами."""
|
||||||
|
mock_rag_service.find_similar_posts = AsyncMock(return_value=[])
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/api/v1/similar",
|
||||||
|
json={"text": "Текст"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
mock_rag_service.find_similar_posts.assert_called_once_with(
|
||||||
|
text="Текст",
|
||||||
|
threshold=0.9,
|
||||||
|
hours=24,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_submitted_endpoint_success(client, mock_rag_service):
|
||||||
|
"""POST /api/v1/submitted успешно добавляет пост."""
|
||||||
|
mock_rag_service.add_submitted_post = AsyncMock(return_value=True)
|
||||||
|
mock_rag_service.vector_store.submitted_count = 5
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/api/v1/submitted",
|
||||||
|
json={"text": "Новый пост", "post_id": 42, "rag_score": 0.8},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["success"] is True
|
||||||
|
assert "добавлен" in data["message"].lower()
|
||||||
|
assert data["submitted_count"] == 5
|
||||||
|
mock_rag_service.add_submitted_post.assert_called_once_with(
|
||||||
|
text="Новый пост",
|
||||||
|
post_id=42,
|
||||||
|
rag_score=0.8,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_submitted_endpoint_duplicate(client, mock_rag_service):
|
||||||
|
"""POST /api/v1/submitted при дубликате возвращает success=False."""
|
||||||
|
mock_rag_service.add_submitted_post = AsyncMock(return_value=False)
|
||||||
|
mock_rag_service.vector_store.submitted_count = 10
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/api/v1/submitted",
|
||||||
|
json={"text": "Дубликат поста"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["success"] is False
|
||||||
|
assert data["submitted_count"] == 10
|
||||||
|
|
||||||
|
|
||||||
|
def test_submitted_endpoint_optional_fields(client, mock_rag_service):
|
||||||
|
"""POST /api/v1/submitted с только текстом (post_id, rag_score опциональны)."""
|
||||||
|
mock_rag_service.add_submitted_post = AsyncMock(return_value=True)
|
||||||
|
mock_rag_service.vector_store.submitted_count = 1
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/api/v1/submitted",
|
||||||
|
json={"text": "Только текст"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
mock_rag_service.add_submitted_post.assert_called_once_with(
|
||||||
|
text="Только текст",
|
||||||
|
post_id=None,
|
||||||
|
rag_score=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_similar_validation_empty_text(client):
|
||||||
|
"""POST /api/v1/similar с пустым текстом возвращает 422."""
|
||||||
|
response = client.post(
|
||||||
|
"/api/v1/similar",
|
||||||
|
json={"text": "", "threshold": 0.9, "hours": 24},
|
||||||
|
)
|
||||||
|
assert response.status_code == 422
|
||||||
|
|
||||||
|
|
||||||
|
def test_submitted_validation_empty_text(client):
|
||||||
|
"""POST /api/v1/submitted с пустым текстом возвращает 422."""
|
||||||
|
response = client.post(
|
||||||
|
"/api/v1/submitted",
|
||||||
|
json={"text": ""},
|
||||||
|
)
|
||||||
|
assert response.status_code == 422
|
||||||
212
tests/test_vector_store_submitted.py
Normal file
212
tests/test_vector_store_submitted.py
Normal file
@@ -0,0 +1,212 @@
|
|||||||
|
"""
|
||||||
|
Тесты для submitted-коллекции VectorStore.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.storage.vector_store import VectorStore
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def vector_store(tmp_path):
|
||||||
|
"""VectorStore с временным путём для submitted."""
|
||||||
|
return VectorStore(
|
||||||
|
vector_dim=4,
|
||||||
|
max_examples=10,
|
||||||
|
max_submitted=5,
|
||||||
|
storage_path=None,
|
||||||
|
submitted_path=str(tmp_path / "submitted.npz"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_vector():
|
||||||
|
"""Нормализованный вектор для тестов."""
|
||||||
|
v = np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32)
|
||||||
|
return v / np.linalg.norm(v)
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_submitted(vector_store, sample_vector):
|
||||||
|
"""Добавление submitted-поста."""
|
||||||
|
added = vector_store.add_submitted(
|
||||||
|
vector=sample_vector,
|
||||||
|
text_hash="abc123",
|
||||||
|
created_at=1000,
|
||||||
|
post_id=42,
|
||||||
|
text="Test post",
|
||||||
|
rag_score=0.85,
|
||||||
|
)
|
||||||
|
assert added is True
|
||||||
|
assert vector_store.submitted_count == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_submitted_duplicate(vector_store, sample_vector):
|
||||||
|
"""Дубликат по хешу не добавляется."""
|
||||||
|
vector_store.add_submitted(
|
||||||
|
vector=sample_vector,
|
||||||
|
text_hash="same_hash",
|
||||||
|
created_at=1000,
|
||||||
|
text="First",
|
||||||
|
)
|
||||||
|
added = vector_store.add_submitted(
|
||||||
|
vector=sample_vector,
|
||||||
|
text_hash="same_hash",
|
||||||
|
created_at=2000,
|
||||||
|
text="Second",
|
||||||
|
)
|
||||||
|
assert added is False
|
||||||
|
assert vector_store.submitted_count == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_submitted_fifo(vector_store, sample_vector):
|
||||||
|
"""При превышении max_submitted удаляется самый старый (FIFO)."""
|
||||||
|
for i in range(7):
|
||||||
|
v = np.array(
|
||||||
|
[float(i + 1), 0.0, 0.0, 0.0], dtype=np.float32
|
||||||
|
) # i+1 чтобы избежать нулевого вектора
|
||||||
|
v = v / np.linalg.norm(v)
|
||||||
|
vector_store.add_submitted(
|
||||||
|
vector=v,
|
||||||
|
text_hash=f"hash_{i}",
|
||||||
|
created_at=1000 + i,
|
||||||
|
post_id=i,
|
||||||
|
text=f"Post {i}",
|
||||||
|
)
|
||||||
|
assert vector_store.submitted_count == 5 # max_submitted
|
||||||
|
# Должны остаться посты 2, 3, 4, 5, 6 (удалены 0, 1)
|
||||||
|
post_ids = vector_store._submitted_post_ids
|
||||||
|
assert 0 not in post_ids
|
||||||
|
assert 1 not in post_ids
|
||||||
|
assert 2 in post_ids
|
||||||
|
|
||||||
|
|
||||||
|
def test_find_similar_submitted_empty(vector_store, sample_vector):
|
||||||
|
"""Поиск в пустой коллекции возвращает пустой список."""
|
||||||
|
result = vector_store.find_similar_submitted(
|
||||||
|
vector=sample_vector,
|
||||||
|
threshold=0.5,
|
||||||
|
hours=24,
|
||||||
|
)
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_find_similar_submitted(vector_store, sample_vector):
|
||||||
|
"""Поиск похожих постов с фильтром по времени и threshold."""
|
||||||
|
import time
|
||||||
|
|
||||||
|
now = int(time.time())
|
||||||
|
# Похожий вектор
|
||||||
|
similar_v = np.array([0.99, 0.01, 0.0, 0.0], dtype=np.float32)
|
||||||
|
similar_v = similar_v / np.linalg.norm(similar_v)
|
||||||
|
# Непохожий вектор
|
||||||
|
different_v = np.array([0.0, 1.0, 0.0, 0.0], dtype=np.float32)
|
||||||
|
different_v = different_v / np.linalg.norm(different_v)
|
||||||
|
|
||||||
|
vector_store.add_submitted(
|
||||||
|
vector=similar_v,
|
||||||
|
text_hash="similar",
|
||||||
|
created_at=now - 3600, # 1 час назад
|
||||||
|
post_id=1,
|
||||||
|
text="Similar post",
|
||||||
|
rag_score=0.9,
|
||||||
|
)
|
||||||
|
vector_store.add_submitted(
|
||||||
|
vector=different_v,
|
||||||
|
text_hash="different",
|
||||||
|
created_at=now - 3600,
|
||||||
|
post_id=2,
|
||||||
|
text="Different post",
|
||||||
|
rag_score=0.5,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = vector_store.find_similar_submitted(
|
||||||
|
vector=sample_vector,
|
||||||
|
threshold=0.9,
|
||||||
|
hours=24,
|
||||||
|
)
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0]["post_id"] == 1
|
||||||
|
assert result[0]["text"] == "Similar post"
|
||||||
|
assert result[0]["similarity"] >= 0.9
|
||||||
|
|
||||||
|
|
||||||
|
def test_find_similar_submitted_time_filter(vector_store, sample_vector):
|
||||||
|
"""Фильтр по hours исключает старые посты."""
|
||||||
|
import time
|
||||||
|
|
||||||
|
now = int(time.time())
|
||||||
|
vector_store.add_submitted(
|
||||||
|
vector=sample_vector,
|
||||||
|
text_hash="old",
|
||||||
|
created_at=now - 48 * 3600, # 48 часов назад
|
||||||
|
post_id=1,
|
||||||
|
text="Old post",
|
||||||
|
)
|
||||||
|
vector_store.add_submitted(
|
||||||
|
vector=sample_vector,
|
||||||
|
text_hash="recent",
|
||||||
|
created_at=now - 3600, # 1 час назад
|
||||||
|
post_id=2,
|
||||||
|
text="Recent post",
|
||||||
|
)
|
||||||
|
|
||||||
|
result = vector_store.find_similar_submitted(
|
||||||
|
vector=sample_vector,
|
||||||
|
threshold=0.5,
|
||||||
|
hours=24,
|
||||||
|
)
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0]["post_id"] == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_submitted_persistence(vector_store, sample_vector, tmp_path):
|
||||||
|
"""Сохранение и загрузка submitted-коллекции."""
|
||||||
|
vector_store.add_submitted(
|
||||||
|
vector=sample_vector,
|
||||||
|
text_hash="persist",
|
||||||
|
created_at=12345,
|
||||||
|
post_id=999,
|
||||||
|
text="Persisted post",
|
||||||
|
rag_score=0.77,
|
||||||
|
)
|
||||||
|
vector_store.save_submitted_to_disk()
|
||||||
|
|
||||||
|
# Новый store загружает данные
|
||||||
|
store2 = VectorStore(
|
||||||
|
vector_dim=4,
|
||||||
|
max_submitted=5,
|
||||||
|
storage_path=None,
|
||||||
|
submitted_path=str(tmp_path / "submitted.npz"),
|
||||||
|
)
|
||||||
|
assert store2.submitted_count == 1
|
||||||
|
assert store2._submitted_post_ids[0] == 999
|
||||||
|
assert store2._submitted_texts[0] == "Persisted post"
|
||||||
|
assert store2._submitted_rag_scores[0] == 0.77
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_stats_includes_submitted(vector_store, sample_vector):
|
||||||
|
"""get_stats включает submitted_count и max_submitted."""
|
||||||
|
vector_store.add_submitted(
|
||||||
|
vector=sample_vector,
|
||||||
|
text_hash="stat",
|
||||||
|
created_at=1000,
|
||||||
|
text="For stats",
|
||||||
|
)
|
||||||
|
stats = vector_store.get_stats()
|
||||||
|
assert "submitted_count" in stats
|
||||||
|
assert stats["submitted_count"] == 1
|
||||||
|
assert "max_submitted" in stats
|
||||||
|
assert stats["max_submitted"] == 5
|
||||||
|
|
||||||
|
|
||||||
|
def test_clear_submitted(vector_store, sample_vector):
|
||||||
|
"""clear() очищает submitted-коллекцию."""
|
||||||
|
vector_store.add_submitted(
|
||||||
|
vector=sample_vector,
|
||||||
|
text_hash="clear",
|
||||||
|
created_at=1000,
|
||||||
|
text="To clear",
|
||||||
|
)
|
||||||
|
vector_store.clear()
|
||||||
|
assert vector_store.submitted_count == 0
|
||||||
Reference in New Issue
Block a user