354 lines
11 KiB
Python
354 lines
11 KiB
Python
"""
|
||
FastAPI endpoints для RAG сервиса.
|
||
"""
|
||
|
||
import logging
|
||
from typing import Annotated
|
||
|
||
from fastapi import APIRouter, Depends, HTTPException, status
|
||
|
||
from app import __version__
|
||
from app.api.auth import AuthDep
|
||
from app.exceptions import (
|
||
InsufficientExamplesError,
|
||
ModelNotLoadedError,
|
||
ScoringError,
|
||
TextTooShortError,
|
||
)
|
||
from app.schemas import (
|
||
ErrorResponse,
|
||
ExampleRequest,
|
||
ExampleResponse,
|
||
HealthResponse,
|
||
ScoreMetadata,
|
||
ScoreRequest,
|
||
ScoreResponse,
|
||
StatsResponse,
|
||
VectorStoreStats,
|
||
WarmupResponse,
|
||
)
|
||
from app.services.rag_service import RAGService, get_rag_service
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
router = APIRouter()
|
||
|
||
|
||
# Dependency для получения RAG сервиса
|
||
def get_service() -> RAGService:
|
||
"""Возвращает экземпляр RAG сервиса."""
|
||
return get_rag_service()
|
||
|
||
|
||
RAGServiceDep = Annotated[RAGService, Depends(get_service)]
|
||
|
||
|
||
# =============================================================================
|
||
# Health Check
|
||
# =============================================================================
|
||
|
||
@router.get(
|
||
"/health",
|
||
response_model=HealthResponse,
|
||
summary="Проверка здоровья сервиса",
|
||
tags=["health"],
|
||
)
|
||
async def health_check(service: RAGServiceDep) -> HealthResponse:
|
||
"""
|
||
Проверяет состояние сервиса.
|
||
|
||
Returns:
|
||
HealthResponse: Статус сервиса
|
||
"""
|
||
return HealthResponse(
|
||
status="healthy",
|
||
model_loaded=service.is_model_loaded,
|
||
version=__version__,
|
||
)
|
||
|
||
|
||
# =============================================================================
|
||
# Scoring
|
||
# =============================================================================
|
||
|
||
@router.post(
|
||
"/score",
|
||
response_model=ScoreResponse,
|
||
responses={
|
||
400: {"model": ErrorResponse, "description": "Ошибка в запросе"},
|
||
401: {"model": ErrorResponse, "description": "Не авторизован"},
|
||
403: {"model": ErrorResponse, "description": "Доступ запрещён"},
|
||
503: {"model": ErrorResponse, "description": "Сервис недоступен"},
|
||
},
|
||
summary="Расчет скора для текста",
|
||
tags=["scoring"],
|
||
)
|
||
async def calculate_score(
|
||
request: ScoreRequest,
|
||
service: RAGServiceDep,
|
||
_auth: AuthDep,
|
||
) -> ScoreResponse:
|
||
"""
|
||
Рассчитывает скор для текста поста.
|
||
|
||
Args:
|
||
request: Запрос с текстом
|
||
service: RAG сервис
|
||
|
||
Returns:
|
||
ScoreResponse: Результат скоринга
|
||
|
||
Raises:
|
||
HTTPException: При ошибке расчета
|
||
"""
|
||
try:
|
||
result = await service.calculate_score(request.text)
|
||
response_dict = result.to_dict()
|
||
|
||
return ScoreResponse(
|
||
rag_score=response_dict["rag_score"],
|
||
rag_confidence=response_dict["rag_confidence"],
|
||
rag_score_pos_only=response_dict["rag_score_pos_only"],
|
||
meta=ScoreMetadata(**response_dict["meta"]),
|
||
)
|
||
|
||
except TextTooShortError as e:
|
||
logger.warning(f"Текст слишком короткий: {e}")
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
detail={"detail": str(e), "error_type": "TextTooShortError"},
|
||
)
|
||
|
||
except InsufficientExamplesError as e:
|
||
logger.warning(f"Недостаточно примеров: {e}")
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
detail={"detail": str(e), "error_type": "InsufficientExamplesError"},
|
||
)
|
||
|
||
except ModelNotLoadedError as e:
|
||
logger.error(f"Модель не загружена: {e}")
|
||
raise HTTPException(
|
||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||
detail={"detail": str(e), "error_type": "ModelNotLoadedError"},
|
||
)
|
||
|
||
except ScoringError as e:
|
||
logger.error(f"Ошибка скоринга: {e}")
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
detail={"detail": str(e), "error_type": "ScoringError"},
|
||
)
|
||
|
||
|
||
# =============================================================================
|
||
# Examples
|
||
# =============================================================================
|
||
|
||
@router.post(
|
||
"/examples/positive",
|
||
response_model=ExampleResponse,
|
||
responses={
|
||
401: {"model": ErrorResponse, "description": "Не авторизован"},
|
||
403: {"model": ErrorResponse, "description": "Доступ запрещён"},
|
||
503: {"model": ErrorResponse, "description": "Сервис недоступен"},
|
||
},
|
||
summary="Добавить положительный пример",
|
||
tags=["examples"],
|
||
)
|
||
async def add_positive_example(
|
||
request: ExampleRequest,
|
||
service: RAGServiceDep,
|
||
_auth: AuthDep,
|
||
) -> ExampleResponse:
|
||
"""
|
||
Добавляет текст как положительный пример (опубликованный пост).
|
||
|
||
Args:
|
||
request: Запрос с текстом
|
||
service: RAG сервис
|
||
|
||
Returns:
|
||
ExampleResponse: Результат добавления
|
||
"""
|
||
try:
|
||
added = await service.add_positive_example(request.text)
|
||
|
||
if added:
|
||
message = "Положительный пример добавлен"
|
||
else:
|
||
message = "Пример не добавлен (дубликат или слишком короткий текст)"
|
||
|
||
return ExampleResponse(
|
||
success=added,
|
||
message=message,
|
||
positive_count=service.vector_store.positive_count,
|
||
negative_count=service.vector_store.negative_count,
|
||
)
|
||
|
||
except ModelNotLoadedError as e:
|
||
logger.error(f"Модель не загружена: {e}")
|
||
raise HTTPException(
|
||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||
detail={"detail": str(e), "error_type": "ModelNotLoadedError"},
|
||
)
|
||
|
||
|
||
@router.post(
|
||
"/examples/negative",
|
||
response_model=ExampleResponse,
|
||
responses={
|
||
401: {"model": ErrorResponse, "description": "Не авторизован"},
|
||
403: {"model": ErrorResponse, "description": "Доступ запрещён"},
|
||
503: {"model": ErrorResponse, "description": "Сервис недоступен"},
|
||
},
|
||
summary="Добавить отрицательный пример",
|
||
tags=["examples"],
|
||
)
|
||
async def add_negative_example(
|
||
request: ExampleRequest,
|
||
service: RAGServiceDep,
|
||
_auth: AuthDep,
|
||
) -> ExampleResponse:
|
||
"""
|
||
Добавляет текст как отрицательный пример (отклоненный пост).
|
||
|
||
Args:
|
||
request: Запрос с текстом
|
||
service: RAG сервис
|
||
|
||
Returns:
|
||
ExampleResponse: Результат добавления
|
||
"""
|
||
try:
|
||
added = await service.add_negative_example(request.text)
|
||
|
||
if added:
|
||
message = "Отрицательный пример добавлен"
|
||
else:
|
||
message = "Пример не добавлен (дубликат или слишком короткий текст)"
|
||
|
||
return ExampleResponse(
|
||
success=added,
|
||
message=message,
|
||
positive_count=service.vector_store.positive_count,
|
||
negative_count=service.vector_store.negative_count,
|
||
)
|
||
|
||
except ModelNotLoadedError as e:
|
||
logger.error(f"Модель не загружена: {e}")
|
||
raise HTTPException(
|
||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||
detail={"detail": str(e), "error_type": "ModelNotLoadedError"},
|
||
)
|
||
|
||
|
||
# =============================================================================
|
||
# Stats & Warmup
|
||
# =============================================================================
|
||
|
||
@router.get(
|
||
"/stats",
|
||
response_model=StatsResponse,
|
||
responses={
|
||
401: {"model": ErrorResponse, "description": "Не авторизован"},
|
||
403: {"model": ErrorResponse, "description": "Доступ запрещён"},
|
||
},
|
||
summary="Статистика сервиса",
|
||
tags=["monitoring"],
|
||
)
|
||
async def get_stats(service: RAGServiceDep, _auth: AuthDep) -> StatsResponse:
|
||
"""
|
||
Возвращает статистику сервиса.
|
||
|
||
Args:
|
||
service: RAG сервис
|
||
|
||
Returns:
|
||
StatsResponse: Статистика
|
||
"""
|
||
stats = service.get_stats()
|
||
|
||
return StatsResponse(
|
||
model_name=stats["model_name"],
|
||
model_loaded=stats["model_loaded"],
|
||
device=stats["device"],
|
||
cache_dir=stats["cache_dir"],
|
||
vector_store=VectorStoreStats(**stats["vector_store"]),
|
||
)
|
||
|
||
|
||
@router.post(
|
||
"/warmup",
|
||
response_model=WarmupResponse,
|
||
responses={
|
||
401: {"model": ErrorResponse, "description": "Не авторизован"},
|
||
403: {"model": ErrorResponse, "description": "Доступ запрещён"},
|
||
503: {"model": ErrorResponse, "description": "Не удалось загрузить модель"},
|
||
},
|
||
summary="Прогрев модели",
|
||
tags=["management"],
|
||
)
|
||
async def warmup(service: RAGServiceDep, _auth: AuthDep) -> WarmupResponse:
|
||
"""
|
||
Прогревает модель (загружает если не загружена).
|
||
|
||
Args:
|
||
service: RAG сервис
|
||
|
||
Returns:
|
||
WarmupResponse: Результат прогрева
|
||
"""
|
||
success = await service.warmup()
|
||
|
||
if success:
|
||
message = "Модель успешно загружена"
|
||
else:
|
||
message = "Не удалось загрузить модель"
|
||
raise HTTPException(
|
||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||
detail={"detail": message, "error_type": "ModelNotLoadedError"},
|
||
)
|
||
|
||
return WarmupResponse(
|
||
success=success,
|
||
model_loaded=service.is_model_loaded,
|
||
message=message,
|
||
)
|
||
|
||
|
||
@router.post(
|
||
"/save",
|
||
response_model=dict,
|
||
responses={
|
||
401: {"model": ErrorResponse, "description": "Не авторизован"},
|
||
403: {"model": ErrorResponse, "description": "Доступ запрещён"},
|
||
},
|
||
summary="Сохранить векторы на диск",
|
||
tags=["management"],
|
||
)
|
||
async def save_vectors(service: RAGServiceDep, _auth: AuthDep) -> dict:
|
||
"""
|
||
Сохраняет векторы на диск.
|
||
|
||
Args:
|
||
service: RAG сервис
|
||
|
||
Returns:
|
||
dict: Результат сохранения
|
||
"""
|
||
try:
|
||
service.save_vectors()
|
||
return {
|
||
"success": True,
|
||
"message": "Векторы сохранены на диск",
|
||
"positive_count": service.vector_store.positive_count,
|
||
"negative_count": service.vector_store.negative_count,
|
||
}
|
||
except Exception as e:
|
||
logger.error(f"Ошибка сохранения векторов: {e}")
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
detail={"detail": str(e), "error_type": "VectorStoreError"},
|
||
)
|