diff --git a/pyproject.toml b/pyproject.toml index 82cec65..124176f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,6 +14,7 @@ pandas = "^2.2.3" langchain-text-splitters = "^0.3.2" detect-secrets = "^1.5.0" pre-commit = "^4.0.1" +lxml = "^5.3.0" [tool.poetry.group.dev.dependencies] pyright = "^1.1.389" diff --git a/srdt_analysis/__main__.py b/srdt_analysis/__main__.py index 87318e6..5e45c40 100644 --- a/srdt_analysis/__main__.py +++ b/srdt_analysis/__main__.py @@ -6,6 +6,7 @@ FichesMTExploiter, FichesSPExploiter, PageInfosExploiter, + PagesContributionsExploiter, ) from srdt_analysis.database_manager import get_data from srdt_analysis.llm_processor import LLMProcessor @@ -15,10 +16,11 @@ QUESTION = "Combien de jours de congé payé par mois de travail effectif ?" COLLECTION_IDS = [ - "4462ceb9-6da9-4f76-8a63-b87d4cc5afa0", # information - "fa1d5d19-ec81-493a-843d-b33ce438f630", # page_fiche_ministere_travail - "ba380a00-660b-4b49-8a77-7b8b389c3200", # code_du_travail - "8dfca31c-994b-41cd-b5d5-c12231eee5d9", # fiches_service_public + "5755cf5f-1cb5-4ec6-a076-21047d069578", # information + "0576c752-f097-403e-b2be-d6d806c3848a", # page_fiche_ministere_travail + "0be5059b-762f-48ba-a8f0-fe10e81455c8", # code_du_travail + "f8d66426-5c54-4503-aa30-a3abc19453d5", # fiches_service_public + "d03df69b-9387-4359-80db-7d73f2b6f04a", # contributions ] @@ -33,8 +35,13 @@ def ingest(): "code_du_travail", "page_fiche_ministere_travail", "fiches_service_public", + "contributions", ] ) + page_contribs_exploiter = PagesContributionsExploiter() + page_contribs_exploiter.process_documents( + data["contributions"], "contributions", "html" + ) page_infos_exploiter = PageInfosExploiter() page_infos_exploiter.process_documents( data["information"], "information", "markdown" diff --git a/srdt_analysis/collections.py b/srdt_analysis/collections.py index 1363678..a1afda8 100644 --- a/srdt_analysis/collections.py +++ b/srdt_analysis/collections.py @@ -97,6 +97,7 @@ def upload( "id": dt["cdtn_id"], "url": dt["url"], "source": dt["source"], + "title": dt["title"], }, } ) @@ -126,7 +127,6 @@ def upload( response.raise_for_status() - if i + COLLECTIONS_UPLOAD_BATCH_SIZE < len(result): - time.sleep(COLLECTIONS_UPLOAD_DELAY_IN_SECONDS) + time.sleep(COLLECTIONS_UPLOAD_DELAY_IN_SECONDS) return diff --git a/srdt_analysis/constants.py b/srdt_analysis/constants.py index 5966234..5b074e8 100644 --- a/srdt_analysis/constants.py +++ b/srdt_analysis/constants.py @@ -4,7 +4,7 @@ CHUNK_SIZE = 4096 CHUNK_OVERLAP = 0 COLLECTIONS_UPLOAD_BATCH_SIZE = 50 -COLLECTIONS_UPLOAD_DELAY_IN_SECONDS = 10 +COLLECTIONS_UPLOAD_DELAY_IN_SECONDS = 5 BASE_URL_CDTN = "https://code.travail.gouv.fr" LLM_ANSWER_PROMPT = """ Instructions diff --git a/srdt_analysis/data_exploiter.py b/srdt_analysis/data_exploiter.py index ee3b75d..dd81e72 100644 --- a/srdt_analysis/data_exploiter.py +++ b/srdt_analysis/data_exploiter.py @@ -31,16 +31,11 @@ def process_documents( chunker_content_type: ChunkerContentType, ) -> ResultProcessDocumentType: results: list[DocumentData] = [] - self.logger.info(f"Number of articles to be processed: {len(data)}") - for doc in data: content = self.get_content(doc) - chunks = self.chunker.split(content, chunker_content_type) - doc_data = self.create_document_data(doc, content, chunks) results.append(doc_data) - id = self.collections.create(collection_name, MODEL_VECTORISATION) self.logger.info( f"Uploading {len(results)} documents from {collection_name} in collection_id {id}" @@ -64,11 +59,12 @@ def create_document_data(self, doc, content, content_chunked) -> DocumentData: } def _get_path_from_collection_name(self, collection_name: CollectionName) -> str: - mapping = { + mapping: dict[CollectionName, str] = { "code_du_travail": "code-du-travail", "fiches_service_public": "fiche-service-public", "page_fiche_ministere_travail": "fiche-ministere-travail", "information": "information", + "contributions": "contributions", } return mapping[collection_name] @@ -98,3 +94,8 @@ def get_content(self, doc: Document) -> FormattedTextContent: if block.get("type") == "markdown": markdown += block.get("markdown", "") return markdown + + +class PagesContributionsExploiter(BaseDataExploiter): + def get_content(self, doc: Document) -> FormattedTextContent: + return doc.document["content"] diff --git a/srdt_analysis/database_manager.py b/srdt_analysis/database_manager.py index 937c471..6317c66 100644 --- a/srdt_analysis/database_manager.py +++ b/srdt_analysis/database_manager.py @@ -33,12 +33,18 @@ async def get_connection(self): async with self.pool.acquire() as conn: yield conn - async def fetch_documents_by_source(self, source: str) -> DocumentsList: + async def fetch_documents_by_source(self, source: CollectionName) -> DocumentsList: async with self.get_connection() as conn: - result = await conn.fetch( - "SELECT * from public.documents WHERE source = $1 AND is_published = true AND is_available = true", - source, - ) + query = """ + SELECT * from public.documents + WHERE source = $1 + AND is_published = true + AND is_available = true + """ + if source == "contributions": + query += " AND document->>'content' IS NOT NULL AND document->>'idcc' = '0000'" + + result = await conn.fetch(query, source) return [Document.from_record(r) for r in result] async def fetch_sources( diff --git a/srdt_analysis/mapper.py b/srdt_analysis/mapper.py index 10d85e0..aaa20d2 100644 --- a/srdt_analysis/mapper.py +++ b/srdt_analysis/mapper.py @@ -4,6 +4,7 @@ FichesMTExploiter, FichesSPExploiter, PageInfosExploiter, + PagesContributionsExploiter, ) from srdt_analysis.models import ( CollectionName, @@ -20,6 +21,7 @@ def __init__(self, documents_by_source: dict[CollectionName, DocumentsList]): "page_fiche_ministere_travail": FichesMTExploiter(), "fiches_service_public": FichesSPExploiter(), "information": PageInfosExploiter(), + "contributions": PagesContributionsExploiter(), } all_documents = [doc for docs in documents_by_source.values() for doc in docs] self.doc_map = {doc.cdtn_id: doc for doc in all_documents} diff --git a/srdt_analysis/models.py b/srdt_analysis/models.py index a350efb..8149ca1 100644 --- a/srdt_analysis/models.py +++ b/srdt_analysis/models.py @@ -10,6 +10,7 @@ "fiches_service_public", "page_fiche_ministere_travail", "information", + "contributions", ] ChunkerContentType = Literal["markdown", "html", "character_recursive"] @@ -149,6 +150,7 @@ class ChunkMetadata(TypedDict): document_created_at: int id: ID source: CollectionName + title: str url: str collection: str