From 72c6a1e1ba0c2eb420bc25141cae5f77646632d2 Mon Sep 17 00:00:00 2001 From: "opensearch-trigger-bot[bot]" <98922864+opensearch-trigger-bot[bot]@users.noreply.github.com> Date: Mon, 6 Jan 2025 08:03:47 -0800 Subject: [PATCH] Ensure consistent filter query in KNNQueryBuilder across multiple shards (#2362) Signed-off-by: Sahil Buddharaju Signed-off-by: sahil <61558528+buddharajusahil@users.noreply.github.com> Co-authored-by: Sahil Buddharaju (cherry picked from commit c969f1d0c0df66791b06b561897a3c327c1cdeb7) --- CHANGELOG.md | 1 + .../knn/index/query/KNNQueryBuilder.java | 19 +++++++- .../org/opensearch/knn/index/FaissIT.java | 45 ++++++++++++++++++ .../opensearch/knn/index/LuceneEngineIT.java | 46 +++++++++++++++++++ .../knn/index/query/KNNQueryBuilderTests.java | 3 ++ 5 files changed, 112 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 519a077f4..5f95af22d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -27,6 +27,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), * Allow validation for non knn index only after 2.17.0 (#2315)[https://github.com/opensearch-project/k-NN/pull/2315] * Release query vector memory after execution (#2346)[https://github.com/opensearch-project/k-NN/pull/2346] * Fix shard level rescoring disabled setting flag (#2352)[https://github.com/opensearch-project/k-NN/pull/2352] +* Fix filter rewrite logic which was resulting in getting inconsistent / incorrect results for cases where filter was getting rewritten for shards (#2359)[https://github.com/opensearch-project/k-NN/pull/2359] ### Infrastructure * Updated C++ version in JNI from c++11 to c++17 [#2259](https://github.com/opensearch-project/k-NN/pull/2259) * Upgrade bytebuddy and objenesis version to match OpenSearch core and, update github ci runner for macos [#2279](https://github.com/opensearch-project/k-NN/pull/2279) diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java index d2b169b2a..ee18394f6 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java @@ -662,9 +662,24 @@ public String getWriteableName() { @Override protected QueryBuilder doRewrite(QueryRewriteContext queryShardContext) throws IOException { - // rewrite filter query if it exists to avoid runtime errors in next steps of query phase + QueryBuilder rewrittenFilter; if (Objects.nonNull(filter)) { - filter = filter.rewrite(queryShardContext); + rewrittenFilter = filter.rewrite(queryShardContext); + if (rewrittenFilter != filter) { + KNNQueryBuilder rewrittenQueryBuilder = KNNQueryBuilder.builder() + .fieldName(this.fieldName) + .vector(this.vector) + .k(this.k) + .maxDistance(this.maxDistance) + .minScore(this.minScore) + .methodParameters(this.methodParameters) + .filter(rewrittenFilter) + .ignoreUnmapped(this.ignoreUnmapped) + .rescoreContext(this.rescoreContext) + .expandNested(this.expandNested) + .build(); + return rewrittenQueryBuilder; + } } return super.doRewrite(queryShardContext); } diff --git a/src/test/java/org/opensearch/knn/index/FaissIT.java b/src/test/java/org/opensearch/knn/index/FaissIT.java index 317cff4d4..66a1c8e99 100644 --- a/src/test/java/org/opensearch/knn/index/FaissIT.java +++ b/src/test/java/org/opensearch/knn/index/FaissIT.java @@ -412,6 +412,51 @@ public void testRadialQuery_withFilter_thenSuccess() { deleteKNNIndex(INDEX_NAME); } + @SneakyThrows + public void testQueryWithFilterMultipleShards() { + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject(PROPERTIES_FIELD_NAME) + .startObject(FIELD_NAME) + .field(TYPE_FIELD_NAME, KNN_VECTOR_TYPE) + .field(DIMENSION_FIELD_NAME, "3") + .startObject(KNNConstants.KNN_METHOD) + .field(KNNConstants.NAME, METHOD_HNSW) + .field(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, SpaceType.L2.getValue()) + .field(KNNConstants.KNN_ENGINE, KNNEngine.FAISS.getName()) + .endObject() + .endObject() + .startObject(INTEGER_FIELD_NAME) + .field(TYPE_FIELD_NAME, FILED_TYPE_INTEGER) + .endObject() + .endObject() + .endObject(); + String mapping = builder.toString(); + + createIndex(INDEX_NAME, Settings.builder().put("number_of_shards", 10).put("number_of_replicas", 0).put("index.knn", true).build()); + putMappingRequest(INDEX_NAME, mapping); + + addKnnDocWithAttributes("doc1", new float[] { 7.0f, 7.0f, 3.0f }, ImmutableMap.of("dateReceived", "2024-10-01")); + + refreshIndex(INDEX_NAME); + + final float[] searchVector = { 6.0f, 7.0f, 3.0f }; + final Response response = searchKNNIndex( + INDEX_NAME, + new KNNQueryBuilder( + FIELD_NAME, + searchVector, + 1, + QueryBuilders.boolQuery().must(QueryBuilders.rangeQuery("dateReceived").gte("2023-11-01")) + ), + 10 + ); + final String responseBody = EntityUtils.toString(response.getEntity()); + final List knnResults = parseSearchResponse(responseBody, FIELD_NAME); + + assertEquals(1, knnResults.size()); + } + @SneakyThrows public void testEndToEnd_whenMethodIsHNSWPQ_thenSucceed() { String indexName = "test-index"; diff --git a/src/test/java/org/opensearch/knn/index/LuceneEngineIT.java b/src/test/java/org/opensearch/knn/index/LuceneEngineIT.java index 64e5af1e7..2afe376e4 100644 --- a/src/test/java/org/opensearch/knn/index/LuceneEngineIT.java +++ b/src/test/java/org/opensearch/knn/index/LuceneEngineIT.java @@ -15,6 +15,7 @@ import org.opensearch.client.Response; import org.opensearch.client.ResponseException; import org.opensearch.common.Nullable; +import org.opensearch.common.settings.Settings; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.index.query.QueryBuilder; @@ -282,6 +283,51 @@ public void testQueryWithFilterUsingByteVectorDataType() { validateQueryResultsWithFilters(searchVector, 5, 1, expectedDocIdsKGreaterThanFilterResult, expectedDocIdsKLimitsFilterResult); } + @SneakyThrows + public void testQueryWithFilterMultipleShards() { + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject(PROPERTIES_FIELD_NAME) + .startObject(FIELD_NAME) + .field(TYPE_FIELD_NAME, KNN_VECTOR_TYPE) + .field(DIMENSION_FIELD_NAME, DIMENSION) + .startObject(KNNConstants.KNN_METHOD) + .field(KNNConstants.NAME, METHOD_HNSW) + .field(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, SpaceType.L2.getValue()) + .field(KNNConstants.KNN_ENGINE, KNNEngine.LUCENE.getName()) + .endObject() + .endObject() + .startObject(INTEGER_FIELD_NAME) + .field(TYPE_FIELD_NAME, FILED_TYPE_INTEGER) + .endObject() + .endObject() + .endObject(); + String mapping = builder.toString(); + + createIndex(INDEX_NAME, Settings.builder().put("number_of_shards", 10).put("number_of_replicas", 0).put("index.knn", true).build()); + putMappingRequest(INDEX_NAME, mapping); + + addKnnDocWithAttributes("doc1", new float[] { 7.0f, 7.0f, 3.0f }, ImmutableMap.of("dateReceived", "2024-10-01")); + + refreshIndex(INDEX_NAME); + + final float[] searchVector = { 6.0f, 7.0f, 3.0f }; + final Response response = searchKNNIndex( + INDEX_NAME, + new KNNQueryBuilder( + FIELD_NAME, + searchVector, + 1, + QueryBuilders.boolQuery().must(QueryBuilders.rangeQuery("dateReceived").gte("2023-11-01")) + ), + 10 + ); + final String responseBody = EntityUtils.toString(response.getEntity()); + final List knnResults = parseSearchResponse(responseBody, FIELD_NAME); + + assertEquals(1, knnResults.size()); + } + @SneakyThrows public void testQueryWithFilter_whenNonExistingFieldUsedInFilter_thenSuccessful() { XContentBuilder builder = XContentFactory.jsonBuilder() diff --git a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java index 30c8007e6..1333d616e 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java @@ -1066,11 +1066,14 @@ public void testDoRewrite_whenFilterSet_thenSuccessful() { .filter(rewrittenFilter) .k(K) .build(); + // When KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder.builder().fieldName(FIELD_NAME).vector(QUERY_VECTOR).filter(filter).k(K).build(); QueryBuilder actual = knnQueryBuilder.rewrite(context); + assertEquals(knnQueryBuilder, KNNQueryBuilder.builder().fieldName(FIELD_NAME).vector(QUERY_VECTOR).filter(filter).k(K).build()); + // Then assertEquals(expected, actual); }