Skip to content

Commit

Permalink
updating workflow
Browse files Browse the repository at this point in the history
Signed-off-by: Francisco Javier Arceo <[email protected]>
  • Loading branch information
franciscojavierarceo committed Jan 10, 2025
1 parent e482ba6 commit 3731e17
Showing 1 changed file with 38 additions and 6 deletions.
44 changes: 38 additions & 6 deletions examples/rag/feature_repo/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch
import torch.nn.functional as F
from feast import FeatureStore
from pymilvus import MilvusClient
from pymilvus import MilvusClient, DataType, FieldSchema
from transformers import AutoTokenizer, AutoModel
from example_repo import city_embeddings_feature_view, item
TOKENIZER = "sentence-transformers/all-MiniLM-L6-v2"
Expand Down Expand Up @@ -36,12 +36,43 @@ def run_model(sentences, tokenizer, model):
def run_demo():
store = FeatureStore(repo_path=".")
df = pd.read_parquet("./data/city_wikipedia_summaries_with_embeddings.parquet")
store.apply([city_embeddings_feature_view, item])
store.write_to_online_store_async("city_embeddings", df)
embedding_length = len(df['vector'][0])
print(f'embedding length = {embedding_length}')

print('\ndata=')
print(df.head().T)

client = MilvusClient(alias="feast", host="localhost", port="19530", token="username:password")
print(client.list_collections())
store.apply([city_embeddings_feature_view, item])
store.write_to_online_store("city_embeddings", df)

client = MilvusClient(uir="http://localhost:19530", token="username:password")
fields = [
FieldSchema(name="id", dtype=DataType.INT64, is_primary=True),
FieldSchema(name='state', dtype=DataType.STRING, description="State"),
FieldSchema(name='wiki_summary', dtype=DataType.STRING, description="State"),
FieldSchema(name='sentence_chunks', dtype=DataType.STRING, description="Sentence Chunks"),
FieldSchema(name="item_id", dtype=DataType.INT64, default_value=0, description="Item"),
FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=embedding_length, description="vector")
]
cols = [f.name for f in fields]
client.insert(
collection_name="demo_collection",
data=df[cols].to_dict(orient="records"),
schema=fields,
)
print('\n')
print('collections', client.list_collections())
print('query results =', client.query(
collection_name="rag_city_embeddings",
filter="item_id == 0",
# output_fields=['city_embeddings', 'item_id', 'city_name'],
))
print('query results2 =', client.query(
collection_name="rag_city_embeddings",
filter="item_id >= 0",
output_fields=["count(*)"]
# output_fields=['city_embeddings', 'item_id', 'city_name'],
))
question = "the most populous city in the U.S. state of Texas?"
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER)
model = AutoModel.from_pretrained(MODEL)
Expand All @@ -50,7 +81,8 @@ def run_demo():

# Retrieve top k documents
features = store.retrieve_online_documents(
feature="city_embeddings:Embeddings",
feature=None,
features=["city_embeddings:vector", "city_embeddings:item_id", "city_embeddings:state"],
query=query,
top_k=3
)
Expand Down

0 comments on commit 3731e17

Please sign in to comment.