From 19b0563983e5d68d85aad86efc7bc00ebe674362 Mon Sep 17 00:00:00 2001 From: Shicheng Liu Date: Sat, 28 Sep 2024 04:09:27 +0000 Subject: [PATCH] refactor dependencies --- setup.py | 12 +++++++--- src/suql/faiss_embedding.py | 48 +++++++++++++++++++++++-------------- src/suql/utils.py | 6 ++--- 3 files changed, 42 insertions(+), 24 deletions(-) diff --git a/setup.py b/setup.py index 530fc96..3a7efc7 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ # Package metadata name = "suql" -version = "1.1.7a7" +version = "1.1.7a8" description = "Structured and Unstructured Query Language (SUQL) Python API" author = "Shicheng Liu" author_email = "shicheng@cs.stanford.edu" @@ -18,15 +18,18 @@ 'Flask-Cors==4.0.0', 'Flask-RESTful==0.3.10', 'requests==2.31.0', - 'spacy==3.6.0', 'tiktoken==0.4.0', 'psycopg2-binary==2.9.7', 'pglast==5.3', - 'FlagEmbedding~=1.2.5', 'litellm==1.34.34', 'platformdirs>=4.0.0' ] +install_dev_requires = [ + 'spacy==3.6.0', + 'FlagEmbedding~=1.2.5', +] + # Additional package information classifiers = [ "License :: OSI Approved :: Apache Software License", @@ -49,6 +52,9 @@ packages=packages, package_dir={"": "src"}, install_requires=install_requires, + extra_requires={ + "dev": install_dev_requires + }, url=url, classifiers=classifiers, package_data={ diff --git a/src/suql/faiss_embedding.py b/src/suql/faiss_embedding.py index 85ceab3..4852440 100644 --- a/src/suql/faiss_embedding.py +++ b/src/suql/faiss_embedding.py @@ -3,10 +3,8 @@ from collections import OrderedDict import os -import faiss import hashlib import pickle -from FlagEmbedding import FlagModel from flask import Flask, request from tqdm import tqdm from platformdirs import user_cache_dir @@ -21,20 +19,21 @@ # number of rows to consider for multi-column operations MULTIPLE_COLUMN_SEL = 1000 -# currently using https://huggingface.co/BAAI/bge-large-en-v1.5 -# change this line for custom embedding model -model = FlagModel( - "BAAI/bge-large-en-v1.5", - query_instruction_for_retrieval="Represent this sentence for searching relevant passages:", - use_fp16=True, -) # Setting use_fp16 to True speeds up computation with a slight performance degradation - def embed_query(query): """ Embed a query for dot product matching """ # change this line for custom embedding model + # currently using https://huggingface.co/BAAI/bge-large-en-v1.5 + # change this line for custom embedding model + from FlagEmbedding import FlagModel + + model = FlagModel( + "BAAI/bge-large-en-v1.5", + query_instruction_for_retrieval="Represent this sentence for searching relevant passages:", + use_fp16=True, + ) # Setting use_fp16 to True speeds up computation with a slight performance degradation q_embedding = model.encode_queries([query]) return q_embedding @@ -44,6 +43,15 @@ def embed_documents(documents): Embed a list of docuemnts to store in vector store """ # change this line for custom embedding model + # currently using https://huggingface.co/BAAI/bge-large-en-v1.5 + # change this line for custom embedding model + from FlagEmbedding import FlagModel + + model = FlagModel( + "BAAI/bge-large-en-v1.5", + query_instruction_for_retrieval="Represent this sentence for searching relevant passages:", + use_fp16=True, + ) # Setting use_fp16 to True speeds up computation with a slight performance degradation embeddings = model.encode(documents) return embeddings @@ -90,6 +98,8 @@ def __len__(self): def compute_top_similarity_documents(documents, query, chunking_param=0, top=3): + import faiss + """ Directly call the model to compute the top documents based on dot product with query @@ -140,6 +150,8 @@ def __init__( cache_embedding=True, force_recompute=False ) -> None: + import faiss + self.faiss = faiss # stores three lists: # 1. PSQL primary key for each row # 2. list of strings in this field @@ -257,18 +269,18 @@ def initialize_embedding(self): if (os.path.exists(faiss_cache_location) and not self.force_recompute): try: print(f"initializing from existing faiss embedding index at {faiss_cache_location}") - self.embeddings = faiss.read_index(faiss_cache_location) + self.embeddings = self.faiss.read_index(faiss_cache_location) return except Exception: print(f"reading {faiss_cache_location} failed. Re-computing embeddings") - self.embeddings = faiss.IndexFlatIP(EMBEDDING_DIMENSION) + self.embeddings = self.faiss.IndexFlatIP(EMBEDDING_DIMENSION) indexs = embed_documents(self.chunked_text) self.embeddings.add(indexs) print(f"writing computed faiss embedding to {faiss_cache_location}") os.makedirs(_user_cache_dir, exist_ok=True) - faiss.write_index(self.embeddings, faiss_cache_location) + self.faiss.write_index(self.embeddings, faiss_cache_location) def dot_product(self, id_list, query, top, individual_id_list=[]): # given a list of id and a particular query, return the top ids and documents according to similarity score ranking @@ -294,18 +306,18 @@ def dot_product(self, id_list, query, top, individual_id_list=[]): query_embedding = embed_query(query) - sel = faiss.IDSelectorBatch(embedding_indices) + sel = self.faiss.IDSelectorBatch(embedding_indices) if top < 0: D, I = self.embeddings.search( query_embedding, len(embedding_indices), - params=faiss.SearchParametersIVF(sel=sel), + params=self.faiss.SearchParametersIVF(sel=sel), ) else: if top > min(self.embeddings.ntotal, len(embedding_indices)): top = min(self.embeddings.ntotal, len(embedding_indices)) D, I = self.embeddings.search( - query_embedding, top, params=faiss.SearchParametersIVF(sel=sel) + query_embedding, top, params=self.faiss.SearchParametersIVF(sel=sel) ) embeddings_indices_max = I[0] @@ -364,11 +376,11 @@ def dot_product_with_value(self, id_list, query, individual_id_list=[]): # this is actually a 2-D array, matching what faiss expects query_embedding = embed_query(query) - sel = faiss.IDSelectorBatch(embedding_indices) + sel = self.faiss.IDSelectorBatch(embedding_indices) D, I = self.embeddings.search( query_embedding, MULTIPLE_COLUMN_SEL, - params=faiss.SearchParametersIVF(sel=sel), + params=self.faiss.SearchParametersIVF(sel=sel), ) embedding_indices = I[0] dot_products = D[0] diff --git a/src/suql/utils.py b/src/suql/utils.py index 6a24a83..4431aa7 100644 --- a/src/suql/utils.py +++ b/src/suql/utils.py @@ -1,6 +1,3 @@ -import spacy - -nlp = spacy.load("en_core_web_sm") import hashlib import tiktoken @@ -52,6 +49,9 @@ def chunk_text(text, k, use_spacy=True): if text == "": return [""] # in case of using spacy, k is the minimum number of words per chunk + import spacy + nlp = spacy.load("en_core_web_sm") + chunks = [i.text for i in nlp(text).sents] res = [] carryover = ""