Skip to content

Commit

Permalink
Auto-sync-2024-08-23-02-42-37
Browse files Browse the repository at this point in the history
  • Loading branch information
s-jse authored Aug 23, 2024
1 parent b98c711 commit b7e6c9d
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 28 deletions.
2 changes: 1 addition & 1 deletion retrieval/create_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def commit_to_index(
desc="Indexing collection",
miniters=1e-6,
unit_scale=1,
unit=" Block",
unit=" Blocks",
dynamic_ncols=True,
smoothing=0,
total=collection_size,
Expand Down
2 changes: 1 addition & 1 deletion tasks/docker_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def wait_for_docker_container_to_be_ready(
timeout = 60
step_time = timeout // 10
elapsed_time = 0
logger.info("Waiting for the container '%s' to be ready...", container)
logger.info("Waiting for container '%s' to be ready...", container.name)

def is_ready():
container_status = docker_client.containers.get(container.id).status
Expand Down
82 changes: 58 additions & 24 deletions tasks/retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from huggingface_hub import snapshot_download
from invoke import task
from tqdm import tqdm
from typing import Optional

from tasks.docker_utils import (
start_embedding_docker_container,
Expand Down Expand Up @@ -213,7 +214,7 @@ def multithreaded_download(url: str, output_path: str, num_parts: int = 3) -> No


@task
def print_wikipedia_dump_date(c, date=None):
def print_wikipedia_dump_date(c, date: Optional[str] = None):
"""
Prints the specified date in a human-readable format.
Expand All @@ -239,6 +240,20 @@ def start_retriever(
use_onnx=DEFAULT_EMBEDDING_USE_ONNX,
retriever_port=DEFAULT_RETRIEVER_PORT,
):
"""
Starts the retriever server.
This task runs the retriever server, which is responsible for handling
retrieval requests. It uses Gunicorn to manage Uvicorn workers for asynchronous processing.
The retriever server sends search requests to the Qdrant docker container.
Args:
- c: Context, automatically passed by invoke.
- collection_name (str): The name of the Qdrant collection to query. Defaults to DEFAULT_QDRANT_COLLECTION_NAME.
- embedding_model (str): The HuggingFace ID of the embedding model to use for retrieval. Defaults to DEFAULT_EMBEDDING_MODEL_NAME.
- use_onnx (bool): Flag indicating whether to use the ONNX version of the embedding model. Defaults to DEFAULT_EMBEDDING_USE_ONNX.
- retriever_port (int): The port on which the retriever server will listen. Defaults to DEFAULT_RETRIEVER_PORT.
"""
command = (
f"gunicorn -k uvicorn.workers.UvicornWorker 'retrieval.retriever_server:gunicorn_app("
f'embedding_model="{embedding_model}", '
Expand All @@ -255,9 +270,20 @@ def start_retriever(
@task(pre=[start_qdrant_docker_container])
def test_index(
c,
collection_name=DEFAULT_QDRANT_COLLECTION_NAME,
embedding_model=DEFAULT_EMBEDDING_MODEL_NAME,
collection_name: str = DEFAULT_QDRANT_COLLECTION_NAME,
embedding_model: str = DEFAULT_EMBEDDING_MODEL_NAME,
):
"""
Tests a Qdrant index.
This task starts a Qdrant Docker container and then runs a test query to ensure that the index is working correctly.
Note that this task does not perform the actual indexing; it only tests an index that already exists.
Args:
- c: Context, automatically passed by invoke.
- collection_name (str): Name of the Qdrant collection to test. Defaults to DEFAULT_QDRANT_COLLECTION_NAME.
- embedding_model (str): Name of the embedding model to use for testing. Defaults to DEFAULT_EMBEDDING_MODEL_NAME.
"""
cmd = (
f"python retrieval/create_index.py "
f"--collection_name {collection_name}"
Expand All @@ -278,9 +304,9 @@ def test_index(
def index_collection(
c,
collection_path,
collection_name=DEFAULT_QDRANT_COLLECTION_NAME,
embedding_model_port=DEFAULT_EMBEDDING_MODEL_PORT,
embedding_model=DEFAULT_EMBEDDING_MODEL_NAME,
collection_name: str = DEFAULT_QDRANT_COLLECTION_NAME,
embedding_model_port: int = DEFAULT_EMBEDDING_MODEL_PORT,
embedding_model: str = DEFAULT_EMBEDDING_MODEL_NAME,
):
"""
Creates a Qdrant index from a collection file using a specified embedding model.
Expand Down Expand Up @@ -316,14 +342,17 @@ def index_collection(

@task
def download_wikipedia_dump(
c, workdir=DEFAULT_WORKDIR, language=DEFAULT_WIKIPEDIA_DUMP_LANGUAGE, wikipedia_date=None
c,
workdir: str = DEFAULT_WORKDIR,
language: str = DEFAULT_WIKIPEDIA_DUMP_LANGUAGE,
wikipedia_date: Optional[str] = None,
):
"""
Downloads the Wikipedia HTML dump from https://dumps.wikimedia.org/other/enterprise_html/runs/
Args:
- c: Context, automatically passed by invoke.
- wikipedia_date: Date of the Wikipedia dump. Note that currently Wikipedia keeps its HTML dumps available for a limited period of time, so older dates might not be available.
- wikipedia_date (str, optional): The date of the Wikipedia dump to use. If not provided, the latest available dump is used.
- language: Language edition of Wikipedia.
"""
if not wikipedia_date:
Expand All @@ -342,22 +371,25 @@ def download_wikipedia_dump(
@task
def preprocess_wikipedia_dump(
c,
workdir=DEFAULT_WORKDIR,
language=DEFAULT_WIKIPEDIA_DUMP_LANGUAGE,
wikipedia_date=None,
pack_to_tokens=200,
num_exclude_frequent_words_from_translation=1000,
workdir: str = DEFAULT_WORKDIR,
language: str = DEFAULT_WIKIPEDIA_DUMP_LANGUAGE,
wikipedia_date: Optional[str] = None,
pack_to_tokens: int = 200,
num_exclude_frequent_words_from_translation: int = 1000,
):
"""
Process Wikipedia HTML dump into a collection.
Process Wikipedia HTML dump into a JSONL collection file.
This takes ~4 hours for English on a 24-core CPU VM. Processing is fully parallelizable so the time is proportional to number of cores available.
It might take more for other languages, if we need to also get entity translations from Wikidata. This is because of Wikidata's rate limit.
Args:
- index_dir: Directory containing the HTML dump file (articles-html.json.tar.gz)
- workdir: Working directory for processing
- language: Language of the dump to process
- workdir (str): Working directory for processing
- language (str): Language of the dump to process
- wikipedia_date (str, optional): The date of the Wikipedia dump to use. If not provided, the latest available dump is used.
- pack_to_tokens(int): We try to pack smaller text chunks to get to this number of tokens.
- num_exclude_frequent_words_from_translation (int): For non-English Wikipedia dumps, we try to find English translations of all article names
in Wikidata. We will exclude the `num_exclude_frequent_words_from_translation` most common words because neural models are already familiar with these.
"""
output_path = get_wikipedia_collection_path(workdir, language, wikipedia_date)
if os.path.exists(output_path):
Expand Down Expand Up @@ -395,12 +427,12 @@ def preprocess_wikipedia_dump(
)
def index_wikipedia_dump(
c,
collection_name=DEFAULT_QDRANT_COLLECTION_NAME,
embedding_model_port=DEFAULT_EMBEDDING_MODEL_PORT,
embedding_model=DEFAULT_EMBEDDING_MODEL_NAME,
workdir=DEFAULT_WORKDIR,
language=DEFAULT_WIKIPEDIA_DUMP_LANGUAGE,
wikipedia_date=None,
collection_name: str = DEFAULT_QDRANT_COLLECTION_NAME,
embedding_model_port: int = DEFAULT_EMBEDDING_MODEL_PORT,
embedding_model: str = DEFAULT_EMBEDDING_MODEL_NAME,
workdir: str = DEFAULT_WORKDIR,
language: str = DEFAULT_WIKIPEDIA_DUMP_LANGUAGE,
wikipedia_date: Optional[str] = None,
):
"""
Orchestrates the indexing of a Wikipedia collection using a specified embedding model.
Expand Down Expand Up @@ -461,4 +493,6 @@ def index_multiple_wikipedia_dumps(
wikipedia_date = get_latest_wikipedia_dump_date()
for l in language:
logger.info("Started indexing for language %s", l)
index_wikipedia_dump(c, workdir=workdir, language=l, wikipedia_date=wikipedia_date)
index_wikipedia_dump(
c, workdir=workdir, language=l, wikipedia_date=wikipedia_date
)
4 changes: 2 additions & 2 deletions wikipedia_preprocessing/preprocess_html_dump.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,7 +687,7 @@ def build_redirection_map(file_path: str) -> dict:
desc="Building the Wikipedia redirection graph",
miniters=1e-6,
unit_scale=1,
unit=" Article",
unit=" Articles",
smoothing=0,
):
if is_disambiguation(article):
Expand Down Expand Up @@ -743,7 +743,7 @@ def articles_without_disambiguation_or_redirections(
desc="Extracting blocks",
miniters=1e-6,
unit_scale=1,
unit=" Block",
unit=" Blocks",
smoothing=0,
total=len(redirect_map),
)
Expand Down

0 comments on commit b7e6c9d

Please sign in to comment.