Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(api): ajout des différents endpoint #54

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .env.sample
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,10 @@ POSTGRES_PASSWORD=postgrespassword
POSTGRES_DATABASE_NAME=postgres
POSTGRES_DATABASE_URL=localhost
ALBERT_API_KEY=xxx
HF_API_TOKEN=token
API_PORT=8000
API_HOST=localhost
API_KEY=abc
CHATGPT_API_KEY=abc
MISTRAL_API_KEY=abc
SRDT_ALLOW_ORIGINS=https://*.social.gouv.fr,http://localhost:3000
26 changes: 18 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,18 +1,28 @@
# Assistant virtuel SRDT

## Installation et lancement
## Installation

```sh
make install
poetry shell # to activate the virtual environment
pre-commit run --all-files
poetry run start # or poetry run python -m srdt_analysis
ruff check --fix
ruff format
pyright # for type checking
```

## Statistiques sur les documents
## Commands

```sh
poetry run ingest # for launching the ingestion of data
poetry run api # for launching the API
```

## Lint, format and type checking

```sh
poetry run ruff check --fix # for checking and fixing
poetry run ruff format # for formatting
poetry run pyright # for type checking
poetry run pre-commit run --all-files # for running all the checks
```

## Stats

| Type de document | Nombre |
| -------------------- | ------ |
Expand Down
7 changes: 6 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ langchain-text-splitters = "^0.3.2"
detect-secrets = "^1.5.0"
pre-commit = "^4.0.1"
lxml = "^5.3.0"
fastapi = "^0.115.6"
uvicorn = "^0.34.0"
pydantic = "^2.10.4"
transformers = "^4.47.1"

[tool.poetry.group.dev.dependencies]
pyright = "^1.1.389"
Expand All @@ -25,7 +29,8 @@ requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"

[tool.poetry.scripts]
start = "srdt_analysis.__main__:main"
ingest = "srdt_analysis.scripts.ingest:start"
api = "srdt_analysis.api.launcher:start"

[tool.ruff]
exclude = [
Expand Down
104 changes: 104 additions & 0 deletions request.http
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
### Base API Test
GET http://localhost:8000/api/v1/
Authorization: Bearer abc

### Anonymize Endpoint Tests
POST http://localhost:8000/api/v1/anonymize
Authorization: Bearer abc
content-type: application/json

{
"user_question": "Je m'appelle Jean-Bernard, j'habite à Paris et j'ai 25 ans. Je veux savoir quels sont mes droits en terme de congé parental."
}

### Test anonymize with specific prompt
POST http://localhost:8000/api/v1/anonymize
Authorization: Bearer abc
content-type: application/json

{
"user_question": "Je travaille chez Microsoft à Lyon et mon manager Paul refuse mes congés",
"anonymization_prompt": "Anonymise en gardant le contexte professionnel"
}

### Rephrase Endpoint Tests
POST http://localhost:8000/api/v1/rephrase
Authorization: Bearer abc
content-type: application/json

{
"question": "Je voudrais savoir quels sont mes droits concernant les congés payés",
"rephrasing_prompt": "Reformule cette question de manière professionnelle"
}

### Test rephrase with query splitting
POST http://localhost:8000/api/v1/rephrase
Authorization: Bearer abc
content-type: application/json

{
"question": "Comment calculer mes heures supplémentaires et les récupérer ?",
"rephrasing_prompt": "Reformule cette question de manière professionnelle",
"queries_splitting_prompt": "Décompose la question en sous-questions spécifiques"
}

### Search Endpoint Tests
POST http://localhost:8000/api/v1/search
Authorization: Bearer abc
content-type: application/json

{
"prompts": ["congés payés calcul durée"]
}

### Test search with custom options
POST http://localhost:8000/api/v1/search
Authorization: Bearer abc
content-type: application/json

{
"prompts": ["droits congé parental durée"],
"options": {
"top_K": 5,
"threshold": 0.65,
"collections": ["code_du_travail", "conventions_collectives"]
}
}

### Generate Endpoint Tests
POST http://localhost:8000/api/v1/generate
Authorization: Bearer abc
content-type: application/json

{
"chat_history": [
{
"role": "user",
"content": "Quels sont mes droits concernant le congé parental ?"
}
],
"system_prompt": "Tu es un assistant juridique spécialisé en droit du travail. Réponds de manière concise et précise."
}

### Test generate with conversation history
POST http://localhost:8000/api/v1/generate
Authorization: Bearer abc
content-type: application/json

{
"chat_history": [
{
"role": "user",
"content": "Quel est le délai de préavis pour une démission ?"
},
{
"role": "assistant",
"content": "Le délai de préavis dépend de votre convention collective et de votre statut."
},
{
"role": "user",
"content": "Je suis cadre dans le secteur informatique."
}
],
"system_prompt": "Tu es un assistant juridique spécialisé en droit du travail. Réponds de manière précise."
}
22 changes: 0 additions & 22 deletions srdt_analysis/albert.py

This file was deleted.

19 changes: 19 additions & 0 deletions srdt_analysis/api/launcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import os

import uvicorn
from dotenv import load_dotenv

load_dotenv()


def start():
uvicorn.run(
"srdt_analysis.api.main:app",
host=os.getenv("API_HOST", "localhost"),
port=int(os.getenv("API_PORT", 8000)),
reload=True,
)


if __name__ == "__main__":
start()
181 changes: 181 additions & 0 deletions srdt_analysis/api/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
import os
import time

from dotenv import load_dotenv
from fastapi import Depends, FastAPI, HTTPException, Security
from fastapi.middleware.cors import CORSMiddleware
from fastapi.security import APIKeyHeader

from srdt_analysis.api.schemas import (
AnonymizeRequest,
AnonymizeResponse,
ChunkMetadata,
ChunkResult,
GenerateRequest,
GenerateResponse,
RephraseRequest,
RephraseResponse,
SearchRequest,
SearchResponse,
)
from srdt_analysis.collections import AlbertCollectionHandler
from srdt_analysis.constants import ALBERT_ENDPOINT, ALBERT_MODEL, BASE_API_URL
from srdt_analysis.llm_runner import LLMRunner
from srdt_analysis.tokenizer import Tokenizer

load_dotenv()

app = FastAPI()
api_key_header = APIKeyHeader(name="Authorization", auto_error=True)


async def get_api_key(api_key: str = Security(api_key_header)):
if not api_key.startswith("Bearer "):
raise HTTPException(
status_code=401,
detail="Invalid authorization header format. Must start with 'Bearer '",
)
token = api_key.replace("Bearer ", "")
if token != os.getenv("API_KEY"):
raise HTTPException(status_code=401, detail="Invalid API key")
return api_key


app.add_middleware(
CORSMiddleware,
allow_origins=os.getenv("SRDT_ALLOW_ORIGINS", "").split(",")
if os.getenv("SRDT_ALLOW_ORIGINS")
else [],
allow_credentials=True,
allow_methods=["GET", "POST"],
allow_headers=["Authorization", "Content-Type"],
)


@app.get("/")
@app.get(f"{BASE_API_URL}/")
async def root(api_key: str = Depends(get_api_key)):
return {"status": "ok", "path": BASE_API_URL}


@app.post(f"{BASE_API_URL}/anonymize", response_model=AnonymizeResponse)
async def anonymize(request: AnonymizeRequest, _api_key: str = Depends(get_api_key)):
start_time = time.time()
tokenizer = Tokenizer(model=ALBERT_MODEL)
llm_runner = LLMRunner(
llm_api_token=os.getenv("ALBERT_API_KEY", ""),
llm_model=ALBERT_MODEL,
llm_url=ALBERT_ENDPOINT,
)
try:
anonymized_question = await llm_runner.anonymize(
request.user_question, request.anonymization_prompt
)
return AnonymizeResponse(
time=time.time() - start_time,
anonymized_question=anonymized_question,
nb_token_input=tokenizer.compute_nb_tokens(request.user_question),
nb_token_output=tokenizer.compute_nb_tokens(anonymized_question),
)
except ValueError as ve:
raise HTTPException(status_code=400, detail=str(ve))
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))


@app.post(f"{BASE_API_URL}/rephrase", response_model=RephraseResponse)
async def rephrase(request: RephraseRequest, api_key: str = Depends(get_api_key)):
start_time = time.time()
tokenizer = Tokenizer(model=ALBERT_MODEL)
llm_runner = LLMRunner(
llm_api_token=os.getenv("ALBERT_API_KEY", ""),
llm_model=ALBERT_MODEL,
llm_url=ALBERT_ENDPOINT,
)

try:
rephrased, queries = await llm_runner.rephrase_and_split(
request.question,
request.rephrasing_prompt,
request.queries_splitting_prompt,
)

return RephraseResponse(
time=time.time() - start_time,
rephrased_question=rephrased,
queries=queries,
nb_token_input=tokenizer.compute_nb_tokens(request.question),
nb_token_output=tokenizer.compute_nb_tokens(rephrased),
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))


@app.post(f"{BASE_API_URL}/search", response_model=SearchResponse)
async def search(request: SearchRequest, api_key: str = Depends(get_api_key)):
start_time = time.time()
collections = AlbertCollectionHandler()
try:
transformed_results = []

for prompt in request.prompts:
maxgfr marked this conversation as resolved.
Show resolved Hide resolved
search_result = collections.search(
prompt=prompt,
id_collections=request.options.collections,
k=request.options.top_K,
score_threshold=request.options.threshold,
)

for item in search_result:
chunk_data = item["chunk"]
metadata = chunk_data["metadata"]

transformed_chunk = ChunkResult(
score=item["score"],
content=chunk_data["content"],
id_chunk=chunk_data["id"],
metadata=ChunkMetadata(
document_id=metadata["document_id"],
source=metadata["source"],
title=metadata["document_name"],
url=metadata["url"],
),
)
transformed_results.append(transformed_chunk)

return SearchResponse(
time=time.time() - start_time,
top_chunks=transformed_results,
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))


@app.post(f"{BASE_API_URL}/generate", response_model=GenerateResponse)
async def generate(request: GenerateRequest, api_key: str = Depends(get_api_key)):
start_time = time.time()
tokenizer = Tokenizer(model=ALBERT_MODEL)
llm_runner = LLMRunner(
llm_api_token=os.getenv("ALBERT_API_KEY", ""),
llm_model=ALBERT_MODEL,
llm_url=ALBERT_ENDPOINT,
)

try:
response = await llm_runner.chat_with_full_document(
chat_history=request.chat_history,
prompt=request.system_prompt,
)

chat_history_str = " ".join(
[msg.get("content", "") for msg in request.chat_history]
)

return GenerateResponse(
time=time.time() - start_time,
text=response,
nb_token_input=tokenizer.compute_nb_tokens(chat_history_str),
nb_token_output=tokenizer.compute_nb_tokens(response),
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
Loading