diff --git a/README.md b/README.md index 4d40701..62ae972 100644 --- a/README.md +++ b/README.md @@ -33,7 +33,7 @@ # Introduction Large language model (LLM) chatbots like ChatGPT and GPT-4 get things wrong a lot, especially if the information you are looking for is recent ("Tell me about the 2024 Super Bowl.") or about less popular topics ("What are some good movies to watch from [insert your favorite foreign director]?"). -WikiChat uses Wikipedia and the following 7-stage pipeline to makes sure its responses are factual. +WikiChat uses Wikipedia and the following 7-stage pipeline to makes sure its responses are factual. Each numbered stage involves one or more LLM calls.
@@ -78,8 +78,14 @@ Installing WikiChat involves the following steps: 4. Run WikiChat with your desired configuration. 5. [Optional] Deploy WikiChat for multi-user access. We provide code to deploy a simple front-end and backend, as well as instructions to connect to an Azure Cosmos DB database for storing conversations. -This project has been tested using Python 3.10 on Ubuntu Focal 20.04 (LTS), but should run on most recent Linux distributions. -If you want to use this on Windows WSL or Mac, or with a different Python version, expect to do some troubleshooting in some of the installation steps. + +## System Requirements +This project has been tested with Python 3.10 on Ubuntu 20.04 LTS (Focal Fossa), but it should be compatible with many other Linux distributions. If you plan to use this on Windows WSL or macOS, or with a different Python version, be prepared for potential troubleshooting during installation. + +Running WikiChat using LLM APIs and our Wikipedia search API does not have specific hardware requirements and can be performed on most systems. However, if you intend to host a search index locally, ensure you have sufficient disk space for the index. For large indices, retrieval latency is heavily dependant on disk speed, so we recommend using SSDs and preferably NVMe drives. For example, storage-optimized VMs like [`Standard_L8s_v3`](https://learn.microsoft.com/en-us/azure/virtual-machines/lsv3-series) on Azure are suitable for this. + +If you plan to use WikiChat with a local LLM, a GPU is necessary to host the model. + ## Install Dependencies diff --git a/retrieval/create_index.py b/retrieval/create_index.py index dcc1da0..83cc908 100644 --- a/retrieval/create_index.py +++ b/retrieval/create_index.py @@ -96,7 +96,7 @@ def commit_to_index( desc="Indexing collection", miniters=1e-6, unit_scale=1, - unit=" Block", + unit=" Blocks", dynamic_ncols=True, smoothing=0, total=collection_size, @@ -219,7 +219,7 @@ def batch_generator(collection_file, embedding_batch_size): "--collection_file", type=str, default=None, help=".jsonl file to read from." ) parser.add_argument( - "--embedding_model", + "--embedding_model_name", type=str, choices=QdrantIndex.get_supported_embedding_models(), default="BAAI/bge-m3", @@ -258,10 +258,10 @@ def batch_generator(collection_file, embedding_batch_size): args = parser.parse_args() model_port = args.model_port - embedding_size = QdrantIndex.get_embedding_model_parameters(args.embedding_model)[ + embedding_size = QdrantIndex.get_embedding_model_parameters(args.embedding_model_name)[ "embedding_dimension" ] - query_prefix = QdrantIndex.get_embedding_model_parameters(args.embedding_model)[ + query_prefix = QdrantIndex.get_embedding_model_parameters(args.embedding_model_name)[ "query_prefix" ] @@ -325,7 +325,7 @@ def batch_generator(collection_file, embedding_batch_size): ] with QdrantIndex( - args.embedding_model, args.collection_name, use_onnx=True + args.embedding_model_name, args.collection_name, use_onnx=True ) as index: results = asyncio.run(index.search(queries, 5)) logger.info(json.dumps(results, indent=2, ensure_ascii=False)) diff --git a/retrieval/retriever_server.py b/retrieval/retriever_server.py index ba7ccaf..2c92368 100644 --- a/retrieval/retriever_server.py +++ b/retrieval/retriever_server.py @@ -164,7 +164,7 @@ async def search(request: Request, query_data: QueryData): def init(): arg_parser = argparse.ArgumentParser() arg_parser.add_argument( - "--embedding_model", + "--embedding_model_name", type=str, required=True, help="Path or the HuggingFace model name for the model used to encode the query.", @@ -183,7 +183,7 @@ def init(): if args.use_onnx: logger.info("Using ONNX for the embedding model.") qdrant_index = QdrantIndex( - embedding_model_name=args.embedding_model, + embedding_model_name=args.embedding_model_name, collection_name=args.collection_name, use_onnx=args.use_onnx, ) diff --git a/tasks/benchmark.py b/tasks/benchmark.py index 136a912..f459f81 100644 --- a/tasks/benchmark.py +++ b/tasks/benchmark.py @@ -43,6 +43,12 @@ def simulate_users( draft_engine=CHATBOT_DEFAULT_CONFIG["draft_engine"], refine_engine=CHATBOT_DEFAULT_CONFIG["refine_engine"], ): + """ + Simulate user dialogues with a chatbot using specified parameters. + + Accepts all parameters that `inv demo` accepts, plus a few additional parameters for the user simulator. + """ + pipeline_flags = ( f"--pipeline {pipeline} " f"--engine {engine} " diff --git a/tasks/docker_utils.py b/tasks/docker_utils.py index 675fafe..9997510 100644 --- a/tasks/docker_utils.py +++ b/tasks/docker_utils.py @@ -95,7 +95,7 @@ def wait_for_docker_container_to_be_ready( timeout = 60 step_time = timeout // 10 elapsed_time = 0 - logger.info("Waiting for the container '%s' to be ready...", container) + logger.info("Waiting for container '%s' to be ready...", container.name) def is_ready(): container_status = docker_client.containers.get(container.id).status @@ -128,7 +128,7 @@ def start_qdrant_docker_container( c, workdir=DEFAULT_WORKDIR, rest_port=6333, grpc_port=6334 ): """ - Starts a Qdrant docker container if it is not already running. + Start a Qdrant docker container if it is not already running. This function checks if a Qdrant docker container named 'qdrant' is already running. If so, it logs that the container is already up and does nothing. If the container exists but is stopped, it simply restarts @@ -200,7 +200,7 @@ def start_embedding_docker_container( embedding_model=DEFAULT_EMBEDDING_MODEL_NAME, ): """ - Starts a Docker container for HuggingFace's text embedding inference (TEI) if it is not already running. + Start a Docker container for HuggingFace's text embedding inference (TEI) if it is not already running. See https://github.com/huggingface/text-embeddings-inference for TEI documentation. This function checks if a text-embedding-inference Docker container is already running and starts it if not. diff --git a/tasks/main.py b/tasks/main.py index f0048bb..312679d 100644 --- a/tasks/main.py +++ b/tasks/main.py @@ -15,6 +15,20 @@ @task def load_api_keys(c): + """ + Load API keys from a file named 'API_KEYS' and set them as environment variables. + + This function reads the 'API_KEYS' file line by line, extracts key-value pairs, + and sets them as environment variables. Lines starting with '#' are treated as + comments and ignored. The expected format for each line in the file is 'KEY=VALUE'. + + Parameters: + - c: Context, automatically passed by invoke. + + Raises: + - Exception: If there is an error while reading the 'API_KEYS' file or setting + the environment variables, an error message is logged. + """ try: with open("API_KEYS") as f: for line in f: @@ -33,7 +47,18 @@ def load_api_keys(c): @task() -def start_redis(c, redis_port=DEFAULT_REDIS_PORT): +def start_redis(c, redis_port: int = DEFAULT_REDIS_PORT): + """ + Start a Redis server if it is not already running. + + This task attempts to connect to a Redis server on the specified port. + If the connection fails (indicating that the Redis server is not running), + it starts a new Redis server on that port. + + Parameters: + - c: Context, automatically passed by invoke. + - redis_port (int): The port number on which to start the Redis server. Defaults to DEFAULT_REDIS_PORT. + """ try: r = redis.Redis(host="localhost", port=redis_port) r.ping() diff --git a/tasks/retrieval.py b/tasks/retrieval.py index 56cd3db..d74e5de 100644 --- a/tasks/retrieval.py +++ b/tasks/retrieval.py @@ -11,6 +11,7 @@ from huggingface_hub import snapshot_download from invoke import task from tqdm import tqdm +from typing import Optional from tasks.docker_utils import ( start_embedding_docker_container, @@ -36,35 +37,6 @@ logger = get_logger(__name__) -@task -def download_wikipedia_index( - c, - repo_id="stanford-oval/wikipedia_10-languages_bge-m3_qdrant_index", - workdir=DEFAULT_WORKDIR, - num_threads=10, -): - # Download the files - snapshot_download( - repo_id=repo_id, - repo_type="dataset", - local_dir=workdir, - allow_patterns="*.tar.*", - max_workers=num_threads, - ) - - # Find the part files - part_files = " ".join(sorted(glob.glob(os.path.join(workdir, "*.tar.*")))) - - # Ensure part_files is not empty - if not part_files: - raise FileNotFoundError("No part files found in the specified directory.") - - # Decompress and extract the files - c.run( - f"cat {part_files} | pigz -d -p {num_threads} | tar --strip-components=2 -xv -C {workdir}" - ) # strip-components gets rid of the extra workdir/ - - def get_latest_wikipedia_dump_date() -> str: """ Fetches the latest Wikipedia HTML dump date from the Wikimedia dumps page. @@ -111,7 +83,7 @@ def download_chunk_from_url( url, start, end, output_path, pbar, file_lock, num_retries=3 ): """ - Downloads a chunk of data from a specific URL, within a given byte range, and writes it to a part of a file. + Download a chunk of data from a specific URL, within a given byte range, and write it to a part of a file. This function attempts to download a specified range of bytes from a given URL and write this data into a part of a file denoted by the start byte. If the download fails due to a ChunkedEncodingError, it will retry up to @@ -148,7 +120,7 @@ def download_chunk_from_url( def multithreaded_download(url: str, output_path: str, num_parts: int = 3) -> None: """ - Downloads a file in parts concurrently using multiple threads to optimize the download process. + Download a file in parts concurrently using multiple threads to optimize the download process. This function breaks the download into several parts and downloads each part in parallel, thus potentially improving the download speed. It is especially useful when dealing with large files and/or rate-limited servers. @@ -213,7 +185,48 @@ def multithreaded_download(url: str, output_path: str, num_parts: int = 3) -> No @task -def print_wikipedia_dump_date(c, date=None): +def download_wikipedia_index( + c, + repo_id: str = "stanford-oval/wikipedia_10-languages_bge-m3_qdrant_index", + workdir: str = DEFAULT_WORKDIR, + num_threads: int = 8, +): + """ + Download and extract a pre-built Qdrant index for Wikipedia from a 🤗 Hub. + + Args: + - c: Context, automatically passed by invoke. + - repo_id (str): The 🤗 hub repository ID from which to download the index files. Defaults to "stanford-oval/wikipedia_10-languages_bge-m3_qdrant_index". + - workdir (str): The working directory where the files will be downloaded and extracted. Defaults to DEFAULT_WORKDIR. + - num_threads (int): The number of threads to use for downloading and decompressing the files. Defaults to 8. + + Raises: + - FileNotFoundError: If no part files are found in the specified directory. + """ + # Download the files + snapshot_download( + repo_id=repo_id, + repo_type="dataset", + local_dir=workdir, + allow_patterns="*.tar.*", + max_workers=num_threads, + ) + + # Find the part files + part_files = " ".join(sorted(glob.glob(os.path.join(workdir, "*.tar.*")))) + + # Ensure part_files is not empty + if not part_files: + raise FileNotFoundError("No part files found in the specified directory.") + + # Decompress and extract the files + c.run( + f"cat {part_files} | pigz -d -p {num_threads} | tar --strip-components=2 -xv -C {workdir}" + ) # strip-components gets rid of the extra workdir/ + + +@task +def print_wikipedia_dump_date(c, date: Optional[str] = None): """ Prints the specified date in a human-readable format. @@ -235,13 +248,27 @@ def print_wikipedia_dump_date(c, date=None): def start_retriever( c, collection_name=DEFAULT_QDRANT_COLLECTION_NAME, - embedding_model=DEFAULT_EMBEDDING_MODEL_NAME, + embedding_model_name=DEFAULT_EMBEDDING_MODEL_NAME, use_onnx=DEFAULT_EMBEDDING_USE_ONNX, retriever_port=DEFAULT_RETRIEVER_PORT, ): + """ + Starts the retriever server. + + This task runs the retriever server, which is responsible for handling + retrieval requests. It uses Gunicorn to manage Uvicorn workers for asynchronous processing. + The retriever server sends search requests to the Qdrant docker container. + + Args: + - c: Context, automatically passed by invoke. + - collection_name (str): The name of the Qdrant collection to query. Defaults to DEFAULT_QDRANT_COLLECTION_NAME. + - embedding_model_name (str): The HuggingFace ID of the embedding model to use for retrieval. Defaults to DEFAULT_EMBEDDING_MODEL_NAME. + - use_onnx (bool): Flag indicating whether to use the ONNX version of the embedding model. Defaults to DEFAULT_EMBEDDING_USE_ONNX. + - retriever_port (int): The port on which the retriever server will listen. Defaults to DEFAULT_RETRIEVER_PORT. + """ command = ( f"gunicorn -k uvicorn.workers.UvicornWorker 'retrieval.retriever_server:gunicorn_app(" - f'embedding_model="{embedding_model}", ' + f'embedding_model_name="{embedding_model_name}", ' f'use_onnx="{use_onnx}", ' f'collection_name="{collection_name}")\' ' f"--access-logfile=- " @@ -255,13 +282,24 @@ def start_retriever( @task(pre=[start_qdrant_docker_container]) def test_index( c, - collection_name=DEFAULT_QDRANT_COLLECTION_NAME, - embedding_model=DEFAULT_EMBEDDING_MODEL_NAME, + collection_name: str = DEFAULT_QDRANT_COLLECTION_NAME, + embedding_model_name: str = DEFAULT_EMBEDDING_MODEL_NAME, ): + """ + Test a Qdrant index. + + This task starts a Qdrant Docker container and then runs a test query to ensure that the index is working correctly. + Note that this task does not perform the actual indexing; it only tests an index that already exists. + + Args: + - c: Context, automatically passed by invoke. + - collection_name (str): Name of the Qdrant collection to test. Defaults to DEFAULT_QDRANT_COLLECTION_NAME. + - embedding_model_name (str): Name of the embedding model to use for testing. Defaults to DEFAULT_EMBEDDING_MODEL_NAME. + """ cmd = ( f"python retrieval/create_index.py " f"--collection_name {collection_name}" - f"--embedding_model {embedding_model} " + f"--embedding_model_name {embedding_model_name} " f"--test" # Just test, don't index ) c.run(cmd, pty=True) @@ -278,9 +316,9 @@ def test_index( def index_collection( c, collection_path, - collection_name=DEFAULT_QDRANT_COLLECTION_NAME, - embedding_model_port=DEFAULT_EMBEDDING_MODEL_PORT, - embedding_model=DEFAULT_EMBEDDING_MODEL_NAME, + collection_name: str = DEFAULT_QDRANT_COLLECTION_NAME, + embedding_model_port: int = DEFAULT_EMBEDDING_MODEL_PORT, + embedding_model_name: str = DEFAULT_EMBEDDING_MODEL_NAME, ): """ Creates a Qdrant index from a collection file using a specified embedding model. @@ -301,13 +339,13 @@ def index_collection( - collection_name (str): The name of the Qdrant collection where the indexed data will be stored. This parameter allows you to specify a custom name for the collection, which can be useful for organizing multiple indexes or distinguishing between different versions of the same dataset. - embedding_model_port (int): The port on which the embedding model server is running. - - embedding_model (str): The HuggingFace ID of the embedding model to use for indexing. + - embedding_model_name (str): The HuggingFace ID of the embedding model to use for indexing. """ c.run( f"python retrieval/create_index.py " f"--collection_file {collection_path} " f"--collection_name {collection_name} " - f"--embedding_model {embedding_model} " + f"--embedding_model_name {embedding_model_name} " f"--model_port {embedding_model_port} " f"--index", # But don't test, because it takes time for Qdrant to optimize the index after we have inserted vectors in bulk. pty=True, @@ -316,14 +354,17 @@ def index_collection( @task def download_wikipedia_dump( - c, workdir=DEFAULT_WORKDIR, language=DEFAULT_WIKIPEDIA_DUMP_LANGUAGE, wikipedia_date=None + c, + workdir: str = DEFAULT_WORKDIR, + language: str = DEFAULT_WIKIPEDIA_DUMP_LANGUAGE, + wikipedia_date: Optional[str] = None, ): """ - Downloads the Wikipedia HTML dump from https://dumps.wikimedia.org/other/enterprise_html/runs/ + Download a Wikipedia HTML dump from https://dumps.wikimedia.org/other/enterprise_html/runs/ Args: - c: Context, automatically passed by invoke. - - wikipedia_date: Date of the Wikipedia dump. Note that currently Wikipedia keeps its HTML dumps available for a limited period of time, so older dates might not be available. + - wikipedia_date (str, optional): The date of the Wikipedia dump to use. If not provided, the latest available dump is used. - language: Language edition of Wikipedia. """ if not wikipedia_date: @@ -342,22 +383,25 @@ def download_wikipedia_dump( @task def preprocess_wikipedia_dump( c, - workdir=DEFAULT_WORKDIR, - language=DEFAULT_WIKIPEDIA_DUMP_LANGUAGE, - wikipedia_date=None, - pack_to_tokens=200, - num_exclude_frequent_words_from_translation=1000, + workdir: str = DEFAULT_WORKDIR, + language: str = DEFAULT_WIKIPEDIA_DUMP_LANGUAGE, + wikipedia_date: Optional[str] = None, + pack_to_tokens: int = 200, + num_exclude_frequent_words_from_translation: int = 1000, ): """ - Process Wikipedia HTML dump into a collection. + Process Wikipedia HTML dump into a JSONL collection file. This takes ~4 hours for English on a 24-core CPU VM. Processing is fully parallelizable so the time is proportional to number of cores available. It might take more for other languages, if we need to also get entity translations from Wikidata. This is because of Wikidata's rate limit. Args: - - index_dir: Directory containing the HTML dump file (articles-html.json.tar.gz) - - workdir: Working directory for processing - - language: Language of the dump to process + - workdir (str): Working directory for processing + - language (str): Language of the dump to process + - wikipedia_date (str, optional): The date of the Wikipedia dump to use. If not provided, the latest available dump is used. + - pack_to_tokens(int): We try to pack smaller text chunks to get to this number of tokens. + - num_exclude_frequent_words_from_translation (int): For non-English Wikipedia dumps, we try to find English translations of all article names + in Wikidata. We will exclude the `num_exclude_frequent_words_from_translation` most common words because neural models are already familiar with these. """ output_path = get_wikipedia_collection_path(workdir, language, wikipedia_date) if os.path.exists(output_path): @@ -395,12 +439,12 @@ def preprocess_wikipedia_dump( ) def index_wikipedia_dump( c, - collection_name=DEFAULT_QDRANT_COLLECTION_NAME, - embedding_model_port=DEFAULT_EMBEDDING_MODEL_PORT, - embedding_model=DEFAULT_EMBEDDING_MODEL_NAME, - workdir=DEFAULT_WORKDIR, - language=DEFAULT_WIKIPEDIA_DUMP_LANGUAGE, - wikipedia_date=None, + collection_name: str = DEFAULT_QDRANT_COLLECTION_NAME, + embedding_model_port: int = DEFAULT_EMBEDDING_MODEL_PORT, + embedding_model_name: str = DEFAULT_EMBEDDING_MODEL_NAME, + workdir: str = DEFAULT_WORKDIR, + language: str = DEFAULT_WIKIPEDIA_DUMP_LANGUAGE, + wikipedia_date: Optional[str] = None, ): """ Orchestrates the indexing of a Wikipedia collection using a specified embedding model. @@ -418,7 +462,7 @@ def index_wikipedia_dump( - collection_name (str): The name of the Qdrant collection where the indexed data will be stored. This parameter allows you to specify a custom name for the collection, which can be useful for organizing multiple indexes or distinguishing between different versions of the same dataset. - embedding_model_port (int): The port on which the embedding model server is running. - - embedding_model (str): The HuggingFace ID of the embedding model to use for indexing. + - embedding_model_name (str): The HuggingFace ID of the embedding model to use for indexing. - workdir (str): The working directory where intermediate and final files are stored. - language (str): The language edition of Wikipedia to index (e.g., "en" for English). - wikipedia_date (str, optional): The date of the Wikipedia dump to use. If not provided, the latest available dump is used. @@ -441,7 +485,7 @@ def index_wikipedia_dump( collection_path=collection_path, collection_name=collection_name, embedding_model_port=embedding_model_port, - embedding_model=embedding_model, + embedding_model_name=embedding_model_name, ) @@ -455,10 +499,15 @@ def index_wikipedia_dump( iterable=["language"], ) def index_multiple_wikipedia_dumps( - c, language, workdir=DEFAULT_WORKDIR, wikipedia_date=None + c, language, workdir: str = DEFAULT_WORKDIR, wikipedia_date: Optional[str] = None ): + """ + Index multiple Wikipedia dumps from different languages in a for loop. + """ if not wikipedia_date: wikipedia_date = get_latest_wikipedia_dump_date() for l in language: logger.info("Started indexing for language %s", l) - index_wikipedia_dump(c, workdir=workdir, language=l, wikipedia_date=wikipedia_date) + index_wikipedia_dump( + c, workdir=workdir, language=l, wikipedia_date=wikipedia_date + ) diff --git a/tasks/setup.py b/tasks/setup.py index 706ca2e..0e87190 100644 --- a/tasks/setup.py +++ b/tasks/setup.py @@ -12,6 +12,23 @@ @task def setup_nvme(c): + """ + Set up an NVMe drive on the VM by performing the following steps. Only works on certain Linux distributions. + + 1. Installs the `nvme-cli` package to manage NVMe devices. + 2. Lists available NVMe devices on the system. + 3. Extracts NVMe device names from the listing output. + 4. Checks if any NVMe devices are found; if none, logs a message and exits. + 5. Formats the first NVMe device found with the XFS filesystem. + 6. Creates a mount point at `/mnt/ephemeral_nvme`. + 7. Mounts the NVMe device to the created mount point. + 8. Changes ownership of the mount point to the current user to enable read and write access. + 9. Creates a `workdir` directory on the NVMe drive. + 10. Creates a symbolic link `./workdir` pointing to the `workdir` directory on the NVMe drive. + + Args: + c: The context instance (passed automatically by invoke). + """ # See if your VM has an NVMe drive c.run("sudo apt install -y nvme-cli") @@ -52,7 +69,7 @@ def setup_nvme(c): @task def install_docker(c): """ - Task to install Docker on an Ubuntu system by following https://docs.docker.com/engine/install/ubuntu/#install-using-the-repository + Install Docker on an Ubuntu system by following https://docs.docker.com/engine/install/ubuntu/#install-using-the-repository Args: c (invoke.context.Context): The context instance (passed automatically by the @task decorator). @@ -89,7 +106,7 @@ def install_docker(c): @task def install_anaconda(c): """ - Installs Anaconda if it is not already installed. + Install Anaconda if it is not already installed. This task checks if Anaconda (conda) is already installed on the system. If it is not installed, it downloads the Anaconda installer for Linux, runs the installer, and then removes the installer file. @@ -112,7 +129,7 @@ def install_anaconda(c): @task(pre=[install_anaconda]) def setup_conda_env(c): """ - Sets up the Conda environment using the environment file. + Set up the Conda environment using the environment file. This task creates a Conda environment based on the specifications in the 'conda_env.yaml' file. After creating the environment, it activates the environment named 'wikichat' and downloads @@ -130,7 +147,7 @@ def setup_conda_env(c): @task def download_azcopy(c): """ - Downloads and installs AzCopy, a command-line utility for copying data to and from Microsoft Azure. + Download and install AzCopy, a command-line utility for copying data to and from Microsoft Azure. This task performs the following steps: 1. Downloads the AzCopy tarball from the official Microsoft Azure link. @@ -167,7 +184,7 @@ def download_azcopy(c): ) def install(c): """ - Installs various tools and sets up the environment. + Install various tools and set up the environment. This task orchestrates the installation and setup of several tools and environments required for the project. It performs the following steps in sequence: diff --git a/wikipedia_preprocessing/preprocess_html_dump.py b/wikipedia_preprocessing/preprocess_html_dump.py index 242e194..b06a6e9 100644 --- a/wikipedia_preprocessing/preprocess_html_dump.py +++ b/wikipedia_preprocessing/preprocess_html_dump.py @@ -687,7 +687,7 @@ def build_redirection_map(file_path: str) -> dict: desc="Building the Wikipedia redirection graph", miniters=1e-6, unit_scale=1, - unit=" Article", + unit=" Articles", smoothing=0, ): if is_disambiguation(article): @@ -743,7 +743,7 @@ def articles_without_disambiguation_or_redirections( desc="Extracting blocks", miniters=1e-6, unit_scale=1, - unit=" Block", + unit=" Blocks", smoothing=0, total=len(redirect_map), ) diff --git a/wikipedia_preprocessing/upload_collections_to_hf_hub.py b/wikipedia_preprocessing/upload_collections_to_hf_hub.py new file mode 100644 index 0000000..0c39b50 --- /dev/null +++ b/wikipedia_preprocessing/upload_collections_to_hf_hub.py @@ -0,0 +1,57 @@ +import argparse +from huggingface_hub import HfApi +import gzip +import shutil +import os + +from tqdm import tqdm + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Upload preprocessed Wikipedia collection files to HuggingFace Hub" + ) + parser.add_argument( + "--dates", + nargs="+", + required=True, + help="List of Wikipedia dump dates (e.g., 20240401)", + ) + parser.add_argument( + "--languages", + nargs="+", + required=True, + help="List of language codes (e.g., de en fr)", + ) + + args = parser.parse_args() + + api = HfApi() + for date, language in tqdm( + [(date, lang) for date in args.dates for lang in args.languages], + desc="Uploading", + ): + # Extract the .gz file + gz_file = f"workdir/{language}/wikipedia_{date}/collection.jsonl.gz" + extracted_file = f"workdir/{language}/wikipedia_{date}/collection.jsonl" + + with gzip.open(gz_file, "rb") as f_in: + with open(extracted_file, "wb") as f_out: + shutil.copyfileobj(f_in, f_out) + + # Now the extracted file is ready for upload + for file in ["collection.jsonl", "collection_histogram.png"]: + api.upload_file( + path_or_fileobj=f"workdir/{language}/wikipedia_{date}/{file}", + path_in_repo=f"{date}/{language}/{file}", + repo_id="stanford-oval/wikipedia", + repo_type="dataset", + run_as_future=True, + ) + + # Remove the extracted files + for date, language in [ + (date, lang) for date in args.dates for lang in args.languages + ]: + extracted_file = f"workdir/{language}/wikipedia_{date}/collection.jsonl" + os.remove(extracted_file)