Skip to content

Commit

Permalink
Ensure consistent filter query in KNNQueryBuilder across multiple sha…
Browse files Browse the repository at this point in the history
…rds (#2362)

Signed-off-by: Sahil Buddharaju <[email protected]>
Signed-off-by: sahil <[email protected]>
Co-authored-by: Sahil Buddharaju <[email protected]>
(cherry picked from commit c969f1d)
  • Loading branch information
opensearch-trigger-bot[bot] authored Jan 6, 2025
1 parent 3cd212a commit 72c6a1e
Show file tree
Hide file tree
Showing 5 changed files with 112 additions and 2 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
* Allow validation for non knn index only after 2.17.0 (#2315)[https://github.com/opensearch-project/k-NN/pull/2315]
* Release query vector memory after execution (#2346)[https://github.com/opensearch-project/k-NN/pull/2346]
* Fix shard level rescoring disabled setting flag (#2352)[https://github.com/opensearch-project/k-NN/pull/2352]
* Fix filter rewrite logic which was resulting in getting inconsistent / incorrect results for cases where filter was getting rewritten for shards (#2359)[https://github.com/opensearch-project/k-NN/pull/2359]
### Infrastructure
* Updated C++ version in JNI from c++11 to c++17 [#2259](https://github.com/opensearch-project/k-NN/pull/2259)
* Upgrade bytebuddy and objenesis version to match OpenSearch core and, update github ci runner for macos [#2279](https://github.com/opensearch-project/k-NN/pull/2279)
Expand Down
19 changes: 17 additions & 2 deletions src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -662,9 +662,24 @@ public String getWriteableName() {

@Override
protected QueryBuilder doRewrite(QueryRewriteContext queryShardContext) throws IOException {
// rewrite filter query if it exists to avoid runtime errors in next steps of query phase
QueryBuilder rewrittenFilter;
if (Objects.nonNull(filter)) {
filter = filter.rewrite(queryShardContext);
rewrittenFilter = filter.rewrite(queryShardContext);
if (rewrittenFilter != filter) {
KNNQueryBuilder rewrittenQueryBuilder = KNNQueryBuilder.builder()
.fieldName(this.fieldName)
.vector(this.vector)
.k(this.k)
.maxDistance(this.maxDistance)
.minScore(this.minScore)
.methodParameters(this.methodParameters)
.filter(rewrittenFilter)
.ignoreUnmapped(this.ignoreUnmapped)
.rescoreContext(this.rescoreContext)
.expandNested(this.expandNested)
.build();
return rewrittenQueryBuilder;
}
}
return super.doRewrite(queryShardContext);
}
Expand Down
45 changes: 45 additions & 0 deletions src/test/java/org/opensearch/knn/index/FaissIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,51 @@ public void testRadialQuery_withFilter_thenSuccess() {
deleteKNNIndex(INDEX_NAME);
}

@SneakyThrows
public void testQueryWithFilterMultipleShards() {
XContentBuilder builder = XContentFactory.jsonBuilder()
.startObject()
.startObject(PROPERTIES_FIELD_NAME)
.startObject(FIELD_NAME)
.field(TYPE_FIELD_NAME, KNN_VECTOR_TYPE)
.field(DIMENSION_FIELD_NAME, "3")
.startObject(KNNConstants.KNN_METHOD)
.field(KNNConstants.NAME, METHOD_HNSW)
.field(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, SpaceType.L2.getValue())
.field(KNNConstants.KNN_ENGINE, KNNEngine.FAISS.getName())
.endObject()
.endObject()
.startObject(INTEGER_FIELD_NAME)
.field(TYPE_FIELD_NAME, FILED_TYPE_INTEGER)
.endObject()
.endObject()
.endObject();
String mapping = builder.toString();

createIndex(INDEX_NAME, Settings.builder().put("number_of_shards", 10).put("number_of_replicas", 0).put("index.knn", true).build());
putMappingRequest(INDEX_NAME, mapping);

addKnnDocWithAttributes("doc1", new float[] { 7.0f, 7.0f, 3.0f }, ImmutableMap.of("dateReceived", "2024-10-01"));

refreshIndex(INDEX_NAME);

final float[] searchVector = { 6.0f, 7.0f, 3.0f };
final Response response = searchKNNIndex(
INDEX_NAME,
new KNNQueryBuilder(
FIELD_NAME,
searchVector,
1,
QueryBuilders.boolQuery().must(QueryBuilders.rangeQuery("dateReceived").gte("2023-11-01"))
),
10
);
final String responseBody = EntityUtils.toString(response.getEntity());
final List<KNNResult> knnResults = parseSearchResponse(responseBody, FIELD_NAME);

assertEquals(1, knnResults.size());
}

@SneakyThrows
public void testEndToEnd_whenMethodIsHNSWPQ_thenSucceed() {
String indexName = "test-index";
Expand Down
46 changes: 46 additions & 0 deletions src/test/java/org/opensearch/knn/index/LuceneEngineIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import org.opensearch.client.Response;
import org.opensearch.client.ResponseException;
import org.opensearch.common.Nullable;
import org.opensearch.common.settings.Settings;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.index.query.QueryBuilder;
Expand Down Expand Up @@ -282,6 +283,51 @@ public void testQueryWithFilterUsingByteVectorDataType() {
validateQueryResultsWithFilters(searchVector, 5, 1, expectedDocIdsKGreaterThanFilterResult, expectedDocIdsKLimitsFilterResult);
}

@SneakyThrows
public void testQueryWithFilterMultipleShards() {
XContentBuilder builder = XContentFactory.jsonBuilder()
.startObject()
.startObject(PROPERTIES_FIELD_NAME)
.startObject(FIELD_NAME)
.field(TYPE_FIELD_NAME, KNN_VECTOR_TYPE)
.field(DIMENSION_FIELD_NAME, DIMENSION)
.startObject(KNNConstants.KNN_METHOD)
.field(KNNConstants.NAME, METHOD_HNSW)
.field(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, SpaceType.L2.getValue())
.field(KNNConstants.KNN_ENGINE, KNNEngine.LUCENE.getName())
.endObject()
.endObject()
.startObject(INTEGER_FIELD_NAME)
.field(TYPE_FIELD_NAME, FILED_TYPE_INTEGER)
.endObject()
.endObject()
.endObject();
String mapping = builder.toString();

createIndex(INDEX_NAME, Settings.builder().put("number_of_shards", 10).put("number_of_replicas", 0).put("index.knn", true).build());
putMappingRequest(INDEX_NAME, mapping);

addKnnDocWithAttributes("doc1", new float[] { 7.0f, 7.0f, 3.0f }, ImmutableMap.of("dateReceived", "2024-10-01"));

refreshIndex(INDEX_NAME);

final float[] searchVector = { 6.0f, 7.0f, 3.0f };
final Response response = searchKNNIndex(
INDEX_NAME,
new KNNQueryBuilder(
FIELD_NAME,
searchVector,
1,
QueryBuilders.boolQuery().must(QueryBuilders.rangeQuery("dateReceived").gte("2023-11-01"))
),
10
);
final String responseBody = EntityUtils.toString(response.getEntity());
final List<KNNResult> knnResults = parseSearchResponse(responseBody, FIELD_NAME);

assertEquals(1, knnResults.size());
}

@SneakyThrows
public void testQueryWithFilter_whenNonExistingFieldUsedInFilter_thenSuccessful() {
XContentBuilder builder = XContentFactory.jsonBuilder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1066,11 +1066,14 @@ public void testDoRewrite_whenFilterSet_thenSuccessful() {
.filter(rewrittenFilter)
.k(K)
.build();

// When
KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder.builder().fieldName(FIELD_NAME).vector(QUERY_VECTOR).filter(filter).k(K).build();

QueryBuilder actual = knnQueryBuilder.rewrite(context);

assertEquals(knnQueryBuilder, KNNQueryBuilder.builder().fieldName(FIELD_NAME).vector(QUERY_VECTOR).filter(filter).k(K).build());

// Then
assertEquals(expected, actual);
}
Expand Down

0 comments on commit 72c6a1e

Please sign in to comment.