Initial commit: RAG Service
This commit is contained in:
199
app/main.py
Normal file
199
app/main.py
Normal file
@@ -0,0 +1,199 @@
|
||||
"""
|
||||
FastAPI приложение RAG сервиса.
|
||||
|
||||
Сервис для векторного скоринга текстов с использованием ruBERT.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import sys
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import AsyncGenerator, Optional
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from app import __version__
|
||||
from app.api.routes import router
|
||||
from app.config import get_settings
|
||||
from app.services.rag_service import RAGService, get_rag_service
|
||||
|
||||
# Настройка логирования
|
||||
def setup_logging() -> None:
|
||||
"""Настраивает логирование для приложения."""
|
||||
settings = get_settings()
|
||||
|
||||
logging.basicConfig(
|
||||
level=getattr(logging, settings.log_level.upper()),
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
handlers=[
|
||||
logging.StreamHandler(sys.stdout),
|
||||
],
|
||||
)
|
||||
|
||||
# Уменьшаем логи от библиотек
|
||||
logging.getLogger("transformers").setLevel(logging.WARNING)
|
||||
logging.getLogger("torch").setLevel(logging.WARNING)
|
||||
logging.getLogger("uvicorn.access").setLevel(logging.WARNING)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Глобальная задача автосохранения
|
||||
_autosave_task: Optional[asyncio.Task] = None
|
||||
|
||||
|
||||
async def autosave_loop(service: RAGService, interval: int) -> None:
|
||||
"""
|
||||
Фоновая задача для периодического сохранения векторов.
|
||||
|
||||
Args:
|
||||
service: RAG сервис
|
||||
interval: Интервал сохранения в секундах
|
||||
"""
|
||||
logger.info(f"Автосохранение запущено (интервал: {interval} сек)")
|
||||
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(interval)
|
||||
|
||||
# Сохраняем только если есть данные
|
||||
if service.vector_store.total_count > 0:
|
||||
service.save_vectors()
|
||||
logger.info(
|
||||
f"Автосохранение: сохранено {service.vector_store.positive_count} pos, "
|
||||
f"{service.vector_store.negative_count} neg"
|
||||
)
|
||||
else:
|
||||
logger.debug("Автосохранение: нет данных для сохранения")
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Автосохранение остановлено")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Ошибка автосохранения: {e}")
|
||||
# Продолжаем работу даже при ошибке
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
"""
|
||||
Lifespan контекст для FastAPI.
|
||||
|
||||
При запуске:
|
||||
- Настраивает логирование
|
||||
- Прогревает модель (опционально)
|
||||
|
||||
При остановке:
|
||||
- Сохраняет векторы на диск
|
||||
"""
|
||||
global _autosave_task
|
||||
|
||||
setup_logging()
|
||||
logger.info(f"RAG Service v{__version__} запускается...")
|
||||
|
||||
settings = get_settings()
|
||||
logger.info(f"Настройки: model={settings.model_name}, vectors_path={settings.vectors_path}")
|
||||
|
||||
# Получаем сервис (создается singleton)
|
||||
service = get_rag_service()
|
||||
|
||||
# Запускаем автосохранение если включено
|
||||
if settings.autosave_interval > 0:
|
||||
_autosave_task = asyncio.create_task(
|
||||
autosave_loop(service, settings.autosave_interval)
|
||||
)
|
||||
logger.info(f"Автосохранение включено: каждые {settings.autosave_interval} сек")
|
||||
else:
|
||||
logger.info("Автосохранение отключено")
|
||||
|
||||
# Прогреваем модель при запуске (опционально)
|
||||
# Можно раскомментировать если нужен автопрогрев
|
||||
# logger.info("Прогрев модели при запуске...")
|
||||
# await service.warmup()
|
||||
|
||||
logger.info("RAG Service готов к работе")
|
||||
|
||||
yield
|
||||
|
||||
# Останавливаем автосохранение
|
||||
if _autosave_task and not _autosave_task.done():
|
||||
_autosave_task.cancel()
|
||||
try:
|
||||
await _autosave_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# При остановке сохраняем векторы
|
||||
logger.info("RAG Service останавливается, финальное сохранение векторов...")
|
||||
try:
|
||||
service.save_vectors()
|
||||
logger.info("Векторы сохранены")
|
||||
except Exception as e:
|
||||
logger.error(f"Ошибка сохранения векторов: {e}")
|
||||
|
||||
logger.info("RAG Service остановлен")
|
||||
|
||||
|
||||
# Создание приложения
|
||||
app = FastAPI(
|
||||
title="RAG Service",
|
||||
description="""
|
||||
Сервис векторного скоринга текстов с использованием ruBERT.
|
||||
|
||||
## Возможности
|
||||
|
||||
* **Скоринг** - оценка текстов на основе векторного сходства с примерами
|
||||
* **Примеры** - добавление положительных и отрицательных примеров
|
||||
* **Статистика** - мониторинг состояния сервиса
|
||||
* **Управление** - прогрев модели, сохранение векторов
|
||||
|
||||
## Алгоритм скоринга
|
||||
|
||||
1. Текст преобразуется в вектор через ruBERT (768 измерений)
|
||||
2. Вычисляется косинусное сходство с положительными примерами
|
||||
3. Вычисляется косинусное сходство с отрицательными примерами
|
||||
4. Финальный скор = разница между сходствами, нормализованная в [0, 1]
|
||||
""",
|
||||
version=__version__,
|
||||
lifespan=lifespan,
|
||||
docs_url="/docs",
|
||||
redoc_url="/redoc",
|
||||
openapi_url="/openapi.json",
|
||||
)
|
||||
|
||||
# CORS middleware (для возможных веб-клиентов)
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"], # В продакшене ограничить
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Подключение роутов
|
||||
app.include_router(router, prefix="/api/v1")
|
||||
|
||||
|
||||
# Корневой endpoint
|
||||
@app.get("/", tags=["root"])
|
||||
async def root() -> dict:
|
||||
"""Корневой endpoint с информацией о сервисе."""
|
||||
return {
|
||||
"service": "RAG Service",
|
||||
"version": __version__,
|
||||
"docs": "/docs",
|
||||
"health": "/api/v1/health",
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
settings = get_settings()
|
||||
uvicorn.run(
|
||||
"app.main:app",
|
||||
host=settings.api_host,
|
||||
port=settings.api_port,
|
||||
reload=True,
|
||||
)
|
||||
Reference in New Issue
Block a user