Skip to content

Commit

Permalink
Merge pull request #26 from stanford-oval/auto-sync-2024-08-23-02-45-07
Browse files Browse the repository at this point in the history
auto-sync-2024-08-23-02-45-07
  • Loading branch information
s-jse authored Aug 23, 2024
2 parents b98c711 + 013c1f2 commit 9220e7a
Show file tree
Hide file tree
Showing 10 changed files with 245 additions and 85 deletions.
12 changes: 9 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
# Introduction

Large language model (LLM) chatbots like ChatGPT and GPT-4 get things wrong a lot, especially if the information you are looking for is recent ("Tell me about the 2024 Super Bowl.") or about less popular topics ("What are some good movies to watch from [insert your favorite foreign director]?").
WikiChat uses Wikipedia and the following 7-stage pipeline to makes sure its responses are factual.
WikiChat uses Wikipedia and the following 7-stage pipeline to makes sure its responses are factual. Each numbered stage involves one or more LLM calls.


<p align="center">
Expand Down Expand Up @@ -78,8 +78,14 @@ Installing WikiChat involves the following steps:
4. Run WikiChat with your desired configuration.
5. [Optional] Deploy WikiChat for multi-user access. We provide code to deploy a simple front-end and backend, as well as instructions to connect to an Azure Cosmos DB database for storing conversations.

This project has been tested using Python 3.10 on Ubuntu Focal 20.04 (LTS), but should run on most recent Linux distributions.
If you want to use this on Windows WSL or Mac, or with a different Python version, expect to do some troubleshooting in some of the installation steps.

## System Requirements
This project has been tested with Python 3.10 on Ubuntu 20.04 LTS (Focal Fossa), but it should be compatible with many other Linux distributions. If you plan to use this on Windows WSL or macOS, or with a different Python version, be prepared for potential troubleshooting during installation.

Running WikiChat using LLM APIs and our Wikipedia search API does not have specific hardware requirements and can be performed on most systems. However, if you intend to host a search index locally, ensure you have sufficient disk space for the index. For large indices, retrieval latency is heavily dependant on disk speed, so we recommend using SSDs and preferably NVMe drives. For example, storage-optimized VMs like [`Standard_L8s_v3`](https://learn.microsoft.com/en-us/azure/virtual-machines/lsv3-series) on Azure are suitable for this.

If you plan to use WikiChat with a local LLM, a GPU is necessary to host the model.


## Install Dependencies

Expand Down
10 changes: 5 additions & 5 deletions retrieval/create_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def commit_to_index(
desc="Indexing collection",
miniters=1e-6,
unit_scale=1,
unit=" Block",
unit=" Blocks",
dynamic_ncols=True,
smoothing=0,
total=collection_size,
Expand Down Expand Up @@ -219,7 +219,7 @@ def batch_generator(collection_file, embedding_batch_size):
"--collection_file", type=str, default=None, help=".jsonl file to read from."
)
parser.add_argument(
"--embedding_model",
"--embedding_model_name",
type=str,
choices=QdrantIndex.get_supported_embedding_models(),
default="BAAI/bge-m3",
Expand Down Expand Up @@ -258,10 +258,10 @@ def batch_generator(collection_file, embedding_batch_size):
args = parser.parse_args()

model_port = args.model_port
embedding_size = QdrantIndex.get_embedding_model_parameters(args.embedding_model)[
embedding_size = QdrantIndex.get_embedding_model_parameters(args.embedding_model_name)[
"embedding_dimension"
]
query_prefix = QdrantIndex.get_embedding_model_parameters(args.embedding_model)[
query_prefix = QdrantIndex.get_embedding_model_parameters(args.embedding_model_name)[
"query_prefix"
]

Expand Down Expand Up @@ -325,7 +325,7 @@ def batch_generator(collection_file, embedding_batch_size):
]

with QdrantIndex(
args.embedding_model, args.collection_name, use_onnx=True
args.embedding_model_name, args.collection_name, use_onnx=True
) as index:
results = asyncio.run(index.search(queries, 5))
logger.info(json.dumps(results, indent=2, ensure_ascii=False))
4 changes: 2 additions & 2 deletions retrieval/retriever_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ async def search(request: Request, query_data: QueryData):
def init():
arg_parser = argparse.ArgumentParser()
arg_parser.add_argument(
"--embedding_model",
"--embedding_model_name",
type=str,
required=True,
help="Path or the HuggingFace model name for the model used to encode the query.",
Expand All @@ -183,7 +183,7 @@ def init():
if args.use_onnx:
logger.info("Using ONNX for the embedding model.")
qdrant_index = QdrantIndex(
embedding_model_name=args.embedding_model,
embedding_model_name=args.embedding_model_name,
collection_name=args.collection_name,
use_onnx=args.use_onnx,
)
Expand Down
6 changes: 6 additions & 0 deletions tasks/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@ def simulate_users(
draft_engine=CHATBOT_DEFAULT_CONFIG["draft_engine"],
refine_engine=CHATBOT_DEFAULT_CONFIG["refine_engine"],
):
"""
Simulate user dialogues with a chatbot using specified parameters.
Accepts all parameters that `inv demo` accepts, plus a few additional parameters for the user simulator.
"""

pipeline_flags = (
f"--pipeline {pipeline} "
f"--engine {engine} "
Expand Down
6 changes: 3 additions & 3 deletions tasks/docker_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def wait_for_docker_container_to_be_ready(
timeout = 60
step_time = timeout // 10
elapsed_time = 0
logger.info("Waiting for the container '%s' to be ready...", container)
logger.info("Waiting for container '%s' to be ready...", container.name)

def is_ready():
container_status = docker_client.containers.get(container.id).status
Expand Down Expand Up @@ -128,7 +128,7 @@ def start_qdrant_docker_container(
c, workdir=DEFAULT_WORKDIR, rest_port=6333, grpc_port=6334
):
"""
Starts a Qdrant docker container if it is not already running.
Start a Qdrant docker container if it is not already running.
This function checks if a Qdrant docker container named 'qdrant' is already running. If so, it logs that
the container is already up and does nothing. If the container exists but is stopped, it simply restarts
Expand Down Expand Up @@ -200,7 +200,7 @@ def start_embedding_docker_container(
embedding_model=DEFAULT_EMBEDDING_MODEL_NAME,
):
"""
Starts a Docker container for HuggingFace's text embedding inference (TEI) if it is not already running.
Start a Docker container for HuggingFace's text embedding inference (TEI) if it is not already running.
See https://github.com/huggingface/text-embeddings-inference for TEI documentation.
This function checks if a text-embedding-inference Docker container is already running and starts it if not.
Expand Down
27 changes: 26 additions & 1 deletion tasks/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,20 @@

@task
def load_api_keys(c):
"""
Load API keys from a file named 'API_KEYS' and set them as environment variables.
This function reads the 'API_KEYS' file line by line, extracts key-value pairs,
and sets them as environment variables. Lines starting with '#' are treated as
comments and ignored. The expected format for each line in the file is 'KEY=VALUE'.
Parameters:
- c: Context, automatically passed by invoke.
Raises:
- Exception: If there is an error while reading the 'API_KEYS' file or setting
the environment variables, an error message is logged.
"""
try:
with open("API_KEYS") as f:
for line in f:
Expand All @@ -33,7 +47,18 @@ def load_api_keys(c):


@task()
def start_redis(c, redis_port=DEFAULT_REDIS_PORT):
def start_redis(c, redis_port: int = DEFAULT_REDIS_PORT):
"""
Start a Redis server if it is not already running.
This task attempts to connect to a Redis server on the specified port.
If the connection fails (indicating that the Redis server is not running),
it starts a new Redis server on that port.
Parameters:
- c: Context, automatically passed by invoke.
- redis_port (int): The port number on which to start the Redis server. Defaults to DEFAULT_REDIS_PORT.
"""
try:
r = redis.Redis(host="localhost", port=redis_port)
r.ping()
Expand Down
Loading

0 comments on commit 9220e7a

Please sign in to comment.