Initial commit: RAG Service
This commit is contained in:
139
.gitignore
vendored
Normal file
139
.gitignore
vendored
Normal file
@@ -0,0 +1,139 @@
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
|
||||
# PyInstaller
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
# .python-version # Не игнорируем для фиксации версии Python
|
||||
|
||||
# pipenv
|
||||
Pipfile.lock
|
||||
|
||||
# PEP 582
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# IDE
|
||||
.idea/
|
||||
.vscode/
|
||||
*.swp
|
||||
*.swo
|
||||
*~
|
||||
|
||||
# OS
|
||||
.DS_Store
|
||||
Thumbs.db
|
||||
|
||||
# Project specific
|
||||
data/models/
|
||||
data/vectors/*.npz
|
||||
|
||||
# Keep data directories
|
||||
!data/models/.gitkeep
|
||||
!data/vectors/.gitkeep
|
||||
1
.python-version
Normal file
1
.python-version
Normal file
@@ -0,0 +1 @@
|
||||
3.11.9
|
||||
43
Dockerfile
Normal file
43
Dockerfile
Normal file
@@ -0,0 +1,43 @@
|
||||
# RAG Service Dockerfile
|
||||
# Python 3.11.9 для совместимости с основным ботом
|
||||
|
||||
FROM python:3.11.9-slim
|
||||
|
||||
# Рабочая директория
|
||||
WORKDIR /app
|
||||
|
||||
# Системные зависимости
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
build-essential \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Копируем зависимости
|
||||
COPY requirements.txt .
|
||||
|
||||
# Устанавливаем зависимости
|
||||
# --no-cache-dir для уменьшения размера образа
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
# Копируем код приложения
|
||||
COPY app/ ./app/
|
||||
|
||||
# Создаем директории для данных
|
||||
RUN mkdir -p data/models data/vectors
|
||||
|
||||
# Переменные окружения по умолчанию
|
||||
ENV RAG_MODEL=DeepPavlov/rubert-base-cased
|
||||
ENV RAG_CACHE_DIR=/app/data/models
|
||||
ENV RAG_VECTORS_PATH=/app/data/vectors/vectors.npz
|
||||
ENV RAG_API_HOST=0.0.0.0
|
||||
ENV RAG_API_PORT=8000
|
||||
ENV LOG_LEVEL=INFO
|
||||
|
||||
# Порт приложения
|
||||
EXPOSE 8000
|
||||
|
||||
# Healthcheck
|
||||
HEALTHCHECK --interval=30s --timeout=10s --start-period=60s --retries=3 \
|
||||
CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:8000/api/v1/health')" || exit 1
|
||||
|
||||
# Запуск приложения
|
||||
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||
92
README.md
Normal file
92
README.md
Normal file
@@ -0,0 +1,92 @@
|
||||
# RAG Service
|
||||
|
||||
Сервис векторного скоринга текстов с использованием ruBERT.
|
||||
|
||||
## Возможности
|
||||
|
||||
- **Скоринг** — оценка текстов на основе векторного сходства с примерами
|
||||
- **Примеры** — добавление положительных и отрицательных примеров для обучения
|
||||
- **Персистентность** — автоматическое сохранение векторов на диск
|
||||
- **API авторизация** — защита через API ключ
|
||||
|
||||
## Быстрый старт
|
||||
|
||||
```bash
|
||||
# Клонировать репозиторий
|
||||
git clone <repository-url>
|
||||
cd rag-service
|
||||
|
||||
# Создать .env файл
|
||||
cp env.example .env
|
||||
|
||||
# Сгенерировать API ключ
|
||||
python -c "import secrets; print(secrets.token_urlsafe(32))"
|
||||
# Добавить ключ в .env (RAG_API_KEY=...)
|
||||
|
||||
# Запустить
|
||||
docker-compose up -d --build
|
||||
```
|
||||
|
||||
## API
|
||||
|
||||
### Endpoints
|
||||
|
||||
| Метод | URL | Описание | Авторизация |
|
||||
|-------|-----|----------|-------------|
|
||||
| GET | `/api/v1/health` | Проверка здоровья | Нет |
|
||||
| POST | `/api/v1/score` | Расчет скора текста | Да |
|
||||
| POST | `/api/v1/examples/positive` | Добавить положительный пример | Да |
|
||||
| POST | `/api/v1/examples/negative` | Добавить отрицательный пример | Да |
|
||||
| GET | `/api/v1/stats` | Статистика сервиса | Да |
|
||||
| POST | `/api/v1/warmup` | Прогрев модели | Да |
|
||||
| POST | `/api/v1/save` | Сохранить векторы | Да |
|
||||
|
||||
### Авторизация
|
||||
|
||||
Передавать API ключ в заголовке `X-API-Key`:
|
||||
|
||||
```bash
|
||||
curl -H "X-API-Key: YOUR_API_KEY" http://localhost/api/v1/stats
|
||||
```
|
||||
|
||||
### Примеры запросов
|
||||
|
||||
```bash
|
||||
# Health check
|
||||
curl http://localhost/api/v1/health
|
||||
|
||||
# Расчет скора
|
||||
curl -X POST http://localhost/api/v1/score \
|
||||
-H "X-API-Key: YOUR_API_KEY" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"text": "Текст для оценки"}'
|
||||
|
||||
# Добавить положительный пример
|
||||
curl -X POST http://localhost/api/v1/examples/positive \
|
||||
-H "X-API-Key: YOUR_API_KEY" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"text": "Хороший пост"}'
|
||||
```
|
||||
|
||||
## Конфигурация
|
||||
|
||||
Переменные окружения (см. `env.example`):
|
||||
|
||||
| Переменная | Описание | По умолчанию |
|
||||
|------------|----------|--------------|
|
||||
| `RAG_API_KEY` | API ключ для авторизации | — |
|
||||
| `RAG_MODEL` | Модель HuggingFace | `DeepPavlov/rubert-base-cased` |
|
||||
| `RAG_MAX_EXAMPLES` | Макс. количество примеров | `10000` |
|
||||
| `RAG_AUTOSAVE_INTERVAL` | Интервал автосохранения (сек) | `600` |
|
||||
|
||||
## Swagger UI
|
||||
|
||||
Документация API доступна по адресу `/docs`.
|
||||
|
||||
## Технологии
|
||||
|
||||
- Python 3.11
|
||||
- FastAPI
|
||||
- Transformers (ruBERT)
|
||||
- NumPy
|
||||
- Docker
|
||||
5
app/__init__.py
Normal file
5
app/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""
|
||||
RAG Service - сервис векторного скоринга на FastAPI.
|
||||
"""
|
||||
|
||||
__version__ = "0.1.0"
|
||||
3
app/api/__init__.py
Normal file
3
app/api/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""
|
||||
API слой FastAPI.
|
||||
"""
|
||||
71
app/api/auth.py
Normal file
71
app/api/auth.py
Normal file
@@ -0,0 +1,71 @@
|
||||
"""
|
||||
Авторизация для API RAG сервиса.
|
||||
|
||||
Поддерживает авторизацию через API ключ в заголовке X-API-Key.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Annotated, Optional
|
||||
|
||||
from fastapi import Depends, HTTPException, Security, status
|
||||
from fastapi.security import APIKeyHeader
|
||||
|
||||
from app.config import Settings, get_settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Схема авторизации через заголовок
|
||||
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
|
||||
|
||||
|
||||
async def verify_api_key(
|
||||
api_key: Annotated[Optional[str], Security(api_key_header)],
|
||||
settings: Annotated[Settings, Depends(get_settings)],
|
||||
) -> bool:
|
||||
"""
|
||||
Проверяет API ключ из заголовка запроса.
|
||||
|
||||
Args:
|
||||
api_key: Ключ из заголовка X-API-Key
|
||||
settings: Настройки приложения
|
||||
|
||||
Returns:
|
||||
True если авторизация успешна
|
||||
|
||||
Raises:
|
||||
HTTPException: Если ключ неверный или отсутствует
|
||||
"""
|
||||
# Если API ключ не настроен и разрешены запросы без авторизации
|
||||
if settings.api_key is None:
|
||||
if settings.allow_no_auth:
|
||||
logger.debug("Авторизация отключена (RAG_ALLOW_NO_AUTH=true)")
|
||||
return True
|
||||
else:
|
||||
logger.warning("API ключ не настроен! Установите RAG_API_KEY")
|
||||
# В продакшене без ключа сервис не должен работать
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="API ключ не настроен на сервере",
|
||||
)
|
||||
|
||||
# Проверяем ключ
|
||||
if api_key is None:
|
||||
logger.warning("Запрос без API ключа")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="API ключ не предоставлен. Используйте заголовок X-API-Key",
|
||||
headers={"WWW-Authenticate": "ApiKey"},
|
||||
)
|
||||
|
||||
if api_key != settings.api_key:
|
||||
logger.warning("Неверный API ключ")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Неверный API ключ",
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
# Dependency для использования в роутах
|
||||
AuthDep = Annotated[bool, Depends(verify_api_key)]
|
||||
353
app/api/routes.py
Normal file
353
app/api/routes.py
Normal file
@@ -0,0 +1,353 @@
|
||||
"""
|
||||
FastAPI endpoints для RAG сервиса.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
|
||||
from app import __version__
|
||||
from app.api.auth import AuthDep
|
||||
from app.exceptions import (
|
||||
InsufficientExamplesError,
|
||||
ModelNotLoadedError,
|
||||
ScoringError,
|
||||
TextTooShortError,
|
||||
)
|
||||
from app.schemas import (
|
||||
ErrorResponse,
|
||||
ExampleRequest,
|
||||
ExampleResponse,
|
||||
HealthResponse,
|
||||
ScoreMetadata,
|
||||
ScoreRequest,
|
||||
ScoreResponse,
|
||||
StatsResponse,
|
||||
VectorStoreStats,
|
||||
WarmupResponse,
|
||||
)
|
||||
from app.services.rag_service import RAGService, get_rag_service
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# Dependency для получения RAG сервиса
|
||||
def get_service() -> RAGService:
|
||||
"""Возвращает экземпляр RAG сервиса."""
|
||||
return get_rag_service()
|
||||
|
||||
|
||||
RAGServiceDep = Annotated[RAGService, Depends(get_service)]
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Health Check
|
||||
# =============================================================================
|
||||
|
||||
@router.get(
|
||||
"/health",
|
||||
response_model=HealthResponse,
|
||||
summary="Проверка здоровья сервиса",
|
||||
tags=["health"],
|
||||
)
|
||||
async def health_check(service: RAGServiceDep) -> HealthResponse:
|
||||
"""
|
||||
Проверяет состояние сервиса.
|
||||
|
||||
Returns:
|
||||
HealthResponse: Статус сервиса
|
||||
"""
|
||||
return HealthResponse(
|
||||
status="healthy",
|
||||
model_loaded=service.is_model_loaded,
|
||||
version=__version__,
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Scoring
|
||||
# =============================================================================
|
||||
|
||||
@router.post(
|
||||
"/score",
|
||||
response_model=ScoreResponse,
|
||||
responses={
|
||||
400: {"model": ErrorResponse, "description": "Ошибка в запросе"},
|
||||
401: {"model": ErrorResponse, "description": "Не авторизован"},
|
||||
403: {"model": ErrorResponse, "description": "Доступ запрещён"},
|
||||
503: {"model": ErrorResponse, "description": "Сервис недоступен"},
|
||||
},
|
||||
summary="Расчет скора для текста",
|
||||
tags=["scoring"],
|
||||
)
|
||||
async def calculate_score(
|
||||
request: ScoreRequest,
|
||||
service: RAGServiceDep,
|
||||
_auth: AuthDep,
|
||||
) -> ScoreResponse:
|
||||
"""
|
||||
Рассчитывает скор для текста поста.
|
||||
|
||||
Args:
|
||||
request: Запрос с текстом
|
||||
service: RAG сервис
|
||||
|
||||
Returns:
|
||||
ScoreResponse: Результат скоринга
|
||||
|
||||
Raises:
|
||||
HTTPException: При ошибке расчета
|
||||
"""
|
||||
try:
|
||||
result = await service.calculate_score(request.text)
|
||||
response_dict = result.to_dict()
|
||||
|
||||
return ScoreResponse(
|
||||
rag_score=response_dict["rag_score"],
|
||||
rag_confidence=response_dict["rag_confidence"],
|
||||
rag_score_pos_only=response_dict["rag_score_pos_only"],
|
||||
meta=ScoreMetadata(**response_dict["meta"]),
|
||||
)
|
||||
|
||||
except TextTooShortError as e:
|
||||
logger.warning(f"Текст слишком короткий: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={"detail": str(e), "error_type": "TextTooShortError"},
|
||||
)
|
||||
|
||||
except InsufficientExamplesError as e:
|
||||
logger.warning(f"Недостаточно примеров: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={"detail": str(e), "error_type": "InsufficientExamplesError"},
|
||||
)
|
||||
|
||||
except ModelNotLoadedError as e:
|
||||
logger.error(f"Модель не загружена: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail={"detail": str(e), "error_type": "ModelNotLoadedError"},
|
||||
)
|
||||
|
||||
except ScoringError as e:
|
||||
logger.error(f"Ошибка скоринга: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail={"detail": str(e), "error_type": "ScoringError"},
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Examples
|
||||
# =============================================================================
|
||||
|
||||
@router.post(
|
||||
"/examples/positive",
|
||||
response_model=ExampleResponse,
|
||||
responses={
|
||||
401: {"model": ErrorResponse, "description": "Не авторизован"},
|
||||
403: {"model": ErrorResponse, "description": "Доступ запрещён"},
|
||||
503: {"model": ErrorResponse, "description": "Сервис недоступен"},
|
||||
},
|
||||
summary="Добавить положительный пример",
|
||||
tags=["examples"],
|
||||
)
|
||||
async def add_positive_example(
|
||||
request: ExampleRequest,
|
||||
service: RAGServiceDep,
|
||||
_auth: AuthDep,
|
||||
) -> ExampleResponse:
|
||||
"""
|
||||
Добавляет текст как положительный пример (опубликованный пост).
|
||||
|
||||
Args:
|
||||
request: Запрос с текстом
|
||||
service: RAG сервис
|
||||
|
||||
Returns:
|
||||
ExampleResponse: Результат добавления
|
||||
"""
|
||||
try:
|
||||
added = await service.add_positive_example(request.text)
|
||||
|
||||
if added:
|
||||
message = "Положительный пример добавлен"
|
||||
else:
|
||||
message = "Пример не добавлен (дубликат или слишком короткий текст)"
|
||||
|
||||
return ExampleResponse(
|
||||
success=added,
|
||||
message=message,
|
||||
positive_count=service.vector_store.positive_count,
|
||||
negative_count=service.vector_store.negative_count,
|
||||
)
|
||||
|
||||
except ModelNotLoadedError as e:
|
||||
logger.error(f"Модель не загружена: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail={"detail": str(e), "error_type": "ModelNotLoadedError"},
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/examples/negative",
|
||||
response_model=ExampleResponse,
|
||||
responses={
|
||||
401: {"model": ErrorResponse, "description": "Не авторизован"},
|
||||
403: {"model": ErrorResponse, "description": "Доступ запрещён"},
|
||||
503: {"model": ErrorResponse, "description": "Сервис недоступен"},
|
||||
},
|
||||
summary="Добавить отрицательный пример",
|
||||
tags=["examples"],
|
||||
)
|
||||
async def add_negative_example(
|
||||
request: ExampleRequest,
|
||||
service: RAGServiceDep,
|
||||
_auth: AuthDep,
|
||||
) -> ExampleResponse:
|
||||
"""
|
||||
Добавляет текст как отрицательный пример (отклоненный пост).
|
||||
|
||||
Args:
|
||||
request: Запрос с текстом
|
||||
service: RAG сервис
|
||||
|
||||
Returns:
|
||||
ExampleResponse: Результат добавления
|
||||
"""
|
||||
try:
|
||||
added = await service.add_negative_example(request.text)
|
||||
|
||||
if added:
|
||||
message = "Отрицательный пример добавлен"
|
||||
else:
|
||||
message = "Пример не добавлен (дубликат или слишком короткий текст)"
|
||||
|
||||
return ExampleResponse(
|
||||
success=added,
|
||||
message=message,
|
||||
positive_count=service.vector_store.positive_count,
|
||||
negative_count=service.vector_store.negative_count,
|
||||
)
|
||||
|
||||
except ModelNotLoadedError as e:
|
||||
logger.error(f"Модель не загружена: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail={"detail": str(e), "error_type": "ModelNotLoadedError"},
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Stats & Warmup
|
||||
# =============================================================================
|
||||
|
||||
@router.get(
|
||||
"/stats",
|
||||
response_model=StatsResponse,
|
||||
responses={
|
||||
401: {"model": ErrorResponse, "description": "Не авторизован"},
|
||||
403: {"model": ErrorResponse, "description": "Доступ запрещён"},
|
||||
},
|
||||
summary="Статистика сервиса",
|
||||
tags=["monitoring"],
|
||||
)
|
||||
async def get_stats(service: RAGServiceDep, _auth: AuthDep) -> StatsResponse:
|
||||
"""
|
||||
Возвращает статистику сервиса.
|
||||
|
||||
Args:
|
||||
service: RAG сервис
|
||||
|
||||
Returns:
|
||||
StatsResponse: Статистика
|
||||
"""
|
||||
stats = service.get_stats()
|
||||
|
||||
return StatsResponse(
|
||||
model_name=stats["model_name"],
|
||||
model_loaded=stats["model_loaded"],
|
||||
device=stats["device"],
|
||||
cache_dir=stats["cache_dir"],
|
||||
vector_store=VectorStoreStats(**stats["vector_store"]),
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/warmup",
|
||||
response_model=WarmupResponse,
|
||||
responses={
|
||||
401: {"model": ErrorResponse, "description": "Не авторизован"},
|
||||
403: {"model": ErrorResponse, "description": "Доступ запрещён"},
|
||||
503: {"model": ErrorResponse, "description": "Не удалось загрузить модель"},
|
||||
},
|
||||
summary="Прогрев модели",
|
||||
tags=["management"],
|
||||
)
|
||||
async def warmup(service: RAGServiceDep, _auth: AuthDep) -> WarmupResponse:
|
||||
"""
|
||||
Прогревает модель (загружает если не загружена).
|
||||
|
||||
Args:
|
||||
service: RAG сервис
|
||||
|
||||
Returns:
|
||||
WarmupResponse: Результат прогрева
|
||||
"""
|
||||
success = await service.warmup()
|
||||
|
||||
if success:
|
||||
message = "Модель успешно загружена"
|
||||
else:
|
||||
message = "Не удалось загрузить модель"
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail={"detail": message, "error_type": "ModelNotLoadedError"},
|
||||
)
|
||||
|
||||
return WarmupResponse(
|
||||
success=success,
|
||||
model_loaded=service.is_model_loaded,
|
||||
message=message,
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/save",
|
||||
response_model=dict,
|
||||
responses={
|
||||
401: {"model": ErrorResponse, "description": "Не авторизован"},
|
||||
403: {"model": ErrorResponse, "description": "Доступ запрещён"},
|
||||
},
|
||||
summary="Сохранить векторы на диск",
|
||||
tags=["management"],
|
||||
)
|
||||
async def save_vectors(service: RAGServiceDep, _auth: AuthDep) -> dict:
|
||||
"""
|
||||
Сохраняет векторы на диск.
|
||||
|
||||
Args:
|
||||
service: RAG сервис
|
||||
|
||||
Returns:
|
||||
dict: Результат сохранения
|
||||
"""
|
||||
try:
|
||||
service.save_vectors()
|
||||
return {
|
||||
"success": True,
|
||||
"message": "Векторы сохранены на диск",
|
||||
"positive_count": service.vector_store.positive_count,
|
||||
"negative_count": service.vector_store.negative_count,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Ошибка сохранения векторов: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail={"detail": str(e), "error_type": "VectorStoreError"},
|
||||
)
|
||||
104
app/config.py
Normal file
104
app/config.py
Normal file
@@ -0,0 +1,104 @@
|
||||
"""
|
||||
Конфигурация RAG сервиса через переменные окружения.
|
||||
"""
|
||||
|
||||
import os
|
||||
import secrets
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class Settings:
|
||||
"""
|
||||
Настройки RAG сервиса.
|
||||
|
||||
Все параметры загружаются из переменных окружения.
|
||||
"""
|
||||
|
||||
# Модель
|
||||
model_name: str = field(
|
||||
default_factory=lambda: os.getenv("RAG_MODEL", "DeepPavlov/rubert-base-cased")
|
||||
)
|
||||
cache_dir: str = field(
|
||||
default_factory=lambda: os.getenv("RAG_CACHE_DIR", "data/models")
|
||||
)
|
||||
|
||||
# VectorStore
|
||||
vectors_path: str = field(
|
||||
default_factory=lambda: os.getenv("RAG_VECTORS_PATH", "data/vectors/vectors.npz")
|
||||
)
|
||||
max_examples: int = field(
|
||||
default_factory=lambda: int(os.getenv("RAG_MAX_EXAMPLES", "10000"))
|
||||
)
|
||||
score_multiplier: float = field(
|
||||
default_factory=lambda: float(os.getenv("RAG_SCORE_MULTIPLIER", "5.0"))
|
||||
)
|
||||
|
||||
# Батч-обработка
|
||||
batch_size: int = field(
|
||||
default_factory=lambda: int(os.getenv("RAG_BATCH_SIZE", "16"))
|
||||
)
|
||||
|
||||
# Минимальная длина текста
|
||||
min_text_length: int = field(
|
||||
default_factory=lambda: int(os.getenv("RAG_MIN_TEXT_LENGTH", "3"))
|
||||
)
|
||||
|
||||
# API настройки
|
||||
api_host: str = field(
|
||||
default_factory=lambda: os.getenv("RAG_API_HOST", "0.0.0.0")
|
||||
)
|
||||
api_port: int = field(
|
||||
default_factory=lambda: int(os.getenv("RAG_API_PORT", "8000"))
|
||||
)
|
||||
|
||||
# Безопасность
|
||||
# API ключ для авторизации (обязателен в продакшене!)
|
||||
api_key: Optional[str] = field(
|
||||
default_factory=lambda: os.getenv("RAG_API_KEY")
|
||||
)
|
||||
# Разрешить запросы без ключа (только для разработки)
|
||||
allow_no_auth: bool = field(
|
||||
default_factory=lambda: os.getenv("RAG_ALLOW_NO_AUTH", "false").lower() == "true"
|
||||
)
|
||||
|
||||
# Логирование
|
||||
log_level: str = field(
|
||||
default_factory=lambda: os.getenv("LOG_LEVEL", "INFO")
|
||||
)
|
||||
|
||||
# Автосохранение (интервал в секундах, 0 = отключено)
|
||||
autosave_interval: int = field(
|
||||
default_factory=lambda: int(os.getenv("RAG_AUTOSAVE_INTERVAL", "600")) # 10 минут
|
||||
)
|
||||
|
||||
# Размерность векторов (768 для ruBERT)
|
||||
vector_dim: int = 768
|
||||
|
||||
@property
|
||||
def is_auth_required(self) -> bool:
|
||||
"""Проверяет, требуется ли авторизация."""
|
||||
return self.api_key is not None and not self.allow_no_auth
|
||||
|
||||
@staticmethod
|
||||
def generate_api_key() -> str:
|
||||
"""Генерирует случайный API ключ."""
|
||||
return secrets.token_urlsafe(32)
|
||||
|
||||
|
||||
# Глобальный экземпляр настроек
|
||||
_settings: Optional[Settings] = None
|
||||
|
||||
|
||||
def get_settings() -> Settings:
|
||||
"""
|
||||
Возвращает глобальный экземпляр настроек.
|
||||
|
||||
Returns:
|
||||
Settings: Настройки приложения
|
||||
"""
|
||||
global _settings
|
||||
if _settings is None:
|
||||
_settings = Settings()
|
||||
return _settings
|
||||
33
app/exceptions.py
Normal file
33
app/exceptions.py
Normal file
@@ -0,0 +1,33 @@
|
||||
"""
|
||||
Исключения для RAG сервиса.
|
||||
"""
|
||||
|
||||
|
||||
class RAGServiceError(Exception):
|
||||
"""Базовое исключение для ошибок RAG сервиса."""
|
||||
pass
|
||||
|
||||
|
||||
class ModelNotLoadedError(RAGServiceError):
|
||||
"""Модель не загружена или недоступна."""
|
||||
pass
|
||||
|
||||
|
||||
class VectorStoreError(RAGServiceError):
|
||||
"""Ошибка при работе с хранилищем векторов."""
|
||||
pass
|
||||
|
||||
|
||||
class InsufficientExamplesError(RAGServiceError):
|
||||
"""Недостаточно примеров для расчета скора."""
|
||||
pass
|
||||
|
||||
|
||||
class TextTooShortError(RAGServiceError):
|
||||
"""Текст слишком короткий для векторизации."""
|
||||
pass
|
||||
|
||||
|
||||
class ScoringError(RAGServiceError):
|
||||
"""Ошибка при расчете скора."""
|
||||
pass
|
||||
199
app/main.py
Normal file
199
app/main.py
Normal file
@@ -0,0 +1,199 @@
|
||||
"""
|
||||
FastAPI приложение RAG сервиса.
|
||||
|
||||
Сервис для векторного скоринга текстов с использованием ruBERT.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import sys
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import AsyncGenerator, Optional
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from app import __version__
|
||||
from app.api.routes import router
|
||||
from app.config import get_settings
|
||||
from app.services.rag_service import RAGService, get_rag_service
|
||||
|
||||
# Настройка логирования
|
||||
def setup_logging() -> None:
|
||||
"""Настраивает логирование для приложения."""
|
||||
settings = get_settings()
|
||||
|
||||
logging.basicConfig(
|
||||
level=getattr(logging, settings.log_level.upper()),
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
handlers=[
|
||||
logging.StreamHandler(sys.stdout),
|
||||
],
|
||||
)
|
||||
|
||||
# Уменьшаем логи от библиотек
|
||||
logging.getLogger("transformers").setLevel(logging.WARNING)
|
||||
logging.getLogger("torch").setLevel(logging.WARNING)
|
||||
logging.getLogger("uvicorn.access").setLevel(logging.WARNING)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Глобальная задача автосохранения
|
||||
_autosave_task: Optional[asyncio.Task] = None
|
||||
|
||||
|
||||
async def autosave_loop(service: RAGService, interval: int) -> None:
|
||||
"""
|
||||
Фоновая задача для периодического сохранения векторов.
|
||||
|
||||
Args:
|
||||
service: RAG сервис
|
||||
interval: Интервал сохранения в секундах
|
||||
"""
|
||||
logger.info(f"Автосохранение запущено (интервал: {interval} сек)")
|
||||
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(interval)
|
||||
|
||||
# Сохраняем только если есть данные
|
||||
if service.vector_store.total_count > 0:
|
||||
service.save_vectors()
|
||||
logger.info(
|
||||
f"Автосохранение: сохранено {service.vector_store.positive_count} pos, "
|
||||
f"{service.vector_store.negative_count} neg"
|
||||
)
|
||||
else:
|
||||
logger.debug("Автосохранение: нет данных для сохранения")
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Автосохранение остановлено")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Ошибка автосохранения: {e}")
|
||||
# Продолжаем работу даже при ошибке
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
"""
|
||||
Lifespan контекст для FastAPI.
|
||||
|
||||
При запуске:
|
||||
- Настраивает логирование
|
||||
- Прогревает модель (опционально)
|
||||
|
||||
При остановке:
|
||||
- Сохраняет векторы на диск
|
||||
"""
|
||||
global _autosave_task
|
||||
|
||||
setup_logging()
|
||||
logger.info(f"RAG Service v{__version__} запускается...")
|
||||
|
||||
settings = get_settings()
|
||||
logger.info(f"Настройки: model={settings.model_name}, vectors_path={settings.vectors_path}")
|
||||
|
||||
# Получаем сервис (создается singleton)
|
||||
service = get_rag_service()
|
||||
|
||||
# Запускаем автосохранение если включено
|
||||
if settings.autosave_interval > 0:
|
||||
_autosave_task = asyncio.create_task(
|
||||
autosave_loop(service, settings.autosave_interval)
|
||||
)
|
||||
logger.info(f"Автосохранение включено: каждые {settings.autosave_interval} сек")
|
||||
else:
|
||||
logger.info("Автосохранение отключено")
|
||||
|
||||
# Прогреваем модель при запуске (опционально)
|
||||
# Можно раскомментировать если нужен автопрогрев
|
||||
# logger.info("Прогрев модели при запуске...")
|
||||
# await service.warmup()
|
||||
|
||||
logger.info("RAG Service готов к работе")
|
||||
|
||||
yield
|
||||
|
||||
# Останавливаем автосохранение
|
||||
if _autosave_task and not _autosave_task.done():
|
||||
_autosave_task.cancel()
|
||||
try:
|
||||
await _autosave_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# При остановке сохраняем векторы
|
||||
logger.info("RAG Service останавливается, финальное сохранение векторов...")
|
||||
try:
|
||||
service.save_vectors()
|
||||
logger.info("Векторы сохранены")
|
||||
except Exception as e:
|
||||
logger.error(f"Ошибка сохранения векторов: {e}")
|
||||
|
||||
logger.info("RAG Service остановлен")
|
||||
|
||||
|
||||
# Создание приложения
|
||||
app = FastAPI(
|
||||
title="RAG Service",
|
||||
description="""
|
||||
Сервис векторного скоринга текстов с использованием ruBERT.
|
||||
|
||||
## Возможности
|
||||
|
||||
* **Скоринг** - оценка текстов на основе векторного сходства с примерами
|
||||
* **Примеры** - добавление положительных и отрицательных примеров
|
||||
* **Статистика** - мониторинг состояния сервиса
|
||||
* **Управление** - прогрев модели, сохранение векторов
|
||||
|
||||
## Алгоритм скоринга
|
||||
|
||||
1. Текст преобразуется в вектор через ruBERT (768 измерений)
|
||||
2. Вычисляется косинусное сходство с положительными примерами
|
||||
3. Вычисляется косинусное сходство с отрицательными примерами
|
||||
4. Финальный скор = разница между сходствами, нормализованная в [0, 1]
|
||||
""",
|
||||
version=__version__,
|
||||
lifespan=lifespan,
|
||||
docs_url="/docs",
|
||||
redoc_url="/redoc",
|
||||
openapi_url="/openapi.json",
|
||||
)
|
||||
|
||||
# CORS middleware (для возможных веб-клиентов)
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"], # В продакшене ограничить
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Подключение роутов
|
||||
app.include_router(router, prefix="/api/v1")
|
||||
|
||||
|
||||
# Корневой endpoint
|
||||
@app.get("/", tags=["root"])
|
||||
async def root() -> dict:
|
||||
"""Корневой endpoint с информацией о сервисе."""
|
||||
return {
|
||||
"service": "RAG Service",
|
||||
"version": __version__,
|
||||
"docs": "/docs",
|
||||
"health": "/api/v1/health",
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
settings = get_settings()
|
||||
uvicorn.run(
|
||||
"app.main:app",
|
||||
host=settings.api_host,
|
||||
port=settings.api_port,
|
||||
reload=True,
|
||||
)
|
||||
179
app/schemas.py
Normal file
179
app/schemas.py
Normal file
@@ -0,0 +1,179 @@
|
||||
"""
|
||||
Pydantic схемы для API RAG сервиса.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Запросы
|
||||
# =============================================================================
|
||||
|
||||
class ScoreRequest(BaseModel):
|
||||
"""Запрос на расчет скора."""
|
||||
text: str = Field(..., min_length=1, description="Текст поста для оценки")
|
||||
|
||||
model_config = {
|
||||
"json_schema_extra": {
|
||||
"example": {
|
||||
"text": "Это пример текста поста для оценки скоринга"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class ExampleRequest(BaseModel):
|
||||
"""Запрос на добавление примера."""
|
||||
text: str = Field(..., min_length=1, description="Текст примера")
|
||||
|
||||
model_config = {
|
||||
"json_schema_extra": {
|
||||
"example": {
|
||||
"text": "Это пример опубликованного/отклоненного поста"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Ответы
|
||||
# =============================================================================
|
||||
|
||||
class ScoreMetadata(BaseModel):
|
||||
"""Метаданные результата скоринга."""
|
||||
positive_examples: int = Field(..., description="Количество положительных примеров")
|
||||
negative_examples: int = Field(..., description="Количество отрицательных примеров")
|
||||
model: str = Field(..., description="Название модели")
|
||||
timestamp: int = Field(..., description="Время расчета (unix timestamp)")
|
||||
|
||||
|
||||
class ScoreResponse(BaseModel):
|
||||
"""Ответ с результатом скоринга."""
|
||||
rag_score: float = Field(..., ge=0.0, le=1.0, description="Основной скор (neg/pos формула)")
|
||||
rag_confidence: float = Field(..., ge=0.0, le=1.0, description="Уверенность в оценке")
|
||||
rag_score_pos_only: float = Field(..., ge=0.0, le=1.0, description="Скор только по положительным примерам")
|
||||
meta: ScoreMetadata = Field(..., description="Метаданные")
|
||||
|
||||
model_config = {
|
||||
"json_schema_extra": {
|
||||
"example": {
|
||||
"rag_score": 0.7523,
|
||||
"rag_confidence": 0.85,
|
||||
"rag_score_pos_only": 0.6891,
|
||||
"meta": {
|
||||
"positive_examples": 500,
|
||||
"negative_examples": 350,
|
||||
"model": "DeepPavlov/rubert-base-cased",
|
||||
"timestamp": 1706270000
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class ExampleResponse(BaseModel):
|
||||
"""Ответ на добавление примера."""
|
||||
success: bool = Field(..., description="Успешность добавления")
|
||||
message: str = Field(..., description="Сообщение о результате")
|
||||
positive_count: int = Field(..., description="Текущее количество положительных примеров")
|
||||
negative_count: int = Field(..., description="Текущее количество отрицательных примеров")
|
||||
|
||||
model_config = {
|
||||
"json_schema_extra": {
|
||||
"example": {
|
||||
"success": True,
|
||||
"message": "Положительный пример добавлен",
|
||||
"positive_count": 501,
|
||||
"negative_count": 350
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class VectorStoreStats(BaseModel):
|
||||
"""Статистика хранилища векторов."""
|
||||
positive_count: int = Field(..., description="Количество положительных примеров")
|
||||
negative_count: int = Field(..., description="Количество отрицательных примеров")
|
||||
total_count: int = Field(..., description="Общее количество примеров")
|
||||
vector_dim: int = Field(..., description="Размерность векторов")
|
||||
max_examples: int = Field(..., description="Максимальное количество примеров")
|
||||
storage_path: Optional[str] = Field(None, description="Путь к файлу хранилища")
|
||||
|
||||
|
||||
class StatsResponse(BaseModel):
|
||||
"""Ответ со статистикой сервиса."""
|
||||
model_name: str = Field(..., description="Название модели")
|
||||
model_loaded: bool = Field(..., description="Загружена ли модель")
|
||||
device: Optional[str] = Field(None, description="Устройство (cpu/cuda)")
|
||||
cache_dir: str = Field(..., description="Директория кеша модели")
|
||||
vector_store: VectorStoreStats = Field(..., description="Статистика хранилища векторов")
|
||||
|
||||
model_config = {
|
||||
"json_schema_extra": {
|
||||
"example": {
|
||||
"model_name": "DeepPavlov/rubert-base-cased",
|
||||
"model_loaded": True,
|
||||
"device": "cpu",
|
||||
"cache_dir": "data/models",
|
||||
"vector_store": {
|
||||
"positive_count": 500,
|
||||
"negative_count": 350,
|
||||
"total_count": 850,
|
||||
"vector_dim": 768,
|
||||
"max_examples": 10000,
|
||||
"storage_path": "data/vectors/vectors.npz"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class WarmupResponse(BaseModel):
|
||||
"""Ответ на прогрев модели."""
|
||||
success: bool = Field(..., description="Успешность загрузки")
|
||||
model_loaded: bool = Field(..., description="Загружена ли модель")
|
||||
message: str = Field(..., description="Сообщение о результате")
|
||||
|
||||
model_config = {
|
||||
"json_schema_extra": {
|
||||
"example": {
|
||||
"success": True,
|
||||
"model_loaded": True,
|
||||
"message": "Модель успешно загружена"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class ErrorResponse(BaseModel):
|
||||
"""Ответ с ошибкой."""
|
||||
detail: str = Field(..., description="Описание ошибки")
|
||||
error_type: str = Field(..., description="Тип ошибки")
|
||||
|
||||
model_config = {
|
||||
"json_schema_extra": {
|
||||
"example": {
|
||||
"detail": "Недостаточно примеров для расчета скора",
|
||||
"error_type": "InsufficientExamplesError"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class HealthResponse(BaseModel):
|
||||
"""Ответ проверки здоровья сервиса."""
|
||||
status: str = Field(..., description="Статус сервиса")
|
||||
model_loaded: bool = Field(..., description="Загружена ли модель")
|
||||
version: str = Field(..., description="Версия сервиса")
|
||||
|
||||
model_config = {
|
||||
"json_schema_extra": {
|
||||
"example": {
|
||||
"status": "healthy",
|
||||
"model_loaded": True,
|
||||
"version": "0.1.0"
|
||||
}
|
||||
}
|
||||
}
|
||||
3
app/services/__init__.py
Normal file
3
app/services/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""
|
||||
Сервисы RAG: ядро логики скоринга.
|
||||
"""
|
||||
488
app/services/rag_service.py
Normal file
488
app/services/rag_service.py
Normal file
@@ -0,0 +1,488 @@
|
||||
"""
|
||||
RAG сервис для скоринга постов с использованием ruBERT.
|
||||
|
||||
Использует модель DeepPavlov/rubert-base-cased для создания эмбеддингов
|
||||
и сравнивает их с эталонными примерами через VectorStore.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from app.config import Settings, get_settings
|
||||
from app.exceptions import (
|
||||
InsufficientExamplesError,
|
||||
ModelNotLoadedError,
|
||||
ScoringError,
|
||||
TextTooShortError,
|
||||
)
|
||||
from app.storage.vector_store import VectorStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScoringResult:
|
||||
"""
|
||||
Результат оценки поста.
|
||||
|
||||
Attributes:
|
||||
score: Оценка от 0.0 до 1.0 (вероятность публикации)
|
||||
confidence: Уверенность в оценке
|
||||
score_pos_only: Оценка только по положительным примерам
|
||||
positive_examples: Количество положительных примеров
|
||||
negative_examples: Количество отрицательных примеров
|
||||
model: Название используемой модели
|
||||
timestamp: Время получения оценки
|
||||
"""
|
||||
score: float
|
||||
confidence: float
|
||||
score_pos_only: float
|
||||
positive_examples: int
|
||||
negative_examples: int
|
||||
model: str
|
||||
timestamp: int = field(default_factory=lambda: int(datetime.now().timestamp()))
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Преобразует результат в словарь."""
|
||||
return {
|
||||
"rag_score": round(self.score, 4),
|
||||
"rag_confidence": round(self.confidence, 4),
|
||||
"rag_score_pos_only": round(self.score_pos_only, 4),
|
||||
"meta": {
|
||||
"positive_examples": self.positive_examples,
|
||||
"negative_examples": self.negative_examples,
|
||||
"model": self.model,
|
||||
"timestamp": self.timestamp,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class RAGService:
|
||||
"""
|
||||
RAG сервис для оценки постов на основе векторного сходства.
|
||||
|
||||
Использует ruBERT для создания эмбеддингов текста и сравнивает
|
||||
их с эталонными примерами (опубликованные vs отклоненные посты).
|
||||
|
||||
Attributes:
|
||||
model_name: Название модели HuggingFace
|
||||
vector_store: Хранилище векторов
|
||||
min_text_length: Минимальная длина текста для обработки
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
settings: Optional[Settings] = None,
|
||||
vector_store: Optional[VectorStore] = None,
|
||||
):
|
||||
"""
|
||||
Инициализация RAG сервиса.
|
||||
|
||||
Args:
|
||||
settings: Настройки сервиса (берутся из get_settings() если не переданы)
|
||||
vector_store: Хранилище векторов (создается автоматически если не передано)
|
||||
"""
|
||||
self._settings = settings or get_settings()
|
||||
self.model_name = self._settings.model_name
|
||||
self.cache_dir = self._settings.cache_dir
|
||||
self.min_text_length = self._settings.min_text_length
|
||||
|
||||
# Модель и токенизатор загружаются лениво
|
||||
self._model = None
|
||||
self._tokenizer = None
|
||||
self._device = None
|
||||
self._model_loaded = False
|
||||
|
||||
# Хранилище векторов
|
||||
self.vector_store = vector_store or VectorStore(
|
||||
vector_dim=self._settings.vector_dim,
|
||||
max_examples=self._settings.max_examples,
|
||||
storage_path=self._settings.vectors_path,
|
||||
score_multiplier=self._settings.score_multiplier,
|
||||
)
|
||||
|
||||
logger.info(f"RAGService инициализирован (model={self.model_name})")
|
||||
|
||||
@property
|
||||
def is_model_loaded(self) -> bool:
|
||||
"""Проверяет, загружена ли модель."""
|
||||
return self._model_loaded
|
||||
|
||||
async def load_model(self) -> None:
|
||||
"""
|
||||
Загружает модель и токенизатор.
|
||||
|
||||
Выполняется асинхронно в отдельном потоке чтобы не блокировать event loop.
|
||||
"""
|
||||
if self._model_loaded:
|
||||
return
|
||||
|
||||
logger.info(f"RAGService: Загрузка модели {self.model_name}...")
|
||||
|
||||
try:
|
||||
# Загрузка в отдельном потоке
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(None, self._load_model_sync)
|
||||
|
||||
self._model_loaded = True
|
||||
logger.info(f"RAGService: Модель {self.model_name} успешно загружена")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"RAGService: Ошибка загрузки модели: {e}")
|
||||
raise ModelNotLoadedError(f"Не удалось загрузить модель {self.model_name}: {e}")
|
||||
|
||||
def _load_model_sync(self) -> None:
|
||||
"""Синхронная загрузка модели (вызывается в executor)."""
|
||||
logger.info("RAGService: Начало _load_model_sync, импорт transformers...")
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
import torch
|
||||
|
||||
# Определяем устройство
|
||||
self._device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
logger.info(f"RAGService: Устройство определено: {self._device}")
|
||||
|
||||
# Загружаем токенизатор
|
||||
logger.info(f"RAGService: Загрузка токенизатора из {self.model_name}...")
|
||||
self._tokenizer = AutoTokenizer.from_pretrained(
|
||||
self.model_name,
|
||||
cache_dir=self.cache_dir,
|
||||
)
|
||||
logger.info("RAGService: Токенизатор загружен")
|
||||
|
||||
# Загружаем модель
|
||||
logger.info(f"RAGService: Загрузка модели из {self.model_name} (это может занять несколько минут)...")
|
||||
self._model = AutoModel.from_pretrained(
|
||||
self.model_name,
|
||||
cache_dir=self.cache_dir,
|
||||
)
|
||||
logger.info("RAGService: Модель загружена, перенос на устройство...")
|
||||
self._model.to(self._device)
|
||||
self._model.eval() # Режим инференса
|
||||
|
||||
logger.info(f"RAGService: Модель готова на устройстве: {self._device}")
|
||||
|
||||
def _get_embedding_sync(self, text: str) -> np.ndarray:
|
||||
"""
|
||||
Получает эмбеддинг текста (синхронно).
|
||||
|
||||
Использует [CLS] токен как представление всего текста.
|
||||
|
||||
Args:
|
||||
text: Текст для векторизации
|
||||
|
||||
Returns:
|
||||
Numpy массив с эмбеддингом (768 измерений для ruBERT)
|
||||
"""
|
||||
import torch
|
||||
|
||||
# Токенизация с ограничением длины
|
||||
inputs = self._tokenizer(
|
||||
text,
|
||||
return_tensors="pt",
|
||||
truncation=True,
|
||||
max_length=512,
|
||||
padding=True,
|
||||
)
|
||||
inputs = {k: v.to(self._device) for k, v in inputs.items()}
|
||||
|
||||
# Получаем эмбеддинг
|
||||
with torch.no_grad():
|
||||
outputs = self._model(**inputs)
|
||||
# Используем [CLS] токен (первый токен)
|
||||
embedding = outputs.last_hidden_state[:, 0, :].cpu().numpy()
|
||||
|
||||
return embedding.flatten()
|
||||
|
||||
def _get_embeddings_batch_sync(self, texts: List[str], batch_size: int = 16) -> List[np.ndarray]:
|
||||
"""
|
||||
Получает эмбеддинги для батча текстов (синхронно).
|
||||
|
||||
Обрабатывает тексты пачками для эффективного использования GPU/CPU.
|
||||
|
||||
Args:
|
||||
texts: Список текстов для векторизации
|
||||
batch_size: Размер батча
|
||||
|
||||
Returns:
|
||||
Список numpy массивов с эмбеддингами
|
||||
"""
|
||||
import torch
|
||||
|
||||
all_embeddings = []
|
||||
|
||||
for i in range(0, len(texts), batch_size):
|
||||
batch_texts = texts[i:i + batch_size]
|
||||
|
||||
# Токенизация батча
|
||||
inputs = self._tokenizer(
|
||||
batch_texts,
|
||||
return_tensors="pt",
|
||||
truncation=True,
|
||||
max_length=512,
|
||||
padding=True,
|
||||
)
|
||||
inputs = {k: v.to(self._device) for k, v in inputs.items()}
|
||||
|
||||
# Получаем эмбеддинги
|
||||
with torch.no_grad():
|
||||
outputs = self._model(**inputs)
|
||||
# [CLS] токен для каждого текста в батче
|
||||
batch_embeddings = outputs.last_hidden_state[:, 0, :].cpu().numpy()
|
||||
|
||||
# Разбиваем на отдельные эмбеддинги
|
||||
for j in range(len(batch_texts)):
|
||||
all_embeddings.append(batch_embeddings[j])
|
||||
|
||||
if i > 0 and i % (batch_size * 10) == 0:
|
||||
logger.info(f"RAGService: Обработано {i}/{len(texts)} текстов")
|
||||
|
||||
return all_embeddings
|
||||
|
||||
async def get_embeddings_batch(self, texts: List[str], batch_size: Optional[int] = None) -> List[np.ndarray]:
|
||||
"""
|
||||
Получает эмбеддинги для батча текстов (асинхронно).
|
||||
|
||||
Args:
|
||||
texts: Список текстов для векторизации
|
||||
batch_size: Размер батча (берется из настроек если не указан)
|
||||
|
||||
Returns:
|
||||
Список numpy массивов с эмбеддингами
|
||||
"""
|
||||
if not self._model_loaded:
|
||||
await self.load_model()
|
||||
|
||||
if not self._model_loaded:
|
||||
raise ModelNotLoadedError("Модель не загружена")
|
||||
|
||||
batch_size = batch_size or self._settings.batch_size
|
||||
|
||||
# Очищаем тексты
|
||||
clean_texts = [self._clean_text(text) for text in texts]
|
||||
|
||||
# Выполняем батч-обработку в thread pool
|
||||
loop = asyncio.get_event_loop()
|
||||
embeddings = await loop.run_in_executor(
|
||||
None,
|
||||
self._get_embeddings_batch_sync,
|
||||
clean_texts,
|
||||
batch_size,
|
||||
)
|
||||
|
||||
return embeddings
|
||||
|
||||
async def get_embedding(self, text: str) -> np.ndarray:
|
||||
"""
|
||||
Получает эмбеддинг текста (асинхронно).
|
||||
|
||||
Args:
|
||||
text: Текст для векторизации
|
||||
|
||||
Returns:
|
||||
Numpy массив с эмбеддингом
|
||||
|
||||
Raises:
|
||||
ModelNotLoadedError: Если модель не загружена
|
||||
TextTooShortError: Если текст слишком короткий
|
||||
"""
|
||||
if not self._model_loaded:
|
||||
await self.load_model()
|
||||
|
||||
if not self._model_loaded:
|
||||
raise ModelNotLoadedError("Модель не загружена")
|
||||
|
||||
# Очищаем текст
|
||||
clean_text = self._clean_text(text)
|
||||
|
||||
if len(clean_text) < self.min_text_length:
|
||||
raise TextTooShortError(
|
||||
f"Текст слишком короткий (минимум {self.min_text_length} символов)"
|
||||
)
|
||||
|
||||
# Выполняем в отдельном потоке
|
||||
loop = asyncio.get_event_loop()
|
||||
embedding = await loop.run_in_executor(
|
||||
None,
|
||||
self._get_embedding_sync,
|
||||
clean_text
|
||||
)
|
||||
|
||||
return embedding
|
||||
|
||||
def _clean_text(self, text: str) -> str:
|
||||
"""Очищает текст от лишних символов."""
|
||||
if not text:
|
||||
return ""
|
||||
|
||||
# Удаляем лишние пробелы и переносы строк
|
||||
clean = " ".join(text.split())
|
||||
|
||||
# Удаляем служебные символы (например "^" для helper сообщений)
|
||||
if clean == "^":
|
||||
return ""
|
||||
|
||||
return clean.strip()
|
||||
|
||||
async def calculate_score(self, text: str) -> ScoringResult:
|
||||
"""
|
||||
Рассчитывает скор для текста поста.
|
||||
|
||||
Args:
|
||||
text: Текст поста для оценки
|
||||
|
||||
Returns:
|
||||
ScoringResult с оценкой
|
||||
|
||||
Raises:
|
||||
ScoringError: При ошибке расчета
|
||||
InsufficientExamplesError: Если недостаточно примеров
|
||||
TextTooShortError: Если текст слишком короткий
|
||||
"""
|
||||
try:
|
||||
# Получаем эмбеддинг текста
|
||||
embedding = await self.get_embedding(text)
|
||||
|
||||
# Логируем первые элементы вектора для отладки
|
||||
logger.debug(
|
||||
f"RAGService: embedding[:3]={embedding[:3].tolist()}, "
|
||||
f"text_preview='{text[:30]}'"
|
||||
)
|
||||
|
||||
# Рассчитываем скор через VectorStore
|
||||
score, confidence, score_pos_only = self.vector_store.calculate_similarity_score(embedding)
|
||||
|
||||
return ScoringResult(
|
||||
score=score,
|
||||
confidence=confidence,
|
||||
score_pos_only=score_pos_only,
|
||||
positive_examples=self.vector_store.positive_count,
|
||||
negative_examples=self.vector_store.negative_count,
|
||||
model=self.model_name,
|
||||
)
|
||||
|
||||
except (InsufficientExamplesError, TextTooShortError):
|
||||
# Пробрасываем ожидаемые исключения
|
||||
raise
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"RAGService: Ошибка расчета скора: {e}")
|
||||
raise ScoringError(f"Ошибка расчета скора: {e}")
|
||||
|
||||
async def add_positive_example(self, text: str) -> bool:
|
||||
"""
|
||||
Добавляет текст как положительный пример (опубликованный пост).
|
||||
|
||||
Args:
|
||||
text: Текст опубликованного поста
|
||||
|
||||
Returns:
|
||||
True если пример добавлен, False если дубликат/короткий текст
|
||||
"""
|
||||
try:
|
||||
clean_text = self._clean_text(text)
|
||||
if len(clean_text) < self.min_text_length:
|
||||
logger.debug("RAGService: Текст слишком короткий для примера, пропускаем")
|
||||
return False
|
||||
|
||||
# Получаем эмбеддинг
|
||||
embedding = await self.get_embedding(clean_text)
|
||||
|
||||
# Вычисляем хеш для дедупликации
|
||||
text_hash = VectorStore.compute_text_hash(clean_text)
|
||||
|
||||
# Добавляем в хранилище
|
||||
added = self.vector_store.add_positive(embedding, text_hash)
|
||||
|
||||
if added:
|
||||
logger.info("RAGService: Добавлен положительный пример")
|
||||
|
||||
return added
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"RAGService: Ошибка добавления положительного примера: {e}")
|
||||
return False
|
||||
|
||||
async def add_negative_example(self, text: str) -> bool:
|
||||
"""
|
||||
Добавляет текст как отрицательный пример (отклоненный пост).
|
||||
|
||||
Args:
|
||||
text: Текст отклоненного поста
|
||||
|
||||
Returns:
|
||||
True если пример добавлен, False если дубликат/короткий текст
|
||||
"""
|
||||
try:
|
||||
clean_text = self._clean_text(text)
|
||||
if len(clean_text) < self.min_text_length:
|
||||
logger.debug("RAGService: Текст слишком короткий для примера, пропускаем")
|
||||
return False
|
||||
|
||||
# Получаем эмбеддинг
|
||||
embedding = await self.get_embedding(clean_text)
|
||||
|
||||
# Вычисляем хеш для дедупликации
|
||||
text_hash = VectorStore.compute_text_hash(clean_text)
|
||||
|
||||
# Добавляем в хранилище
|
||||
added = self.vector_store.add_negative(embedding, text_hash)
|
||||
|
||||
if added:
|
||||
logger.info("RAGService: Добавлен отрицательный пример")
|
||||
|
||||
return added
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"RAGService: Ошибка добавления отрицательного примера: {e}")
|
||||
return False
|
||||
|
||||
async def warmup(self) -> bool:
|
||||
"""
|
||||
Прогревает модель (загружает если не загружена).
|
||||
|
||||
Returns:
|
||||
True если модель загружена успешно
|
||||
"""
|
||||
try:
|
||||
await self.load_model()
|
||||
return self._model_loaded
|
||||
except Exception as e:
|
||||
logger.error(f"RAGService: Ошибка прогрева модели: {e}")
|
||||
return False
|
||||
|
||||
def save_vectors(self) -> None:
|
||||
"""Сохраняет векторы на диск."""
|
||||
if self.vector_store.storage_path:
|
||||
self.vector_store.save_to_disk()
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Возвращает статистику сервиса."""
|
||||
return {
|
||||
"model_name": self.model_name,
|
||||
"model_loaded": self._model_loaded,
|
||||
"device": self._device,
|
||||
"cache_dir": self.cache_dir,
|
||||
"vector_store": self.vector_store.get_stats(),
|
||||
}
|
||||
|
||||
|
||||
# Глобальный экземпляр сервиса (singleton)
|
||||
_rag_service: Optional[RAGService] = None
|
||||
|
||||
|
||||
def get_rag_service() -> RAGService:
|
||||
"""
|
||||
Возвращает глобальный экземпляр RAG сервиса.
|
||||
|
||||
Returns:
|
||||
RAGService: Экземпляр сервиса
|
||||
"""
|
||||
global _rag_service
|
||||
if _rag_service is None:
|
||||
_rag_service = RAGService()
|
||||
return _rag_service
|
||||
3
app/storage/__init__.py
Normal file
3
app/storage/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""
|
||||
Хранилище векторов.
|
||||
"""
|
||||
402
app/storage/vector_store.py
Normal file
402
app/storage/vector_store.py
Normal file
@@ -0,0 +1,402 @@
|
||||
"""
|
||||
In-memory хранилище векторов на numpy.
|
||||
|
||||
Хранит векторные представления постов для быстрого сравнения.
|
||||
Поддерживает персистентность через сохранение/загрузку с диска.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
from app.exceptions import InsufficientExamplesError, VectorStoreError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class VectorStore:
|
||||
"""
|
||||
In-memory хранилище векторов для RAG.
|
||||
|
||||
Хранит отдельно положительные (опубликованные) и отрицательные (отклоненные)
|
||||
примеры. Использует косинусное сходство для расчета скора.
|
||||
|
||||
Attributes:
|
||||
vector_dim: Размерность векторов (768 для ruBERT)
|
||||
max_examples: Максимальное количество примеров каждого типа
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vector_dim: int = 768,
|
||||
max_examples: int = 10000,
|
||||
storage_path: Optional[str] = None,
|
||||
score_multiplier: float = 5.0,
|
||||
):
|
||||
"""
|
||||
Инициализация хранилища.
|
||||
|
||||
Args:
|
||||
vector_dim: Размерность векторов
|
||||
max_examples: Максимальное количество примеров каждого типа
|
||||
storage_path: Путь для сохранения/загрузки векторов (опционально)
|
||||
score_multiplier: Множитель для усиления разницы в скорах
|
||||
"""
|
||||
self.vector_dim = vector_dim
|
||||
self.max_examples = max_examples
|
||||
self.storage_path = storage_path
|
||||
self.score_multiplier = score_multiplier
|
||||
|
||||
# Инициализируем пустые массивы
|
||||
# Используем список для динамического добавления, потом конвертируем в numpy
|
||||
self._positive_vectors: list = []
|
||||
self._negative_vectors: list = []
|
||||
self._positive_hashes: list = [] # Хеши текстов для дедупликации
|
||||
self._negative_hashes: list = []
|
||||
|
||||
# Lock для потокобезопасности
|
||||
self._lock = threading.Lock()
|
||||
|
||||
# Пытаемся загрузить сохраненные векторы
|
||||
if storage_path and os.path.exists(storage_path):
|
||||
self._load_from_disk()
|
||||
|
||||
@property
|
||||
def positive_count(self) -> int:
|
||||
"""Количество положительных примеров."""
|
||||
return len(self._positive_vectors)
|
||||
|
||||
@property
|
||||
def negative_count(self) -> int:
|
||||
"""Количество отрицательных примеров."""
|
||||
return len(self._negative_vectors)
|
||||
|
||||
@property
|
||||
def total_count(self) -> int:
|
||||
"""Общее количество примеров."""
|
||||
return self.positive_count + self.negative_count
|
||||
|
||||
@staticmethod
|
||||
def compute_text_hash(text: str) -> str:
|
||||
"""Вычисляет хеш текста для дедупликации."""
|
||||
return hashlib.md5(text.encode('utf-8')).hexdigest()
|
||||
|
||||
def _normalize_vector(self, vector: np.ndarray) -> np.ndarray:
|
||||
"""Нормализует вектор для косинусного сходства."""
|
||||
norm = np.linalg.norm(vector)
|
||||
if norm == 0:
|
||||
return vector
|
||||
return vector / norm
|
||||
|
||||
def add_positive(self, vector: np.ndarray, text_hash: Optional[str] = None) -> bool:
|
||||
"""
|
||||
Добавляет положительный пример (опубликованный пост).
|
||||
|
||||
Args:
|
||||
vector: Векторное представление текста
|
||||
text_hash: Хеш текста для дедупликации (опционально)
|
||||
|
||||
Returns:
|
||||
True если добавлен, False если дубликат или превышен лимит
|
||||
"""
|
||||
with self._lock:
|
||||
# Проверяем дубликат по хешу
|
||||
if text_hash and text_hash in self._positive_hashes:
|
||||
logger.debug("VectorStore: Пропуск дубликата положительного примера")
|
||||
return False
|
||||
|
||||
# Проверяем лимит
|
||||
if len(self._positive_vectors) >= self.max_examples:
|
||||
# Удаляем самый старый пример (FIFO)
|
||||
self._positive_vectors.pop(0)
|
||||
self._positive_hashes.pop(0)
|
||||
logger.debug("VectorStore: Удален старый положительный пример (лимит)")
|
||||
|
||||
# Нормализуем и добавляем
|
||||
normalized = self._normalize_vector(vector)
|
||||
self._positive_vectors.append(normalized)
|
||||
if text_hash:
|
||||
self._positive_hashes.append(text_hash)
|
||||
|
||||
logger.info(f"VectorStore: Добавлен положительный пример (всего: {self.positive_count})")
|
||||
return True
|
||||
|
||||
def add_positive_batch(
|
||||
self,
|
||||
vectors: List[np.ndarray],
|
||||
text_hashes: Optional[List[str]] = None
|
||||
) -> int:
|
||||
"""
|
||||
Добавляет батч положительных примеров.
|
||||
|
||||
Args:
|
||||
vectors: Список векторов
|
||||
text_hashes: Список хешей текстов для дедупликации
|
||||
|
||||
Returns:
|
||||
Количество добавленных примеров
|
||||
"""
|
||||
if text_hashes is None:
|
||||
text_hashes = [None] * len(vectors)
|
||||
|
||||
added = 0
|
||||
with self._lock:
|
||||
for vector, text_hash in zip(vectors, text_hashes):
|
||||
# Проверяем дубликат по хешу
|
||||
if text_hash and text_hash in self._positive_hashes:
|
||||
continue
|
||||
|
||||
# Проверяем лимит
|
||||
if len(self._positive_vectors) >= self.max_examples:
|
||||
self._positive_vectors.pop(0)
|
||||
self._positive_hashes.pop(0)
|
||||
|
||||
# Нормализуем и добавляем
|
||||
normalized = self._normalize_vector(vector)
|
||||
self._positive_vectors.append(normalized)
|
||||
if text_hash:
|
||||
self._positive_hashes.append(text_hash)
|
||||
added += 1
|
||||
|
||||
logger.info(f"VectorStore: Добавлено {added} положительных примеров батчем (всего: {self.positive_count})")
|
||||
return added
|
||||
|
||||
def add_negative(self, vector: np.ndarray, text_hash: Optional[str] = None) -> bool:
|
||||
"""
|
||||
Добавляет отрицательный пример (отклоненный пост).
|
||||
|
||||
Args:
|
||||
vector: Векторное представление текста
|
||||
text_hash: Хеш текста для дедупликации (опционально)
|
||||
|
||||
Returns:
|
||||
True если добавлен, False если дубликат или превышен лимит
|
||||
"""
|
||||
with self._lock:
|
||||
# Проверяем дубликат по хешу
|
||||
if text_hash and text_hash in self._negative_hashes:
|
||||
logger.debug("VectorStore: Пропуск дубликата отрицательного примера")
|
||||
return False
|
||||
|
||||
# Проверяем лимит
|
||||
if len(self._negative_vectors) >= self.max_examples:
|
||||
# Удаляем самый старый пример (FIFO)
|
||||
self._negative_vectors.pop(0)
|
||||
self._negative_hashes.pop(0)
|
||||
logger.debug("VectorStore: Удален старый отрицательный пример (лимит)")
|
||||
|
||||
# Нормализуем и добавляем
|
||||
normalized = self._normalize_vector(vector)
|
||||
self._negative_vectors.append(normalized)
|
||||
if text_hash:
|
||||
self._negative_hashes.append(text_hash)
|
||||
|
||||
logger.info(f"VectorStore: Добавлен отрицательный пример (всего: {self.negative_count})")
|
||||
return True
|
||||
|
||||
def add_negative_batch(
|
||||
self,
|
||||
vectors: List[np.ndarray],
|
||||
text_hashes: Optional[List[str]] = None
|
||||
) -> int:
|
||||
"""
|
||||
Добавляет батч отрицательных примеров.
|
||||
|
||||
Args:
|
||||
vectors: Список векторов
|
||||
text_hashes: Список хешей текстов для дедупликации
|
||||
|
||||
Returns:
|
||||
Количество добавленных примеров
|
||||
"""
|
||||
if text_hashes is None:
|
||||
text_hashes = [None] * len(vectors)
|
||||
|
||||
added = 0
|
||||
with self._lock:
|
||||
for vector, text_hash in zip(vectors, text_hashes):
|
||||
# Проверяем дубликат по хешу
|
||||
if text_hash and text_hash in self._negative_hashes:
|
||||
continue
|
||||
|
||||
# Проверяем лимит
|
||||
if len(self._negative_vectors) >= self.max_examples:
|
||||
self._negative_vectors.pop(0)
|
||||
self._negative_hashes.pop(0)
|
||||
|
||||
# Нормализуем и добавляем
|
||||
normalized = self._normalize_vector(vector)
|
||||
self._negative_vectors.append(normalized)
|
||||
if text_hash:
|
||||
self._negative_hashes.append(text_hash)
|
||||
added += 1
|
||||
|
||||
logger.info(f"VectorStore: Добавлено {added} отрицательных примеров батчем (всего: {self.negative_count})")
|
||||
return added
|
||||
|
||||
def calculate_similarity_score(self, vector: np.ndarray) -> Tuple[float, float, float]:
|
||||
"""
|
||||
Рассчитывает скор на основе сходства с примерами.
|
||||
|
||||
Алгоритм:
|
||||
1. Вычисляем среднее косинусное сходство с положительными примерами
|
||||
2. Вычисляем среднее косинусное сходство с отрицательными примерами
|
||||
3. Финальный скор = pos_sim / (pos_sim + neg_sim + eps)
|
||||
|
||||
Args:
|
||||
vector: Векторное представление нового поста
|
||||
|
||||
Returns:
|
||||
Tuple (score, confidence, score_pos_only):
|
||||
- score: Оценка от 0.0 до 1.0 (neg/pos формула)
|
||||
- confidence: Уверенность (зависит от количества примеров)
|
||||
- score_pos_only: Оценка только по положительным примерам
|
||||
|
||||
Raises:
|
||||
InsufficientExamplesError: Если недостаточно примеров
|
||||
"""
|
||||
with self._lock:
|
||||
if self.positive_count == 0:
|
||||
raise InsufficientExamplesError(
|
||||
"Нет положительных примеров для сравнения"
|
||||
)
|
||||
|
||||
# Нормализуем входной вектор
|
||||
normalized = self._normalize_vector(vector)
|
||||
|
||||
# Конвертируем в numpy массивы для быстрых вычислений
|
||||
pos_matrix = np.array(self._positive_vectors)
|
||||
|
||||
# Косинусное сходство с положительными примерами
|
||||
# Для нормализованных векторов это просто скалярное произведение
|
||||
pos_similarities = np.dot(pos_matrix, normalized)
|
||||
pos_sim = float(np.mean(pos_similarities))
|
||||
|
||||
# Косинусное сходство с отрицательными примерами
|
||||
if self.negative_count > 0:
|
||||
neg_matrix = np.array(self._negative_vectors)
|
||||
neg_similarities = np.dot(neg_matrix, normalized)
|
||||
neg_sim = float(np.mean(neg_similarities))
|
||||
else:
|
||||
# Если нет отрицательных примеров, используем нейтральное значение
|
||||
neg_sim = pos_sim # Нейтральный скор = 0.5
|
||||
|
||||
# === Вариант 1: neg/pos (разница между положительными и отрицательными) ===
|
||||
diff = pos_sim - neg_sim
|
||||
score_neg_pos = 0.5 + (diff * self.score_multiplier)
|
||||
score_neg_pos = max(0.0, min(1.0, score_neg_pos))
|
||||
|
||||
# === Вариант 2: pos only (только положительные, топ-k ближайших) ===
|
||||
# Берём топ-5 ближайших положительных примеров
|
||||
top_k = min(5, len(pos_similarities))
|
||||
top_k_sim = float(np.mean(np.sort(pos_similarities)[-top_k:]))
|
||||
# Нормализуем: 0.85 -> 0.0, 0.95 -> 1.0 (типичный диапазон для BERT)
|
||||
score_pos_only = (top_k_sim - 0.85) / 0.10
|
||||
score_pos_only = max(0.0, min(1.0, score_pos_only))
|
||||
|
||||
# Основной скор — neg/pos
|
||||
score = score_neg_pos
|
||||
|
||||
# Confidence зависит от количества примеров (100% при 1000 примерах)
|
||||
total_examples = self.positive_count + self.negative_count
|
||||
confidence = min(1.0, total_examples / 1000)
|
||||
|
||||
logger.info(
|
||||
f"VectorStore: pos_sim={pos_sim:.4f}, neg_sim={neg_sim:.4f}, "
|
||||
f"top_k_sim={top_k_sim:.4f}, score_neg_pos={score_neg_pos:.4f}, "
|
||||
f"score_pos_only={score_pos_only:.4f}"
|
||||
)
|
||||
|
||||
return score, confidence, score_pos_only
|
||||
|
||||
def save_to_disk(self, path: Optional[str] = None) -> None:
|
||||
"""
|
||||
Сохраняет векторы на диск.
|
||||
|
||||
Args:
|
||||
path: Путь для сохранения (если не указан, используется storage_path)
|
||||
"""
|
||||
save_path = path or self.storage_path
|
||||
if not save_path:
|
||||
raise VectorStoreError("Путь для сохранения не указан")
|
||||
|
||||
with self._lock:
|
||||
# Создаем директорию если нужно
|
||||
Path(save_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Сохраняем в npz формате
|
||||
np.savez_compressed(
|
||||
save_path,
|
||||
positive_vectors=np.array(self._positive_vectors) if self._positive_vectors else np.array([]),
|
||||
negative_vectors=np.array(self._negative_vectors) if self._negative_vectors else np.array([]),
|
||||
positive_hashes=np.array(self._positive_hashes, dtype=object),
|
||||
negative_hashes=np.array(self._negative_hashes, dtype=object),
|
||||
vector_dim=self.vector_dim,
|
||||
max_examples=self.max_examples,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"VectorStore: Сохранено на диск ({self.positive_count} pos, "
|
||||
f"{self.negative_count} neg): {save_path}"
|
||||
)
|
||||
|
||||
def _load_from_disk(self) -> None:
|
||||
"""Загружает векторы с диска."""
|
||||
if not self.storage_path or not os.path.exists(self.storage_path):
|
||||
return
|
||||
|
||||
try:
|
||||
with self._lock:
|
||||
data = np.load(self.storage_path, allow_pickle=True)
|
||||
|
||||
# Загружаем векторы
|
||||
pos_vectors = data.get('positive_vectors', np.array([]))
|
||||
neg_vectors = data.get('negative_vectors', np.array([]))
|
||||
|
||||
if pos_vectors.size > 0:
|
||||
self._positive_vectors = list(pos_vectors)
|
||||
if neg_vectors.size > 0:
|
||||
self._negative_vectors = list(neg_vectors)
|
||||
|
||||
# Загружаем хеши
|
||||
pos_hashes = data.get('positive_hashes', np.array([]))
|
||||
neg_hashes = data.get('negative_hashes', np.array([]))
|
||||
|
||||
if pos_hashes.size > 0:
|
||||
self._positive_hashes = list(pos_hashes)
|
||||
if neg_hashes.size > 0:
|
||||
self._negative_hashes = list(neg_hashes)
|
||||
|
||||
logger.info(
|
||||
f"VectorStore: Загружено с диска ({self.positive_count} pos, "
|
||||
f"{self.negative_count} neg): {self.storage_path}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"VectorStore: Ошибка загрузки с диска: {e}")
|
||||
# Продолжаем с пустым хранилищем
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Очищает все векторы."""
|
||||
with self._lock:
|
||||
self._positive_vectors.clear()
|
||||
self._negative_vectors.clear()
|
||||
self._positive_hashes.clear()
|
||||
self._negative_hashes.clear()
|
||||
logger.info("VectorStore: Хранилище очищено")
|
||||
|
||||
def get_stats(self) -> dict:
|
||||
"""Возвращает статистику хранилища."""
|
||||
return {
|
||||
"positive_count": self.positive_count,
|
||||
"negative_count": self.negative_count,
|
||||
"total_count": self.total_count,
|
||||
"vector_dim": self.vector_dim,
|
||||
"max_examples": self.max_examples,
|
||||
"storage_path": self.storage_path,
|
||||
}
|
||||
0
data/models/.gitkeep
Normal file
0
data/models/.gitkeep
Normal file
0
data/vectors/.gitkeep
Normal file
0
data/vectors/.gitkeep
Normal file
48
docker-compose.yml
Normal file
48
docker-compose.yml
Normal file
@@ -0,0 +1,48 @@
|
||||
version: '3.8'
|
||||
|
||||
services:
|
||||
rag-service:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile
|
||||
container_name: rag-service
|
||||
restart: unless-stopped
|
||||
# Порт открываем только для localhost (NGINX будет проксировать)
|
||||
# Для прямого доступа используй: "8000:8000"
|
||||
ports:
|
||||
- "127.0.0.1:8000:8000"
|
||||
volumes:
|
||||
# Персистентность данных модели и векторов
|
||||
- ./data/models:/app/data/models
|
||||
- ./data/vectors:/app/data/vectors
|
||||
environment:
|
||||
- RAG_MODEL=${RAG_MODEL:-DeepPavlov/rubert-base-cased}
|
||||
- RAG_CACHE_DIR=/app/data/models
|
||||
- RAG_VECTORS_PATH=/app/data/vectors/vectors.npz
|
||||
- RAG_MAX_EXAMPLES=${RAG_MAX_EXAMPLES:-10000}
|
||||
- RAG_SCORE_MULTIPLIER=${RAG_SCORE_MULTIPLIER:-5.0}
|
||||
- RAG_BATCH_SIZE=${RAG_BATCH_SIZE:-16}
|
||||
- RAG_MIN_TEXT_LENGTH=${RAG_MIN_TEXT_LENGTH:-3}
|
||||
- RAG_API_HOST=0.0.0.0
|
||||
- RAG_API_PORT=8000
|
||||
# Безопасность
|
||||
- RAG_API_KEY=${RAG_API_KEY}
|
||||
- RAG_ALLOW_NO_AUTH=${RAG_ALLOW_NO_AUTH:-false}
|
||||
# Автосохранение
|
||||
- RAG_AUTOSAVE_INTERVAL=${RAG_AUTOSAVE_INTERVAL:-600}
|
||||
- LOG_LEVEL=${LOG_LEVEL:-INFO}
|
||||
healthcheck:
|
||||
test: ["CMD", "python", "-c", "import urllib.request; urllib.request.urlopen('http://localhost:8000/api/v1/health')"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
start_period: 60s
|
||||
# Ограничения ресурсов (рекомендуется для продакшена)
|
||||
# deploy:
|
||||
# resources:
|
||||
# limits:
|
||||
# cpus: '2'
|
||||
# memory: 4G
|
||||
# reservations:
|
||||
# cpus: '1'
|
||||
# memory: 2G
|
||||
33
env.example
Normal file
33
env.example
Normal file
@@ -0,0 +1,33 @@
|
||||
# RAG Service Configuration
|
||||
|
||||
# Модель
|
||||
RAG_MODEL=DeepPavlov/rubert-base-cased
|
||||
RAG_CACHE_DIR=data/models
|
||||
|
||||
# VectorStore
|
||||
RAG_VECTORS_PATH=data/vectors/vectors.npz
|
||||
RAG_MAX_EXAMPLES=10000
|
||||
RAG_SCORE_MULTIPLIER=5.0
|
||||
|
||||
# Батч-обработка
|
||||
RAG_BATCH_SIZE=16
|
||||
|
||||
# Минимальная длина текста
|
||||
RAG_MIN_TEXT_LENGTH=3
|
||||
|
||||
# API настройки
|
||||
RAG_API_HOST=0.0.0.0
|
||||
RAG_API_PORT=8000
|
||||
|
||||
# Безопасность (ОБЯЗАТЕЛЬНО для продакшена!)
|
||||
# Сгенерировать ключ: python -c "import secrets; print(secrets.token_urlsafe(32))"
|
||||
RAG_API_KEY=your-super-secret-api-key-here
|
||||
|
||||
# Разрешить запросы без ключа (только для разработки, в продакшене = false)
|
||||
RAG_ALLOW_NO_AUTH=false
|
||||
|
||||
# Автосохранение векторов (секунды, 0 = отключено)
|
||||
RAG_AUTOSAVE_INTERVAL=600
|
||||
|
||||
# Логирование
|
||||
LOG_LEVEL=INFO
|
||||
44
pyproject.toml
Normal file
44
pyproject.toml
Normal file
@@ -0,0 +1,44 @@
|
||||
[project]
|
||||
name = "rag-service"
|
||||
version = "0.1.0"
|
||||
description = "RAG Service - сервис векторного скоринга на FastAPI с ruBERT"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.11"
|
||||
license = {text = "MIT"}
|
||||
authors = [
|
||||
{name = "Developer"}
|
||||
]
|
||||
|
||||
dependencies = [
|
||||
"fastapi>=0.109.0",
|
||||
"uvicorn[standard]>=0.27.0",
|
||||
"pydantic>=2.5.0",
|
||||
"torch>=2.1.0",
|
||||
"transformers>=4.36.0",
|
||||
"numpy>=1.24.0",
|
||||
"python-dotenv>=1.0.0",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = [
|
||||
"pytest>=7.4.0",
|
||||
"pytest-asyncio>=0.23.0",
|
||||
"httpx>=0.26.0",
|
||||
"ruff>=0.1.0",
|
||||
]
|
||||
|
||||
[build-system]
|
||||
requires = ["setuptools>=61.0"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 100
|
||||
target-version = "py311"
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = ["E", "F", "W", "I", "N", "UP"]
|
||||
ignore = ["E501"]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
asyncio_mode = "auto"
|
||||
testpaths = ["tests"]
|
||||
12
requirements.txt
Normal file
12
requirements.txt
Normal file
@@ -0,0 +1,12 @@
|
||||
# FastAPI и веб-сервер
|
||||
fastapi>=0.109.0
|
||||
uvicorn[standard]>=0.27.0
|
||||
pydantic>=2.5.0
|
||||
|
||||
# ML / NLP
|
||||
torch>=2.1.0
|
||||
transformers>=4.36.0
|
||||
numpy>=1.24.0
|
||||
|
||||
# Утилиты
|
||||
python-dotenv>=1.0.0
|
||||
Reference in New Issue
Block a user