Skip to content

Commit

Permalink
feat(contributions): ajout du parsing et de l'exploitation des contri…
Browse files Browse the repository at this point in the history
…butions génériques avec un IDCC à 0000 (#43)

* fix(unpublished): remove documents which are not available on cdtn

* fix(fiche-sp): gestion des fiches sp sous format textuel

* feat(chunk): add no chunker option

* feat: add contribs

* fix(calcul): optimisation du calcul avec les nouvelles variables

* fix(retours): utilisation du modèle au niveau de la requête sql

* feat(title): ajout au niveau des metadatas

* fix(retours): simplification du code pour fetch en db
  • Loading branch information
maxgfr authored Dec 19, 2024
1 parent ffeb1d9 commit 1b4b43e
Show file tree
Hide file tree
Showing 8 changed files with 37 additions and 18 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ pandas = "^2.2.3"
langchain-text-splitters = "^0.3.2"
detect-secrets = "^1.5.0"
pre-commit = "^4.0.1"
lxml = "^5.3.0"

[tool.poetry.group.dev.dependencies]
pyright = "^1.1.389"
Expand Down
15 changes: 11 additions & 4 deletions srdt_analysis/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
FichesMTExploiter,
FichesSPExploiter,
PageInfosExploiter,
PagesContributionsExploiter,
)
from srdt_analysis.database_manager import get_data
from srdt_analysis.llm_processor import LLMProcessor
Expand All @@ -15,10 +16,11 @@

QUESTION = "Combien de jours de congé payé par mois de travail effectif ?"
COLLECTION_IDS = [
"4462ceb9-6da9-4f76-8a63-b87d4cc5afa0", # information
"fa1d5d19-ec81-493a-843d-b33ce438f630", # page_fiche_ministere_travail
"ba380a00-660b-4b49-8a77-7b8b389c3200", # code_du_travail
"8dfca31c-994b-41cd-b5d5-c12231eee5d9", # fiches_service_public
"5755cf5f-1cb5-4ec6-a076-21047d069578", # information
"0576c752-f097-403e-b2be-d6d806c3848a", # page_fiche_ministere_travail
"0be5059b-762f-48ba-a8f0-fe10e81455c8", # code_du_travail
"f8d66426-5c54-4503-aa30-a3abc19453d5", # fiches_service_public
"d03df69b-9387-4359-80db-7d73f2b6f04a", # contributions
]


Expand All @@ -33,8 +35,13 @@ def ingest():
"code_du_travail",
"page_fiche_ministere_travail",
"fiches_service_public",
"contributions",
]
)
page_contribs_exploiter = PagesContributionsExploiter()
page_contribs_exploiter.process_documents(
data["contributions"], "contributions", "html"
)
page_infos_exploiter = PageInfosExploiter()
page_infos_exploiter.process_documents(
data["information"], "information", "markdown"
Expand Down
4 changes: 2 additions & 2 deletions srdt_analysis/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def upload(
"id": dt["cdtn_id"],
"url": dt["url"],
"source": dt["source"],
"title": dt["title"],
},
}
)
Expand Down Expand Up @@ -126,7 +127,6 @@ def upload(

response.raise_for_status()

if i + COLLECTIONS_UPLOAD_BATCH_SIZE < len(result):
time.sleep(COLLECTIONS_UPLOAD_DELAY_IN_SECONDS)
time.sleep(COLLECTIONS_UPLOAD_DELAY_IN_SECONDS)

return
2 changes: 1 addition & 1 deletion srdt_analysis/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
CHUNK_SIZE = 4096
CHUNK_OVERLAP = 0
COLLECTIONS_UPLOAD_BATCH_SIZE = 50
COLLECTIONS_UPLOAD_DELAY_IN_SECONDS = 10
COLLECTIONS_UPLOAD_DELAY_IN_SECONDS = 5
BASE_URL_CDTN = "https://code.travail.gouv.fr"
LLM_ANSWER_PROMPT = """
Instructions
Expand Down
13 changes: 7 additions & 6 deletions srdt_analysis/data_exploiter.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,11 @@ def process_documents(
chunker_content_type: ChunkerContentType,
) -> ResultProcessDocumentType:
results: list[DocumentData] = []
self.logger.info(f"Number of articles to be processed: {len(data)}")

for doc in data:
content = self.get_content(doc)

chunks = self.chunker.split(content, chunker_content_type)

doc_data = self.create_document_data(doc, content, chunks)
results.append(doc_data)

id = self.collections.create(collection_name, MODEL_VECTORISATION)
self.logger.info(
f"Uploading {len(results)} documents from {collection_name} in collection_id {id}"
Expand All @@ -64,11 +59,12 @@ def create_document_data(self, doc, content, content_chunked) -> DocumentData:
}

def _get_path_from_collection_name(self, collection_name: CollectionName) -> str:
mapping = {
mapping: dict[CollectionName, str] = {
"code_du_travail": "code-du-travail",
"fiches_service_public": "fiche-service-public",
"page_fiche_ministere_travail": "fiche-ministere-travail",
"information": "information",
"contributions": "contributions",
}
return mapping[collection_name]

Expand Down Expand Up @@ -98,3 +94,8 @@ def get_content(self, doc: Document) -> FormattedTextContent:
if block.get("type") == "markdown":
markdown += block.get("markdown", "")
return markdown


class PagesContributionsExploiter(BaseDataExploiter):
def get_content(self, doc: Document) -> FormattedTextContent:
return doc.document["content"]
16 changes: 11 additions & 5 deletions srdt_analysis/database_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,18 @@ async def get_connection(self):
async with self.pool.acquire() as conn:
yield conn

async def fetch_documents_by_source(self, source: str) -> DocumentsList:
async def fetch_documents_by_source(self, source: CollectionName) -> DocumentsList:
async with self.get_connection() as conn:
result = await conn.fetch(
"SELECT * from public.documents WHERE source = $1 AND is_published = true AND is_available = true",
source,
)
query = """
SELECT * from public.documents
WHERE source = $1
AND is_published = true
AND is_available = true
"""
if source == "contributions":
query += " AND document->>'content' IS NOT NULL AND document->>'idcc' = '0000'"

result = await conn.fetch(query, source)
return [Document.from_record(r) for r in result]

async def fetch_sources(
Expand Down
2 changes: 2 additions & 0 deletions srdt_analysis/mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
FichesMTExploiter,
FichesSPExploiter,
PageInfosExploiter,
PagesContributionsExploiter,
)
from srdt_analysis.models import (
CollectionName,
Expand All @@ -20,6 +21,7 @@ def __init__(self, documents_by_source: dict[CollectionName, DocumentsList]):
"page_fiche_ministere_travail": FichesMTExploiter(),
"fiches_service_public": FichesSPExploiter(),
"information": PageInfosExploiter(),
"contributions": PagesContributionsExploiter(),
}
all_documents = [doc for docs in documents_by_source.values() for doc in docs]
self.doc_map = {doc.cdtn_id: doc for doc in all_documents}
Expand Down
2 changes: 2 additions & 0 deletions srdt_analysis/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
"fiches_service_public",
"page_fiche_ministere_travail",
"information",
"contributions",
]

ChunkerContentType = Literal["markdown", "html", "character_recursive"]
Expand Down Expand Up @@ -149,6 +150,7 @@ class ChunkMetadata(TypedDict):
document_created_at: int
id: ID
source: CollectionName
title: str
url: str
collection: str

Expand Down

0 comments on commit 1b4b43e

Please sign in to comment.