Skip to content

Commit

Permalink
almost have retrieval working, having to make a lot of changes to onl…
Browse files Browse the repository at this point in the history
…ine retrieval. long term this can all go in the FeatureView class and in get_online_features

Signed-off-by: Francisco Javier Arceo <[email protected]>
  • Loading branch information
franciscojavierarceo committed Dec 18, 2024
1 parent 2b18594 commit d05d601
Show file tree
Hide file tree
Showing 10 changed files with 94 additions and 67 deletions.
40 changes: 33 additions & 7 deletions sdk/python/feast/feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -1750,9 +1750,10 @@ async def get_online_features_async(

def retrieve_online_documents(
self,
feature: str,
feature: Optional[str],
query: Union[str, List[float]],
top_k: int,
features: Optional[List[str]] = None,
distance_metric: Optional[str] = None,
) -> OnlineResponse:
"""
Expand All @@ -1762,6 +1763,7 @@ def retrieve_online_documents(
feature: The list of document features that should be retrieved from the online document store. These features can be
specified either as a list of string document feature references or as a feature service. String feature
references must have format "feature_view:feature", e.g, "document_fv:document_embeddings".
features: The list of features that should be retrieved from the online store.
query: The query to retrieve the closest document features for.
top_k: The number of closest document features to retrieve.
distance_metric: The distance metric to use for retrieval.
Expand All @@ -1770,18 +1772,39 @@ def retrieve_online_documents(
raise ValueError(
"Using embedding functionality is not supported for document retrieval. Please embed the query before calling retrieve_online_documents."
)
feature_list = features or [feature]
(
available_feature_views,
_,
) = utils._get_feature_views_to_use(
registry=self._registry,
project=self.project,
features=[feature],
features=feature_list,
allow_cache=True,
hide_dummy_entity=False,
)
if features:
feature_view_set = set()
for feature in features:
feature_view_name = feature.split(":")[0]
feature_view = self.get_feature_view(feature_view_name)
feature_view_set.add(feature_view.name)
if len(feature_view_set) > 1:
raise ValueError(
"Document retrieval only supports a single feature view."
)
requested_feature = None
requested_features = [
f.split(":")[1] for f in features if isinstance(f, str) and ":" in f
]
else:
requested_feature = (
feature.split(":")[1] if isinstance(feature, str) else feature
)
requested_features = [requested_feature]

requested_feature_view_name = (
feature.split(":")[0] if isinstance(feature, str) else feature
feature.split(":")[0] if feature else list(feature_view_set)[0]
)
for feature_view in available_feature_views:
if feature_view.name == requested_feature_view_name:
Expand All @@ -1790,14 +1813,15 @@ def retrieve_online_documents(
raise ValueError(
f"Feature view {requested_feature_view} not found in the registry."
)
requested_feature = (
feature.split(":")[1] if isinstance(feature, str) else feature
)

requested_feature_view = available_feature_views[0]

provider = self._get_provider()
document_features = self._retrieve_from_online_store(
provider,
requested_feature_view,
requested_feature,
requested_features,
query,
top_k,
distance_metric,
Expand Down Expand Up @@ -1833,7 +1857,8 @@ def _retrieve_from_online_store(
self,
provider: Provider,
table: FeatureView,
requested_feature: str,
requested_feature: Optional[str],
requested_features: Optional[List[str]],
query: List[float],
top_k: int,
distance_metric: Optional[str],
Expand All @@ -1849,6 +1874,7 @@ def _retrieve_from_online_store(
config=self.config,
table=table,
requested_feature=requested_feature,
requested_features=requested_features,
query=query,
top_k=top_k,
distance_metric=distance_metric,
Expand Down
81 changes: 32 additions & 49 deletions sdk/python/feast/infra/online_stores/milvus_online_store/milvus.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@
)

PROTO_TO_MILVUS_TYPE_MAPPING: Dict[ValueType, DataType] = {
PROTO_VALUE_TO_VALUE_TYPE_MAP["bytes_val"]: DataType.STRING,
PROTO_VALUE_TO_VALUE_TYPE_MAP["bytes_val"]: DataType.VARCHAR,
PROTO_VALUE_TO_VALUE_TYPE_MAP["bool_val"]: DataType.BOOL,
PROTO_VALUE_TO_VALUE_TYPE_MAP["string_val"]: DataType.STRING,
PROTO_VALUE_TO_VALUE_TYPE_MAP["string_val"]: DataType.VARCHAR,
PROTO_VALUE_TO_VALUE_TYPE_MAP["float_val"]: DataType.FLOAT,
PROTO_VALUE_TO_VALUE_TYPE_MAP["double_val"]: DataType.DOUBLE,
PROTO_VALUE_TO_VALUE_TYPE_MAP["int32_val"]: DataType.INT32,
Expand Down Expand Up @@ -71,6 +71,8 @@
ValueType.DOUBLE,
]:
FEAST_PRIMITIVE_TO_MILVUS_TYPE_MAPPING[feast_type] = DataType.FLOAT_VECTOR
elif base_value_type == ValueType.STRING:
FEAST_PRIMITIVE_TO_MILVUS_TYPE_MAPPING[feast_type] = DataType.VARCHAR
elif base_value_type == ValueType.BOOL:
FEAST_PRIMITIVE_TO_MILVUS_TYPE_MAPPING[feast_type] = DataType.BINARY_VECTOR

Expand Down Expand Up @@ -149,7 +151,14 @@ def _get_collection(self, config: RepoConfig, table: FeatureView) -> Collection:
dim=config.online_store.embedding_dim,
)
)

elif dtype == DataType.VARCHAR:
fields.append(
FieldSchema(
name=field.name,
dtype=dtype,
max_length=512,
)
)
else:
fields.append(FieldSchema(name=field.name, dtype=dtype))

Expand Down Expand Up @@ -210,17 +219,14 @@ def online_write_batch(
int(to_naive_utc(created_ts).timestamp() * 1e6) if created_ts else 0
)
for feature_name in values_dict:
for vector_list_type_name in numeric_vector_list_types:
vector_list = getattr(
values_dict[feature_name], vector_list_type_name, None
)
if vector_list:
vector_values = getattr(
values_dict[feature_name], vector_list_type_name
).val
if vector_values != []:
# Note here we are over-writing the feature and collapsing the list into a single value
values_dict[feature_name] = vector_values
feature_values = values_dict[feature_name]
for proto_val_type in PROTO_VALUE_TO_VALUE_TYPE_MAP:
if feature_values.HasField(proto_val_type):
if proto_val_type in numeric_vector_list_types:
vector_values = getattr(feature_values, proto_val_type).val
else:
vector_values = getattr(feature_values, proto_val_type)
values_dict[feature_name] = vector_values

single_entity_record = {
composite_key_name: entity_key_str,
Expand All @@ -243,40 +249,7 @@ def online_read(
entity_keys: List[EntityKeyProto],
requested_features: Optional[List[str]] = None,
) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]:
collection = self._get_collection(config, table)
results = []

for entity_key in entity_keys:
entity_key_str = serialize_entity_key(
entity_key,
entity_key_serialization_version=config.entity_key_serialization_version,
).hex()
expr = f"entity_key == '{entity_key_str}'"
if requested_features:
features_str = ", ".join([f"'{f}'" for f in requested_features])
expr += f" && feature_name in [{features_str}]"

res = collection.query(
expr,
output_fields=["feature_name", "value", "event_ts"],
consistency_level="Strong",
)

res_dict = {}
res_ts = None
for r in res:
feature_name = r["feature_name"]
val_bin = r["value"]
val = ValueProto()
val.ParseFromString(val_bin)
res_dict[feature_name] = val
res_ts = datetime.fromtimestamp(r["event_ts"] / 1e6)
if not res_dict:
results.append((None, None))
else:
results.append((res_ts, res_dict))

return results
raise NotImplementedError

def update(
self,
Expand Down Expand Up @@ -320,6 +293,7 @@ def retrieve_online_documents(
config: RepoConfig,
table: FeatureView,
requested_feature: str,
requested_features: List[str],
embedding: List[float],
top_k: int,
distance_metric: Optional[str] = None,
Expand All @@ -342,13 +316,22 @@ def retrieve_online_documents(
}
expr = f"feature_name == '{requested_feature}'"

composite_key_name = (
"_".join([str(value) for value in table.entity_columns]) + "_pk"
)
if requested_features:
features_str = ", ".join([f"'{f}'" for f in requested_features])
expr += f" && feature_name in [{features_str}]"

results = collection.search(
data=[embedding],
anns_field="vector_value",
param=search_params,
limit=top_k,
expr=expr,
output_fields=["entity_key", "value", "event_ts"],
output_fields=[composite_key_name]
+ requested_features
+ ["created_ts", "event_ts"],
consistency_level="Strong",
)

Expand Down
2 changes: 2 additions & 0 deletions sdk/python/feast/infra/passthrough_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,7 @@ def retrieve_online_documents(
config: RepoConfig,
table: FeatureView,
requested_feature: str,
requested_features: Optional[List[str]],
query: List[float],
top_k: int,
distance_metric: Optional[str] = None,
Expand All @@ -305,6 +306,7 @@ def retrieve_online_documents(
config,
table,
requested_feature,
requested_features,
query,
top_k,
distance_metric,
Expand Down
2 changes: 2 additions & 0 deletions sdk/python/feast/infra/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,7 @@ def retrieve_online_documents(
config: RepoConfig,
table: FeatureView,
requested_feature: str,
requested_features: Optional[List[str]],
query: List[float],
top_k: int,
distance_metric: Optional[str] = None,
Expand All @@ -440,6 +441,7 @@ def retrieve_online_documents(
config: The config for the current feature store.
table: The feature view whose embeddings should be searched.
requested_feature: the requested document feature name.
requested_features: the requested document feature names.
query: The query embedding to search for.
top_k: The number of documents to return.
Expand Down
5 changes: 1 addition & 4 deletions sdk/python/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,6 @@
driver,
location,
)
from tests.integration.feature_repos.universal.online_store.milvus import (
MilvusOnlineStoreCreator,
)
from tests.utils.auth_permissions_util import default_store
from tests.utils.generate_self_signed_certifcate_util import generate_self_signed_cert
from tests.utils.http_server import check_port_open, free_port # noqa: E402
Expand Down Expand Up @@ -204,7 +201,6 @@ def environment(request, worker_id):
e.teardown()



@pytest.fixture
def vectordb_environment(request, worker_id):
db_config = IntegrationTestRepoConfig(
Expand All @@ -231,6 +227,7 @@ def vectordb_environment(request, worker_id):

e.teardown()


_config_cache: Any = {}


Expand Down
2 changes: 2 additions & 0 deletions sdk/python/tests/data/data_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ def get_feature_values_for_dtype(
def create_document_dataset() -> pd.DataFrame:
data = {
"item_id": [1, 2, 3],
"string_feature": ["a", "b", "c"],
"float_feature": [1.0, 2.0, 3.0],
"embedding_float": [[4.0, 5.0], [1.0, 2.0], [3.0, 4.0]],
"embedding_double": [[4.0, 5.0], [1.0, 2.0], [3.0, 4.0]],
"ts": [
Expand Down
1 change: 1 addition & 0 deletions sdk/python/tests/foo_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ def retrieve_online_documents(
config: RepoConfig,
table: FeatureView,
requested_feature: str,
requested_features: Optional[List[str]],
query: List[float],
top_k: int,
distance_metric: Optional[str] = None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from feast.data_source import DataSource, RequestSource
from feast.feature_view_projection import FeatureViewProjection
from feast.on_demand_feature_view import PandasTransformation, SubstraitTransformation
from feast.types import Array, FeastType, Float32, Float64, Int32, Int64
from feast.types import Array, FeastType, Float32, Float64, Int32, Int64, String
from tests.integration.feature_repos.universal.entities import (
customer,
driver,
Expand Down Expand Up @@ -160,8 +160,20 @@ def create_item_embeddings_feature_view(source, infer_features: bool = False):
schema=None
if infer_features
else [
Field(name="embedding_double", dtype=Array(Float64)),
Field(name="embedding_float", dtype=Array(Float32)),
Field(
name="embedding_double",
dtype=Array(Float64),
vector_index=True,
vector_search_metric="L2",
),
Field(
name="embedding_float",
dtype=Array(Float32),
vector_index=True,
vector_search_metric="L2",
),
Field(name="string_feature", dtype=String),
Field(name="float_feature", dtype=Float32),
],
source=source,
ttl=timedelta(hours=2),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -903,7 +903,12 @@ def test_retrieve_online_documents2(environment, fake_document_data):
fs.apply([item_embeddings_feature_view, item()])
fs.write_to_online_store("item_embeddings", df)
documents = fs.retrieve_online_documents(
feature="item_embeddings:embedding_float",
feature=None,
features=[
"item_embeddings:embedding_float",
"item_embeddings:item_id",
"item_embeddings:string_feature",
],
query=[1.0, 2.0],
top_k=2,
distance_metric="L2",
Expand Down
3 changes: 0 additions & 3 deletions sdk/python/tests/unit/online_store/test_online_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@

from feast import FeatureStore, RepoConfig
from feast.errors import FeatureViewNotFoundException
from feast.infra.online_stores.milvus_online_store.milvus import MilvusOnlineStoreConfig
from feast.infra.provider import Provider
from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto
from feast.protos.feast.types.Value_pb2 import FloatList as FloatListProto
from feast.protos.feast.types.Value_pb2 import Value as ValueProto
Expand Down Expand Up @@ -563,4 +561,3 @@ def test_sqlite_vec_import() -> None:
""").fetchall()
result = [(rowid, round(distance, 2)) for rowid, distance in result]
assert result == [(2, 2.39), (1, 2.39)]

0 comments on commit d05d601

Please sign in to comment.