diff --git a/.env.sample b/.env.sample index 72f0258..1cf5d43 100644 --- a/.env.sample +++ b/.env.sample @@ -3,3 +3,10 @@ POSTGRES_PASSWORD=postgrespassword POSTGRES_DATABASE_NAME=postgres POSTGRES_DATABASE_URL=localhost ALBERT_API_KEY=xxx +HF_API_TOKEN=token +API_PORT=8000 +API_HOST=localhost +API_KEY=abc +CHATGPT_API_KEY=abc +MISTRAL_API_KEY=abc +SRDT_ALLOW_ORIGINS=https://*.social.gouv.fr,http://localhost:3000 \ No newline at end of file diff --git a/README.md b/README.md index 3013686..c4237ab 100644 --- a/README.md +++ b/README.md @@ -1,18 +1,28 @@ # Assistant virtuel SRDT -## Installation et lancement +## Installation ```sh make install -poetry shell # to activate the virtual environment -pre-commit run --all-files -poetry run start # or poetry run python -m srdt_analysis -ruff check --fix -ruff format -pyright # for type checking ``` -## Statistiques sur les documents +## Commands + +```sh +poetry run ingest # for launching the ingestion of data +poetry run api # for launching the API +``` + +## Lint, format and type checking + +```sh +poetry run ruff check --fix # for checking and fixing +poetry run ruff format # for formatting +poetry run pyright # for type checking +poetry run pre-commit run --all-files # for running all the checks +``` + +## Stats | Type de document | Nombre | | -------------------- | ------ | diff --git a/pyproject.toml b/pyproject.toml index 124176f..186ced1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,6 +15,10 @@ langchain-text-splitters = "^0.3.2" detect-secrets = "^1.5.0" pre-commit = "^4.0.1" lxml = "^5.3.0" +fastapi = "^0.115.6" +uvicorn = "^0.34.0" +pydantic = "^2.10.4" +transformers = "^4.47.1" [tool.poetry.group.dev.dependencies] pyright = "^1.1.389" @@ -25,7 +29,8 @@ requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" [tool.poetry.scripts] -start = "srdt_analysis.__main__:main" +ingest = "srdt_analysis.scripts.ingest:start" +api = "srdt_analysis.api.launcher:start" [tool.ruff] exclude = [ diff --git a/request.http b/request.http new file mode 100644 index 0000000..867a8bc --- /dev/null +++ b/request.http @@ -0,0 +1,104 @@ +### Base API Test +GET http://localhost:8000/api/v1/ +Authorization: Bearer abc + +### Anonymize Endpoint Tests +POST http://localhost:8000/api/v1/anonymize +Authorization: Bearer abc +content-type: application/json + +{ + "user_question": "Je m'appelle Jean-Bernard, j'habite à Paris et j'ai 25 ans. Je veux savoir quels sont mes droits en terme de congé parental." +} + +### Test anonymize with specific prompt +POST http://localhost:8000/api/v1/anonymize +Authorization: Bearer abc +content-type: application/json + +{ + "user_question": "Je travaille chez Microsoft à Lyon et mon manager Paul refuse mes congés", + "anonymization_prompt": "Anonymise en gardant le contexte professionnel" +} + +### Rephrase Endpoint Tests +POST http://localhost:8000/api/v1/rephrase +Authorization: Bearer abc +content-type: application/json + +{ + "question": "Je voudrais savoir quels sont mes droits concernant les congés payés", + "rephrasing_prompt": "Reformule cette question de manière professionnelle" +} + +### Test rephrase with query splitting +POST http://localhost:8000/api/v1/rephrase +Authorization: Bearer abc +content-type: application/json + +{ + "question": "Comment calculer mes heures supplémentaires et les récupérer ?", + "rephrasing_prompt": "Reformule cette question de manière professionnelle", + "queries_splitting_prompt": "Décompose la question en sous-questions spécifiques" +} + +### Search Endpoint Tests +POST http://localhost:8000/api/v1/search +Authorization: Bearer abc +content-type: application/json + +{ + "prompts": ["congés payés calcul durée"] +} + +### Test search with custom options +POST http://localhost:8000/api/v1/search +Authorization: Bearer abc +content-type: application/json + +{ + "prompts": ["droits congé parental durée"], + "options": { + "top_K": 5, + "threshold": 0.65, + "collections": ["code_du_travail", "conventions_collectives"] + } +} + +### Generate Endpoint Tests +POST http://localhost:8000/api/v1/generate +Authorization: Bearer abc +content-type: application/json + +{ + "chat_history": [ + { + "role": "user", + "content": "Quels sont mes droits concernant le congé parental ?" + } + ], + "system_prompt": "Tu es un assistant juridique spécialisé en droit du travail. Réponds de manière concise et précise." +} + +### Test generate with conversation history +POST http://localhost:8000/api/v1/generate +Authorization: Bearer abc +content-type: application/json + +{ + "chat_history": [ + { + "role": "user", + "content": "Quel est le délai de préavis pour une démission ?" + }, + { + "role": "assistant", + "content": "Le délai de préavis dépend de votre convention collective et de votre statut." + }, + { + "role": "user", + "content": "Je suis cadre dans le secteur informatique." + } + ], + "system_prompt": "Tu es un assistant juridique spécialisé en droit du travail. Réponds de manière précise." +} diff --git a/srdt_analysis/albert.py b/srdt_analysis/albert.py deleted file mode 100644 index 3a94dac..0000000 --- a/srdt_analysis/albert.py +++ /dev/null @@ -1,22 +0,0 @@ -import os -from typing import Any, Dict - -import httpx - -from srdt_analysis.constants import ALBERT_ENDPOINT - - -class AlbertBase: - def __init__(self): - self.api_key = os.getenv("ALBERT_API_KEY") - if not self.api_key: - raise ValueError( - "API key must be provided either in constructor or as environment variable" - ) - self.headers = { - "Authorization": f"Bearer {self.api_key}", - } - - def get_models(self) -> Dict[str, Any]: - response = httpx.get(f"{ALBERT_ENDPOINT}/v1/models", headers=self.headers) - return response.json() diff --git a/srdt_analysis/api/launcher.py b/srdt_analysis/api/launcher.py new file mode 100644 index 0000000..b5f24f1 --- /dev/null +++ b/srdt_analysis/api/launcher.py @@ -0,0 +1,19 @@ +import os + +import uvicorn +from dotenv import load_dotenv + +load_dotenv() + + +def start(): + uvicorn.run( + "srdt_analysis.api.main:app", + host=os.getenv("API_HOST", "localhost"), + port=int(os.getenv("API_PORT", 8000)), + reload=True, + ) + + +if __name__ == "__main__": + start() diff --git a/srdt_analysis/api/main.py b/srdt_analysis/api/main.py new file mode 100644 index 0000000..8d25581 --- /dev/null +++ b/srdt_analysis/api/main.py @@ -0,0 +1,181 @@ +import os +import time + +from dotenv import load_dotenv +from fastapi import Depends, FastAPI, HTTPException, Security +from fastapi.middleware.cors import CORSMiddleware +from fastapi.security import APIKeyHeader + +from srdt_analysis.api.schemas import ( + AnonymizeRequest, + AnonymizeResponse, + ChunkMetadata, + ChunkResult, + GenerateRequest, + GenerateResponse, + RephraseRequest, + RephraseResponse, + SearchRequest, + SearchResponse, +) +from srdt_analysis.collections import AlbertCollectionHandler +from srdt_analysis.constants import ALBERT_ENDPOINT, ALBERT_MODEL, BASE_API_URL +from srdt_analysis.llm_runner import LLMRunner +from srdt_analysis.tokenizer import Tokenizer + +load_dotenv() + +app = FastAPI() +api_key_header = APIKeyHeader(name="Authorization", auto_error=True) + + +async def get_api_key(api_key: str = Security(api_key_header)): + if not api_key.startswith("Bearer "): + raise HTTPException( + status_code=401, + detail="Invalid authorization header format. Must start with 'Bearer '", + ) + token = api_key.replace("Bearer ", "") + if token != os.getenv("API_KEY"): + raise HTTPException(status_code=401, detail="Invalid API key") + return api_key + + +app.add_middleware( + CORSMiddleware, + allow_origins=os.getenv("SRDT_ALLOW_ORIGINS", "").split(",") + if os.getenv("SRDT_ALLOW_ORIGINS") + else [], + allow_credentials=True, + allow_methods=["GET", "POST"], + allow_headers=["Authorization", "Content-Type"], +) + + +@app.get("/") +@app.get(f"{BASE_API_URL}/") +async def root(api_key: str = Depends(get_api_key)): + return {"status": "ok", "path": BASE_API_URL} + + +@app.post(f"{BASE_API_URL}/anonymize", response_model=AnonymizeResponse) +async def anonymize(request: AnonymizeRequest, _api_key: str = Depends(get_api_key)): + start_time = time.time() + tokenizer = Tokenizer(model=ALBERT_MODEL) + llm_runner = LLMRunner( + llm_api_token=os.getenv("ALBERT_API_KEY", ""), + llm_model=ALBERT_MODEL, + llm_url=ALBERT_ENDPOINT, + ) + try: + anonymized_question = await llm_runner.anonymize( + request.user_question, request.anonymization_prompt + ) + return AnonymizeResponse( + time=time.time() - start_time, + anonymized_question=anonymized_question, + nb_token_input=tokenizer.compute_nb_tokens(request.user_question), + nb_token_output=tokenizer.compute_nb_tokens(anonymized_question), + ) + except ValueError as ve: + raise HTTPException(status_code=400, detail=str(ve)) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +@app.post(f"{BASE_API_URL}/rephrase", response_model=RephraseResponse) +async def rephrase(request: RephraseRequest, api_key: str = Depends(get_api_key)): + start_time = time.time() + tokenizer = Tokenizer(model=ALBERT_MODEL) + llm_runner = LLMRunner( + llm_api_token=os.getenv("ALBERT_API_KEY", ""), + llm_model=ALBERT_MODEL, + llm_url=ALBERT_ENDPOINT, + ) + + try: + rephrased, queries = await llm_runner.rephrase_and_split( + request.question, + request.rephrasing_prompt, + request.queries_splitting_prompt, + ) + + return RephraseResponse( + time=time.time() - start_time, + rephrased_question=rephrased, + queries=queries, + nb_token_input=tokenizer.compute_nb_tokens(request.question), + nb_token_output=tokenizer.compute_nb_tokens(rephrased), + ) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +@app.post(f"{BASE_API_URL}/search", response_model=SearchResponse) +async def search(request: SearchRequest, api_key: str = Depends(get_api_key)): + start_time = time.time() + collections = AlbertCollectionHandler() + try: + transformed_results = [] + + for prompt in request.prompts: + search_result = collections.search( + prompt=prompt, + id_collections=request.options.collections, + k=request.options.top_K, + score_threshold=request.options.threshold, + ) + + for item in search_result: + chunk_data = item["chunk"] + metadata = chunk_data["metadata"] + + transformed_chunk = ChunkResult( + score=item["score"], + content=chunk_data["content"], + id_chunk=chunk_data["id"], + metadata=ChunkMetadata( + document_id=metadata["document_id"], + source=metadata["source"], + title=metadata["document_name"], + url=metadata["url"], + ), + ) + transformed_results.append(transformed_chunk) + + return SearchResponse( + time=time.time() - start_time, + top_chunks=transformed_results, + ) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +@app.post(f"{BASE_API_URL}/generate", response_model=GenerateResponse) +async def generate(request: GenerateRequest, api_key: str = Depends(get_api_key)): + start_time = time.time() + tokenizer = Tokenizer(model=ALBERT_MODEL) + llm_runner = LLMRunner( + llm_api_token=os.getenv("ALBERT_API_KEY", ""), + llm_model=ALBERT_MODEL, + llm_url=ALBERT_ENDPOINT, + ) + + try: + response = await llm_runner.chat_with_full_document( + chat_history=request.chat_history, + prompt=request.system_prompt, + ) + + chat_history_str = " ".join( + [msg.get("content", "") for msg in request.chat_history] + ) + + return GenerateResponse( + time=time.time() - start_time, + text=response, + nb_token_input=tokenizer.compute_nb_tokens(chat_history_str), + nb_token_output=tokenizer.compute_nb_tokens(response), + ) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) diff --git a/srdt_analysis/api/schemas.py b/srdt_analysis/api/schemas.py new file mode 100644 index 0000000..74adc72 --- /dev/null +++ b/srdt_analysis/api/schemas.py @@ -0,0 +1,105 @@ +from typing import List, Optional + +from pydantic import BaseModel, Field, field_validator + +from srdt_analysis.constants import COLLECTION_IDS +from srdt_analysis.models import ( + CHUNK_ID, + ID, + CollectionName, + UserLLMMessage, +) + + +class AnonymizeRequest(BaseModel): + user_question: str + anonymization_prompt: Optional[str] = None # TODO : to be removed in the future + + +class AnonymizeResponse(BaseModel): + time: float + anonymized_question: str + nb_token_input: int + nb_token_output: int + + +class RephraseRequest(BaseModel): + question: str + rephrasing_prompt: str # TODO : to be removed in the future + queries_splitting_prompt: Optional[str] = None # TODO : to be removed in the future + + +class RephraseResponse(BaseModel): + time: float + rephrased_question: str + queries: Optional[List[str]] = None + nb_token_input: int + nb_token_output: int + + +class SearchOptions(BaseModel): + top_K: int = Field(default=20) + threshold: float = Field(default=0.7, ge=0.0, le=1.0) + collections: List[str] = Field(default=COLLECTION_IDS) + + @field_validator("collections") + @classmethod + def validate_collections(cls, collections): + if collections is not None: + invalid_collections = [c for c in collections if c not in COLLECTION_IDS] + if invalid_collections: + raise ValueError( + f"Invalid collection IDs: {invalid_collections}. Must be one of {COLLECTION_IDS}" + ) + return collections + + +class SearchRequest(BaseModel): + prompts: List[str] = Field(max_length=10) + options: SearchOptions = Field(default_factory=SearchOptions) + + @classmethod + def model_validate( + cls, + obj, + *, + strict: Optional[bool] = None, + from_attributes: Optional[bool] = None, + context: Optional[dict] = None, + ): + if isinstance(obj, dict) and obj.get("options") is None: + obj["options"] = {} + return super().model_validate( + obj, strict=strict, from_attributes=from_attributes, context=context + ) + + +class ChunkMetadata(BaseModel): + title: str + url: str + document_id: ID + source: CollectionName + + +class ChunkResult(BaseModel): + score: float + content: str + id_chunk: CHUNK_ID + metadata: ChunkMetadata + + +class SearchResponse(BaseModel): + time: float + top_chunks: List[ChunkResult] + + +class GenerateRequest(BaseModel): + chat_history: List[UserLLMMessage] + system_prompt: str # TODO : to be removed in the future + + +class GenerateResponse(BaseModel): + time: float + text: str + nb_token_input: int + nb_token_output: int diff --git a/srdt_analysis/collections.py b/srdt_analysis/collections.py index a1afda8..8e67aa4 100644 --- a/srdt_analysis/collections.py +++ b/srdt_analysis/collections.py @@ -1,10 +1,10 @@ import json +import os import time from io import BytesIO import httpx -from srdt_analysis.albert import AlbertBase from srdt_analysis.constants import ( ALBERT_ENDPOINT, COLLECTIONS_UPLOAD_BATCH_SIZE, @@ -17,11 +17,21 @@ CollectionName, DocumentData, ListOfDocumentData, - RAGChunkSearchResult, + RankedChunk, ) -class Collections(AlbertBase): +class AlbertCollectionHandler: + def __init__(self): + self.api_key = os.getenv("ALBERT_API_KEY") + if not self.api_key: + raise ValueError( + "API key must be provided either in constructor or as environment variable" + ) + self.headers = { + "Authorization": f"Bearer {self.api_key}", + } + def _create(self, collection_name: CollectionName, model: str) -> COLLECTION_ID: payload = {"name": collection_name, "model": model} response = httpx.post( @@ -30,13 +40,13 @@ def _create(self, collection_name: CollectionName, model: str) -> COLLECTION_ID: return response.json()["id"] def create(self, collection_name: CollectionName, model: str) -> COLLECTION_ID: - collections = self.list() + collections = self.list_collections() for collection in collections: if collection["name"] == collection_name: self.delete(collection["id"]) return self._create(collection_name, model) - def list(self) -> AlbertCollectionsList: + def list_collections(self) -> AlbertCollectionsList: try: response = httpx.get( f"{ALBERT_ENDPOINT}/v1/collections", headers=self.headers @@ -54,7 +64,7 @@ def delete(self, id_collection: str) -> None: response.raise_for_status() def delete_all(self, collection_name: CollectionName) -> None: - collections = self.list() + collections = self.list_collections() for collection in collections: if collection["name"] == collection_name: self.delete(collection["id"]) @@ -66,7 +76,7 @@ def search( id_collections: COLLECTIONS_ID, k: int = 5, score_threshold: float = 0, - ) -> RAGChunkSearchResult: + ) -> list[RankedChunk]: response = httpx.post( f"{ALBERT_ENDPOINT}/v1/search", headers=self.headers, @@ -77,7 +87,8 @@ def search( "score_threshold": score_threshold, }, ) - return response.json() + result = response.json() + return result.get("data", []) def upload( self, diff --git a/srdt_analysis/constants.py b/srdt_analysis/constants.py index 5b074e8..39f6c28 100644 --- a/srdt_analysis/constants.py +++ b/srdt_analysis/constants.py @@ -1,30 +1,94 @@ ALBERT_ENDPOINT = "https://albert.api.etalab.gouv.fr" MODEL_VECTORISATION = "BAAI/bge-m3" -LLM_MODEL = "meta-llama/Meta-Llama-3.1-70B-Instruct" +ALBERT_MODEL = "meta-llama/Meta-Llama-3.1-70B-Instruct" CHUNK_SIZE = 4096 CHUNK_OVERLAP = 0 COLLECTIONS_UPLOAD_BATCH_SIZE = 50 COLLECTIONS_UPLOAD_DELAY_IN_SECONDS = 5 BASE_URL_CDTN = "https://code.travail.gouv.fr" +BASE_API_URL = "/api/v1" +COLLECTION_IDS = [ + "5755cf5f-1cb5-4ec6-a076-21047d069578", # information + "0576c752-f097-403e-b2be-d6d806c3848a", # page_fiche_ministere_travail + "0be5059b-762f-48ba-a8f0-fe10e81455c8", # code_du_travail + "f8d66426-5c54-4503-aa30-a3abc19453d5", # fiches_service_public + "d03df69b-9387-4359-80db-7d73f2b6f04a", # contributions +] LLM_ANSWER_PROMPT = """ - Instructions - Rôle et objectif - L'assistant juridique est conçu pour répondre aux questions des usagers (salariés et employeurs du secteur privé) en France concernant le droit du travail, conformément aux normes et règlements du droit français. L'objectif est de fournir des informations juridiques précises et contextualisées, en utilisant des extraits de documents pertinents pour soutenir chaque réponse. - - Lignes directrices - A chaque fois que l'utilisateur sollicite l'assistant juridique, le chatbot va procéder ainsi : - - Reformuler la demande de l’utilisateur en deux parties : le contexte, et les points juridiques à traiter. Puis y répondre. - - Pour chaque point, citer la source juridique utilisée dans la base de connaissance externe, ou bien citer le passage correspondant - - Commencer par citer le principe général de droit qui répond au point, puis aller dans le détail en distinguant les cas particuliers, ou en posant des questions à l'utilisateur pour avoir plus de précisions quand cela est nécessaire - - Conclure en synthétisant la réponse et si nécessaire, en indiquant les prochaines étapes à suivre, ainsi qu’en posant des questions qui vont permettre de compléter la réponse - - Limites et contraintes - Il faut faire attention à ce que toutes les réponses aient une question. Mais si une question n'a pas de réponse, il ne faut pas inventer et plutôt simplement indiquer que la réponse n'a pas été trouvée. Si tu as besoin d’informations supplémentaires pour répondre à une question, tu demandes simplement ces informations à l’usager qui te les donnera. - - Style et ton. - Répondre dans un langage clair et accessible. + # Instructions + ## Rôle et objectif + L'assistant juridique est conçu pour répondre aux questions des usagers (salariés et employeurs du secteur privé) en France concernant le droit du travail, conformément aux normes et règlements du droit français. L'objectif est de fournir des informations juridiques précises et contextualisées, en utilisant des extraits de documents pertinents pour soutenir chaque réponse. + ## Lignes directrices + A chaque fois que l'utilisateur sollicite l'assistant juridique, le chatbot va procéder ainsi : + - Reformuler la demande de l’utilisateur en deux parties : le contexte, et les points juridiques à traiter. Puis y répondre. + - Pour chaque point, citer la source juridique utilisée dans la base de connaissance externe (qui se trouve ci-dessous), ou bien citer le passage correspondant. Commencer par citer le principe général de droit qui répond au point, puis aller dans le détail en distinguant les cas particuliers, ou en posant des questions à l'utilisateur pour avoir plus de précisions quand cela est nécessaire + - Conclure en synthétisant la réponse et si nécessaire, en indiquant les prochaines étapes à suivre, ainsi qu’en posant des questions qui vont permettre de compléter la réponse + ## Limites et contraintes + Il faut faire attention à ce que toutes les réponses aient une question. Mais si une question n'a pas de réponse, il ne faut pas inventer et plutôt simplement indiquer que la réponse n'a pas été trouvée. Si tu as besoin d’informations supplémentaires pour répondre à une question, tu demandes simplement ces informations à l’usager qui te les donnera. + ## Style et ton. + Répondre dans un langage clair et accessible. + ## Base de connaissance externe. + Voici les extraits de documents que tu peux utiliser, avec cette structure : titre du document, contenu du document, url_source. """ +LLM_ANONYMIZATION_PROMPT = """ + # Instructions + Anonymise le texte suivant en remplaçant toutes les informations personnelles par des balises standard, sauf le titre de poste et la civilité, qui doivent rester inchangés. Utilise [PERSONNE] pour les noms de personnes, [EMAIL] pour les adresses email, [TELEPHONE] pour les numéros de téléphone, [ADRESSE] pour les adresses physiques, [DATE] pour les dates, et [IDENTIFIANT] pour tout identifiant unique ou sensible. + # Exemple + - Texte : + "Bonjour, je suis employé chez ABC Construction à Lyon en tant que chef de chantier. Mon responsable, M. Dupont, m’a demandé de travailler deux week-ends consécutifs. J’aimerais savoir si c’est légal, car il n’a pas mentionné de rémunération supplémentaire. Mon numéro de salarié est 123456. Pouvez-vous me renseigner sur mes droits concernant les jours de repos et les heures supplémentaires ? Merci." + - Texte anonymisé : + Bonjour, je suis employé chez [ENTREPRISE] en tant que chef de chantier. Mon responsable, [PERSONNE], m’a demandé de travailler deux week-ends consécutifs. J’aimerais savoir si c’est légal, car il n’a pas mentionné de rémunération supplémentaire. Mon numéro de salarié est [IDENTIFIANT]. Pouvez-vous me renseigner sur mes droits concernant les jours de repos et les heures supplémentaires ? Merci. + """ +LLM_REPHRASING_PROMPT = """ + # Instructions + ## Objectif + L'assistant juridique a pour mission de reformuler des questions juridiques relatives au droit du travail posées par les salariés ou employeurs du secteur privé en France. L’objectif est de reformuler la question d'origine de façon claire sans perdre les détails, et de mettre en avant les points juridiques pour qu'un agent public puisse y répondre plus efficacement. Attention, dans la reformulation, c'est l'usager qui est à la première personne et non l'assistant (comme c'est le cas dans l'exemple plus bas). + ## Etape + - Identification des points juridiques : Repérer tous les points qui demandent une réponse juridique dans la question de l'utilisateur. Ne pas hésiter à anticiper et mentionner des questions juridiques à laquelle l’utilisateur n’aurait pas pensé. + - Reformulation claire et structurée : Formuler la question en deux parties : + Un paragraphe de contexte dans laquelle la personne raconte sa situation + Une synthèse des questions juridiques que soulève la personne. Cette synthèse reprend donc l’ensemble des points juridiques identifiés. + - Exemple + Question initiale : + "Bonjour, + J’ai effectuée un remplacement en CDD dans une micro crèche, mon contrat étant fini depuis le 22 septembre 2023 je suis toujours en attente de mon salaire. Après plusieurs relance auprès de la directrice aucun versement n’a été fait. J’aimerais savoir si elle est en droit de me faire patienter comme cela ou sinon qu’elle sont les délais pour qu’elle puisse me verser mon salaire. + Cordialement." + Reformulation attendue : + "Bonjour, + J’ai effectuée un remplacement en CDD dans une micro crèche, mon contrat étant fini depuis le 22 septembre 2023 je suis toujours en attente de mon salaire. Après plusieurs relance auprès de la directrice aucun versement n’a été fait. + Mes questions sont : + La directrice est-elle en droit de retarder le paiement de mon salaire ? + Quels sont les délais légaux pour qu’un employeur verse le salaire d’un employé à la fin d’un CDD ? + Quels sont les recours applicables et la procédure à suivre ? + Cordialement." + """ +LLM_SPLIT_MULTIPLE_QUERIES_PROMPT = """ + ## Objectif + Tu es chargé d'identifier toutes les questions qui ont été posées dans la fenêtre de chat, et de les renvoyer sous format d'un document json + ## Etape + - tu identifies toutes les questions qui sont formulées dans la fenêtre de chat + - tu les enregistres dans des variables (sans changer un mot) "question_i" où + i est le numéro de la question, et i va de 1 à N s'il y a N questions identifiées + ## Format de sortie + Format de json attendu en sortie (pour 2 questions) + { + "question_1": texte_question_1, + "question_2": texte_question_2, + } + ## Point d'attention + Je veux que la réponse que tu fais soit directement réutilisable dans un programme de code. Aussi je ne veux aucun caractère supplémentaire de type "/n", je veux seulement le json en sortie et absolument rien d'autre. + ## Exemple + - texte de la fenêtre de chat : "Bonjour, + Actuellement membre du CSE, de la CSSCT et de la RPX, ma direction souhaite me changer de roulement à compter de janvier. Lors de notre première rencontre non officielle, il a été dit que mes absences mettaient mes collègues en souffrance en raison du grand nombre de remplaçants et qu'il fallait séparer un binôme sur l'équipe inverse. Lors d'une seconde rencontre non officielle, ma direction m'a indiqué qu'ils n'avaient rien à reprocher à mon travail, mais qu'il fallait redynamiser un peu et continuer à séparer le binôme sur les deux équipes. + Je travaille en roulement amplitude de 12h, avec 10h travaillées et 2h de pause. Un week-end sur 2, et si elle me change de roulement, je travaillerais complètement à l'inverse de mon roulement actuel, ce qui rendrait impossible pour moi d'assurer la garde de mon enfant. Par conséquent, je risque de ne plus pouvoir venir travailler. + Ma direction souhaite effectuer ce changement début janvier, mais à ce jour, je n'ai reçu qu'une information officieuse, aucun entretien officiel ou courrier ne m'a été adressé. + Mes questions sont : + En tant que salariée protégée (membre du CSE, de la CSSCT et de la RPX), ma direction a-t-elle le droit de modifier mon roulement de travail ? + Quelles sont mes recours et la procédure à suivre si je considère que ce changement n'est pas légitime et impacte ma vie familiale ? + Je souhaite obtenir ces informations afin de les lui expliquer, avant d'envisager des démarches plus formelles auprès des services compétents. + Merci d'avance pour votre retour." + - réponse attendue : + {"question_1" : "En tant que salariée protégée (membre du CSE, de la CSSCT et de la RPX), ma direction a-t-elle le droit de modifier mon roulement de travail ?", + "question_2" : "Quelles sont mes recours et la procédure à suivre si je considère que ce changement n'est pas légitime et impacte ma vie familiale ?" + } +""" diff --git a/srdt_analysis/data_exploiter.py b/srdt_analysis/data_exploiter.py index dd81e72..63ba9d1 100644 --- a/srdt_analysis/data_exploiter.py +++ b/srdt_analysis/data_exploiter.py @@ -1,6 +1,5 @@ -from srdt_analysis.albert import AlbertBase from srdt_analysis.chunker import Chunker -from srdt_analysis.collections import Collections +from srdt_analysis.collections import AlbertCollectionHandler from srdt_analysis.constants import BASE_URL_CDTN, MODEL_VECTORISATION from srdt_analysis.logger import Logger from srdt_analysis.models import ( @@ -17,8 +16,7 @@ class BaseDataExploiter: def __init__(self): self.chunker = Chunker() - self.collections = Collections() - self.albert = AlbertBase() + self.collections = AlbertCollectionHandler() self.logger = Logger("BaseDataExploiter") def get_content(self, _doc: Document) -> FormattedTextContent: diff --git a/srdt_analysis/llm_client.py b/srdt_analysis/llm_client.py new file mode 100644 index 0000000..7884fc4 --- /dev/null +++ b/srdt_analysis/llm_client.py @@ -0,0 +1,81 @@ +import asyncio +from typing import Sequence, Union + +import httpx +from tenacity import ( + retry, + retry_if_exception_type, + stop_after_attempt, + wait_exponential, +) + +from srdt_analysis.logger import Logger +from srdt_analysis.models import ( + LLMChatPayload, + SystemLLMMessage, + UserLLMMessage, +) + + +class LLMClient: + def __init__(self, base_url, api_key, model): + super().__init__() + self.logger = Logger("LLMProcessor") + self.client = httpx.AsyncClient(timeout=30.0) + self.rate_limit = asyncio.Semaphore(10) + self.base_url = base_url + self.headers = { + "Authorization": f"Bearer {api_key}", + } + self.model = model + + @retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=1, max=10), + retry=retry_if_exception_type((httpx.HTTPError, ValueError)), + ) + async def _make_chat_completions_async( + self, + system_prompt: str, + chat_history: list[UserLLMMessage], + ) -> str: + async with self.rate_limit: + try: + messages: Sequence[Union[SystemLLMMessage, UserLLMMessage]] = [ + SystemLLMMessage(role="system", content=system_prompt), + ] + chat_history + + payload: LLMChatPayload = { + "messages": messages, + "model": self.model, + } + + response = await self.client.post( + f"{self.base_url}/v1/chat/completions", + headers=self.headers, + json=payload, + ) + response.raise_for_status() + + response_json = response.json() + return response_json["choices"][0]["message"]["content"] + + except httpx.HTTPStatusError as e: + self.logger.error( + f"HTTP error occurred: {e.response.status_code} - {e.response.text}" + ) + raise + except httpx.RequestError as e: + self.logger.error(f"Request error occurred: {str(e)}") + raise + except Exception as e: + self.logger.error(f"Unexpected error: {str(e)}") + raise + + async def generate_completions_async( + self, + system_prompt: str, + chat_history: list[UserLLMMessage], + ) -> str: + self.logger.info("Generating a chat completions answer") + return await self._make_chat_completions_async(system_prompt, chat_history) diff --git a/srdt_analysis/llm_processor.py b/srdt_analysis/llm_processor.py deleted file mode 100644 index ae8c6a3..0000000 --- a/srdt_analysis/llm_processor.py +++ /dev/null @@ -1,104 +0,0 @@ -import asyncio -import json -from typing import AsyncGenerator, Iterator - -import httpx -from tenacity import ( - retry, - retry_if_exception_type, - stop_after_attempt, - wait_exponential, -) - -from srdt_analysis.albert import AlbertBase -from srdt_analysis.constants import ALBERT_ENDPOINT, LLM_ANSWER_PROMPT, LLM_MODEL -from srdt_analysis.logger import Logger -from srdt_analysis.models import RAGChunkSearchResultEnriched - - -class LLMProcessor(AlbertBase): - def __init__(self): - super().__init__() - self.logger = Logger("LLMProcessor") - self.client = httpx.AsyncClient(timeout=30.0) - self.rate_limit = asyncio.Semaphore(10) - - @retry( - stop=stop_after_attempt(3), - wait=wait_exponential(multiplier=1, min=1, max=10), - retry=retry_if_exception_type((httpx.HTTPError, ValueError)), - ) - async def _make_request_stream_async( - self, - message: str, - system_prompt: str, - ) -> AsyncGenerator[str, None]: - async with self.rate_limit: - try: - messages = [{"role": "system", "content": system_prompt}] - messages.append({"role": "user", "content": message}) - - async with self.client.stream( - "POST", - f"{ALBERT_ENDPOINT}/v1/chat/completions", - headers=self.headers, - json={ - "messages": messages, - "model": LLM_MODEL, - "stream": True, - }, - ) as response: - response.raise_for_status() - async for line in response.aiter_lines(): - if line.startswith("data: ") and line.strip() != "data: [DONE]": - try: - chunk = line[len("data: ") :] - chunk_data = json.loads(chunk) - if ( - content := chunk_data["choices"][0] - .get("delta", {}) - .get("content") - ): - yield content - except json.JSONDecodeError as e: - self.logger.error(f"Failed to parse chunk: {e}") - continue - - except httpx.HTTPStatusError as e: - self.logger.error( - f"HTTP error occurred: {e.response.status_code} - {e.response.text}" - ) - raise - except httpx.RequestError as e: - self.logger.error(f"Request error occurred: {str(e)}") - raise - except Exception as e: - self.logger.error(f"Unexpected error: {str(e)}") - raise - - async def get_answer_stream_async( - self, - message: str, - documents: RAGChunkSearchResultEnriched, - ) -> AsyncGenerator[str, None]: - self.logger.info("Generating streaming answer based on documents") - document_contents = [item["content"] for item in documents["data"]] - system_prompt = f"{LLM_ANSWER_PROMPT}\n Mes documents sont :\n{'\n'.join(document_contents)}\n" - async for token in self._make_request_stream_async(message, system_prompt): - yield token - - def get_answer_stream( - self, - message: str, - documents: RAGChunkSearchResultEnriched, - ) -> Iterator[str]: - async def collect_tokens(): - tokens = [] - async for token in self.get_answer_stream_async( - message, - documents, - ): - tokens.append(token) - return tokens - - return iter(asyncio.run(collect_tokens())) diff --git a/srdt_analysis/llm_runner.py b/srdt_analysis/llm_runner.py new file mode 100644 index 0000000..f5969be --- /dev/null +++ b/srdt_analysis/llm_runner.py @@ -0,0 +1,75 @@ +from typing import Optional, Tuple + +from srdt_analysis.collections import AlbertCollectionHandler +from srdt_analysis.constants import ( + LLM_ANONYMIZATION_PROMPT, + LLM_REPHRASING_PROMPT, + LLM_SPLIT_MULTIPLE_QUERIES_PROMPT, +) +from srdt_analysis.llm_client import LLMClient +from srdt_analysis.models import ( + UserLLMMessage, +) + + +class LLMRunner: + collections: AlbertCollectionHandler + llm_processor: LLMClient + + def __init__(self, llm_url: str, llm_api_token: str, llm_model: str): + self.collections = AlbertCollectionHandler() + self.llm_processor = LLMClient(llm_url, llm_api_token, llm_model) + + async def anonymize( + self, + user_message: str, + optional_prompt: Optional[str] = None, + ) -> str: + prompt = ( + optional_prompt if optional_prompt is not None else LLM_ANONYMIZATION_PROMPT + ) + result = await self.llm_processor.generate_completions_async( + prompt, + [UserLLMMessage(role="user", content=user_message)], + ) + return result + + async def rephrase_and_split( + self, + question: str, + optional_rephrasing_prompt: Optional[str] = None, + optional_queries_splitting_prompt: Optional[str] = None, + ) -> Tuple[str, Optional[list[str]]]: + rephrasing_prompt = ( + optional_rephrasing_prompt + if optional_rephrasing_prompt is not None + else LLM_REPHRASING_PROMPT + ) + queries_splitting_prompt = ( + optional_queries_splitting_prompt + if optional_queries_splitting_prompt is not None + else LLM_SPLIT_MULTIPLE_QUERIES_PROMPT + ) + rephrased_question = await self.llm_processor.generate_completions_async( + rephrasing_prompt, + [UserLLMMessage(role="user", content=question)], + ) + + queries = await self.llm_processor.generate_completions_async( + queries_splitting_prompt, + [UserLLMMessage(role="user", content=rephrased_question)], + ) + + query_list = [q.strip() for q in queries.split("\n") if q.strip()] + + return rephrased_question, query_list + + async def chat_with_full_document( + self, + chat_history: list[UserLLMMessage], + prompt: str, + ) -> str: + return await self.llm_processor.generate_completions_async( + system_prompt=prompt, + chat_history=chat_history, + ) diff --git a/srdt_analysis/mapper.py b/srdt_analysis/mapper.py index aaa20d2..1b1de00 100644 --- a/srdt_analysis/mapper.py +++ b/srdt_analysis/mapper.py @@ -9,8 +9,8 @@ from srdt_analysis.models import ( CollectionName, DocumentsList, - RAGChunkSearchResult, - RAGChunkSearchResultEnriched, + EnrichedRankedChunk, + RankedChunk, ) @@ -26,32 +26,32 @@ def __init__(self, documents_by_source: dict[CollectionName, DocumentsList]): all_documents = [doc for docs in documents_by_source.values() for doc in docs] self.doc_map = {doc.cdtn_id: doc for doc in all_documents} - def get_exploiter(self, source: CollectionName) -> BaseDataExploiter: + def _get_exploiter(self, source: CollectionName) -> BaseDataExploiter: exploiter = self.source_exploiters.get(source) if not exploiter: raise ValueError(f"No exploiter found for source: {source}") return exploiter - def get_original_docs( + def enrich_chunks( self, - rag_response: RAGChunkSearchResult, - ) -> RAGChunkSearchResultEnriched: - enriched_data = [] - for item in rag_response["data"]: - id = item["chunk"]["metadata"]["id"] - source = item["chunk"]["metadata"]["source"] + chunks: list[RankedChunk], + ) -> list[EnrichedRankedChunk]: + enriched_chunks = [] + for scored_chunk in chunks: + id = scored_chunk["chunk"]["metadata"]["id"] + source = scored_chunk["chunk"]["metadata"]["source"] if id in self.doc_map: - enriched_data.append( + enriched_chunks.append( { - "score": item["score"], - "chunk": item["chunk"], + "score": scored_chunk["score"], + "chunk": scored_chunk["chunk"], "document": self.doc_map[id], - "content": self.get_exploiter(source).get_content( + "content": self._get_exploiter(source).get_content( self.doc_map[id] ), } ) - enriched_data.sort(key=lambda x: x["score"], reverse=True) + enriched_chunks.sort(key=lambda x: x["score"], reverse=True) - return {"object": rag_response["object"], "data": enriched_data} + return enriched_chunks diff --git a/srdt_analysis/models.py b/srdt_analysis/models.py index 8149ca1..e9f148b 100644 --- a/srdt_analysis/models.py +++ b/srdt_analysis/models.py @@ -1,7 +1,7 @@ import json from dataclasses import dataclass from datetime import datetime -from typing import Any, Dict, Literal, Optional, TypedDict +from typing import Any, Dict, Literal, Optional, Sequence, TypedDict, Union import asyncpg @@ -16,6 +16,7 @@ ChunkerContentType = Literal["markdown", "html", "character_recursive"] +CHUNK_ID = str ID = str HTML = str PlainText = str @@ -164,31 +165,19 @@ class Chunk(TypedDict): @dataclass -class RAGChunkData(TypedDict): +class RankedChunk(TypedDict): score: float chunk: Chunk @dataclass -class RAGChunkSearchResult(TypedDict): - object: str - data: list[RAGChunkData] - - -@dataclass -class RAGChunkDataEnriched(TypedDict): +class EnrichedRankedChunk(TypedDict): score: float chunk: Chunk document: Document content: str -@dataclass -class RAGChunkSearchResultEnriched(TypedDict): - object: str - data: list[RAGChunkDataEnriched] - - # Albert Collection @dataclass class AlbertCollectionData(TypedDict): @@ -203,3 +192,19 @@ class AlbertCollectionData(TypedDict): AlbertCollectionsList = list[AlbertCollectionData] + + +# LLM +class SystemLLMMessage(TypedDict): + role: Literal["system", "user", "assistant"] + content: str + + +class UserLLMMessage(TypedDict): + role: Literal["user", "assistant"] + content: str + + +class LLMChatPayload(TypedDict): + model: str + messages: Sequence[Union[SystemLLMMessage, UserLLMMessage]] diff --git a/srdt_analysis/database_manager.py b/srdt_analysis/postgresql_manager.py similarity index 97% rename from srdt_analysis/database_manager.py rename to srdt_analysis/postgresql_manager.py index 6317c66..f702276 100644 --- a/srdt_analysis/database_manager.py +++ b/srdt_analysis/postgresql_manager.py @@ -8,7 +8,7 @@ from srdt_analysis.models import CollectionName, Document, DocumentsList -class DatabaseManager: +class PostgreSQLManager: def __init__(self): self.pool: Optional[asyncpg.Pool] = None @@ -61,5 +61,5 @@ async def fetch_sources( def get_data( sources: Sequence[CollectionName], ) -> dict[CollectionName, DocumentsList]: - db = DatabaseManager() + db = PostgreSQLManager() return asyncio.run(db.fetch_sources(sources)) diff --git a/srdt_analysis/__main__.py b/srdt_analysis/scripts/ingest.py similarity index 51% rename from srdt_analysis/__main__.py rename to srdt_analysis/scripts/ingest.py index 5e45c40..d7d64f6 100644 --- a/srdt_analysis/__main__.py +++ b/srdt_analysis/scripts/ingest.py @@ -1,6 +1,5 @@ from dotenv import load_dotenv -from srdt_analysis.collections import Collections from srdt_analysis.data_exploiter import ( ArticlesCodeDuTravailExploiter, FichesMTExploiter, @@ -8,27 +7,12 @@ PageInfosExploiter, PagesContributionsExploiter, ) -from srdt_analysis.database_manager import get_data -from srdt_analysis.llm_processor import LLMProcessor -from srdt_analysis.mapper import Mapper +from srdt_analysis.postgresql_manager import get_data load_dotenv() -QUESTION = "Combien de jours de congé payé par mois de travail effectif ?" -COLLECTION_IDS = [ - "5755cf5f-1cb5-4ec6-a076-21047d069578", # information - "0576c752-f097-403e-b2be-d6d806c3848a", # page_fiche_ministere_travail - "0be5059b-762f-48ba-a8f0-fe10e81455c8", # code_du_travail - "f8d66426-5c54-4503-aa30-a3abc19453d5", # fiches_service_public - "d03df69b-9387-4359-80db-7d73f2b6f04a", # contributions -] - -def main(): - ingest() - - -def ingest(): +def start(): data = get_data( [ "information", @@ -60,29 +44,5 @@ def ingest(): ) -def run_llm(): - data = get_data( - [ - "information", - "code_du_travail", - "page_fiche_ministere_travail", - "fiches_service_public", - ] - ) - collections = Collections() - rag_response = collections.search( - QUESTION, - COLLECTION_IDS, - ) - mapper = Mapper(data) - data_to_send_to_llm = mapper.get_original_docs(rag_response) - llm_processor = LLMProcessor() - for token in llm_processor.get_answer_stream( - QUESTION, - data_to_send_to_llm, - ): - print(token, end="", flush=True) - - if __name__ == "__main__": - main() + start() diff --git a/srdt_analysis/tokenizer.py b/srdt_analysis/tokenizer.py new file mode 100644 index 0000000..8d2beb1 --- /dev/null +++ b/srdt_analysis/tokenizer.py @@ -0,0 +1,16 @@ +import os + +from transformers import AutoTokenizer + + +class Tokenizer: + def __init__(self, model: str): + if not os.getenv("HF_API_TOKEN"): + raise ValueError("HF_API_TOKEN not provided or found in environment") + self._tokenizer = AutoTokenizer.from_pretrained( + model, token=os.getenv("HF_API_TOKEN") + ) + + def compute_nb_tokens(self, text: str) -> int: + tokens = self._tokenizer.encode(text) + return len(tokens)