Skip to content

Commit

Permalink
Utilize diffbot-kg client
Browse files Browse the repository at this point in the history
  • Loading branch information
brendancsmith committed Sep 18, 2024
1 parent 7bc7c78 commit af4edfa
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 47 deletions.
19 changes: 9 additions & 10 deletions api/app/enhance.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,32 @@
import logging
import os
from datetime import datetime
from typing import Any, Dict, List, Optional, Union
from urllib.parse import urlencode
from typing import Dict, List, Literal, Optional, Tuple, Union

import requests
from diffbot_kg import DiffbotEnhanceClient
from utils import graph

CATEGORY_THRESHOLD = 0.50
params = []

DIFF_TOKEN = os.environ["DIFFBOT_API_KEY"]

client = DiffbotEnhanceClient(DIFF_TOKEN)

def get_datetime(value: Optional[Union[str, int, float]]) -> datetime:
if not value:
return value
return datetime.fromtimestamp(float(value) / 1000.0)


def process_entities(entity: str, type: str) -> Dict[str, Any]:

async def process_entities(entity: str, type: str) -> Tuple[str, List[Dict]]:
"""
Fetch relevant articles from Diffbot KG endpoint
"""
search_host = "https://kg.diffbot.com/kg/v3/enhance?"
params = {"type": type, "name": entity, "token": DIFF_TOKEN}
encoded_query = urlencode(params)
url = f"{search_host}{encoded_query}"
return entity, requests.get(url).json()
params = {"type": type, "name": entity}
response = await client.enhance(params)

return entity, response.entities


def get_people_params(row: Dict) -> Optional[Dict]:
Expand Down
37 changes: 22 additions & 15 deletions api/app/importing.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,40 @@
import logging
import os
from typing import Any, Dict, List, Optional
from typing import Dict, List, Optional

import requests
from diffbot_kg import DiffbotSearchClient
from utils import embeddings, text_splitter

CATEGORY_THRESHOLD = 0.50
params = []

DIFF_TOKEN = os.environ["DIFFBOT_API_KEY"]

client = DiffbotSearchClient(token=DIFF_TOKEN)

def get_articles(
query: Optional[str], tag: Optional[str], size: int = 5, offset: int = 0
) -> Dict[str, Any]:

async def get_articles(
query: Optional[str],
tag: Optional[str],
size: int = 5,
offset: int = 0,
) -> List[Dict]:
"""
Fetch relevant articles from Diffbot KG endpoint
"""
search_query = "type:Article language:en sortBy:date"
if query:
search_query += f' strict:text:"{query}"'
if tag:
search_query += f' tags.label:"{tag}"'

params = {"query": search_query, "size": size, "offset": offset}

logging.info(f"Fetching articles with params: {params}")

try:
search_host = "https://kg.diffbot.com/kg/v3/dql?"
search_query = f'query=type%3AArticle+strict%3Alanguage%3A"en"+sortBy%3Adate'
if query:
search_query += f'+text%3A"{query}"'
if tag:
search_query += f'+tags.label%3A"{tag}"'
url = (
f"{search_host}{search_query}&token={DIFF_TOKEN}&from={offset}&size={size}"
)
return requests.get(url).json()
response = await client.search(params)
return response.entities
except Exception as ex:
raise ex

Expand Down
63 changes: 41 additions & 22 deletions api/app/main.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import logging
import os
from concurrent.futures import ThreadPoolExecutor
Expand All @@ -20,7 +21,7 @@
)

# Multithreading for Diffbot API
MAX_WORKERS = min(os.cpu_count() * 5, 20)
MAX_WORKERS = min((os.cpu_count() or 1) * 5, 20)

app = FastAPI()

Expand All @@ -35,21 +36,24 @@


@app.post("/import_articles/")
def import_articles_endpoint(article_data: ArticleData) -> int:
async def import_articles_endpoint(article_data: ArticleData) -> int:
logging.info(f"Starting to process article import with params: {article_data}")
if not article_data.query and not article_data.tag:
if not article_data.query and not article_data.category and not article_data.tag:
raise HTTPException(
status_code=500, detail="Either `query` or `tag` must be provided"
status_code=500,
detail="Either `query` or `category` or `tag` must be provided",
)
data = get_articles(article_data.query, article_data.tag, article_data.size)
logging.info(f"Articles fetched: {len(data['data'])} articles.")
articles = await get_articles(
article_data.query, article_data.category, article_data.tag, article_data.size
)
logging.info(f"Articles fetched: {len(articles)} articles.")
try:
params = process_params(data)
params = process_params(articles)
except Exception as e:
# You could log the exception here if needed
raise HTTPException(status_code=500, detail=e)
raise HTTPException(status_code=500, detail=e) from e
graph.query(import_cypher_query, params={"data": params})
logging.info(f"Article import query executed successfully.")
logging.info("Article import query executed successfully.")
return len(params)


Expand Down Expand Up @@ -124,26 +128,41 @@ def fetch_unprocessed_count(count_data: CountData) -> int:


@app.post("/enhance_entities/")
def enhance_entities(entity_data: EntityData) -> str:
async def enhance_entities(entity_data: EntityData) -> str:
entities = graph.query(
"MATCH (a:Person|Organization) WHERE a.processed IS NULL "
"WITH a LIMIT toInteger($limit) "
"RETURN [el in labels(a) WHERE el <> '__Entity__' | el][0] "
"AS label, collect(a.name) AS entities",
params={"limit": entity_data.size},
)
enhanced_data = []
with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
# Submitting all tasks and creating a list of future objects
for row in entities:
futures = [
executor.submit(process_entities, el, row["label"])
for el in row["entities"]
]

for future in futures:
response = future.result()
enhanced_data.append(response)
enhanced_data = {}

# Run the process_entities function in a TaskGroup

queue = asyncio.Queue()
for row in entities:
for el in row["entities"]:
await queue.put((el, row["label"]))

async def worker():
while True:
el, label = await queue.get()
try:
response = await process_entities(el, label)
enhanced_data[response[0]] = response[1]
finally:
queue.task_done()

tasks = []
for _ in range(4): # Number of workers
tasks.append(asyncio.create_task(worker()))

await queue.join()

for task in tasks:
task.cancel()

store_enhanced_data(enhanced_data)
return "Finished enhancing entities."

Expand Down

0 comments on commit af4edfa

Please sign in to comment.