diff --git a/api/app/enhance.py b/api/app/enhance.py index a3228c5..b2172ae 100644 --- a/api/app/enhance.py +++ b/api/app/enhance.py @@ -1,10 +1,8 @@ -import logging import os from datetime import datetime -from typing import Any, Dict, List, Optional, Union -from urllib.parse import urlencode +from typing import Dict, List, Literal, Optional, Tuple, Union -import requests +from diffbot_kg import DiffbotEnhanceClient from utils import graph CATEGORY_THRESHOLD = 0.50 @@ -12,6 +10,7 @@ DIFF_TOKEN = os.environ["DIFFBOT_API_KEY"] +client = DiffbotEnhanceClient(DIFF_TOKEN) def get_datetime(value: Optional[Union[str, int, float]]) -> datetime: if not value: @@ -19,15 +18,15 @@ def get_datetime(value: Optional[Union[str, int, float]]) -> datetime: return datetime.fromtimestamp(float(value) / 1000.0) -def process_entities(entity: str, type: str) -> Dict[str, Any]: + +async def process_entities(entity: str, type: str) -> Tuple[str, List[Dict]]: """ Fetch relevant articles from Diffbot KG endpoint """ - search_host = "https://kg.diffbot.com/kg/v3/enhance?" - params = {"type": type, "name": entity, "token": DIFF_TOKEN} - encoded_query = urlencode(params) - url = f"{search_host}{encoded_query}" - return entity, requests.get(url).json() + params = {"type": type, "name": entity} + response = await client.enhance(params) + + return entity, response.entities def get_people_params(row: Dict) -> Optional[Dict]: diff --git a/api/app/importing.py b/api/app/importing.py index 6ba912e..fc39689 100644 --- a/api/app/importing.py +++ b/api/app/importing.py @@ -1,8 +1,8 @@ import logging import os -from typing import Any, Dict, List, Optional +from typing import Dict, List, Optional -import requests +from diffbot_kg import DiffbotSearchClient from utils import embeddings, text_splitter CATEGORY_THRESHOLD = 0.50 @@ -10,24 +10,31 @@ DIFF_TOKEN = os.environ["DIFFBOT_API_KEY"] +client = DiffbotSearchClient(token=DIFF_TOKEN) -def get_articles( - query: Optional[str], tag: Optional[str], size: int = 5, offset: int = 0 -) -> Dict[str, Any]: + +async def get_articles( + query: Optional[str], + tag: Optional[str], + size: int = 5, + offset: int = 0, +) -> List[Dict]: """ Fetch relevant articles from Diffbot KG endpoint """ + search_query = "type:Article language:en sortBy:date" + if query: + search_query += f' strict:text:"{query}"' + if tag: + search_query += f' tags.label:"{tag}"' + + params = {"query": search_query, "size": size, "offset": offset} + + logging.info(f"Fetching articles with params: {params}") + try: - search_host = "https://kg.diffbot.com/kg/v3/dql?" - search_query = f'query=type%3AArticle+strict%3Alanguage%3A"en"+sortBy%3Adate' - if query: - search_query += f'+text%3A"{query}"' - if tag: - search_query += f'+tags.label%3A"{tag}"' - url = ( - f"{search_host}{search_query}&token={DIFF_TOKEN}&from={offset}&size={size}" - ) - return requests.get(url).json() + response = await client.search(params) + return response.entities except Exception as ex: raise ex diff --git a/api/app/main.py b/api/app/main.py index e345a30..d5c1ba4 100644 --- a/api/app/main.py +++ b/api/app/main.py @@ -1,3 +1,4 @@ +import asyncio import logging import os from concurrent.futures import ThreadPoolExecutor @@ -20,7 +21,7 @@ ) # Multithreading for Diffbot API -MAX_WORKERS = min(os.cpu_count() * 5, 20) +MAX_WORKERS = min((os.cpu_count() or 1) * 5, 20) app = FastAPI() @@ -35,21 +36,24 @@ @app.post("/import_articles/") -def import_articles_endpoint(article_data: ArticleData) -> int: +async def import_articles_endpoint(article_data: ArticleData) -> int: logging.info(f"Starting to process article import with params: {article_data}") - if not article_data.query and not article_data.tag: + if not article_data.query and not article_data.category and not article_data.tag: raise HTTPException( - status_code=500, detail="Either `query` or `tag` must be provided" + status_code=500, + detail="Either `query` or `category` or `tag` must be provided", ) - data = get_articles(article_data.query, article_data.tag, article_data.size) - logging.info(f"Articles fetched: {len(data['data'])} articles.") + articles = await get_articles( + article_data.query, article_data.category, article_data.tag, article_data.size + ) + logging.info(f"Articles fetched: {len(articles)} articles.") try: - params = process_params(data) + params = process_params(articles) except Exception as e: # You could log the exception here if needed - raise HTTPException(status_code=500, detail=e) + raise HTTPException(status_code=500, detail=e) from e graph.query(import_cypher_query, params={"data": params}) - logging.info(f"Article import query executed successfully.") + logging.info("Article import query executed successfully.") return len(params) @@ -124,7 +128,7 @@ def fetch_unprocessed_count(count_data: CountData) -> int: @app.post("/enhance_entities/") -def enhance_entities(entity_data: EntityData) -> str: +async def enhance_entities(entity_data: EntityData) -> str: entities = graph.query( "MATCH (a:Person|Organization) WHERE a.processed IS NULL " "WITH a LIMIT toInteger($limit) " @@ -132,18 +136,33 @@ def enhance_entities(entity_data: EntityData) -> str: "AS label, collect(a.name) AS entities", params={"limit": entity_data.size}, ) - enhanced_data = [] - with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor: - # Submitting all tasks and creating a list of future objects - for row in entities: - futures = [ - executor.submit(process_entities, el, row["label"]) - for el in row["entities"] - ] - - for future in futures: - response = future.result() - enhanced_data.append(response) + enhanced_data = {} + + # Run the process_entities function in a TaskGroup + + queue = asyncio.Queue() + for row in entities: + for el in row["entities"]: + await queue.put((el, row["label"])) + + async def worker(): + while True: + el, label = await queue.get() + try: + response = await process_entities(el, label) + enhanced_data[response[0]] = response[1] + finally: + queue.task_done() + + tasks = [] + for _ in range(4): # Number of workers + tasks.append(asyncio.create_task(worker())) + + await queue.join() + + for task in tasks: + task.cancel() + store_enhanced_data(enhanced_data) return "Finished enhancing entities."