Skip to content

Commit

Permalink
[Backport 2.x] Pagination in hybrid query (#1099)
Browse files Browse the repository at this point in the history
* Pagination in Hybrid query (#1048)

* Pagination in Hybrid query

Signed-off-by: Varun Jain <[email protected]>

* Remove unwanted code

Signed-off-by: Varun Jain <[email protected]>

* Adding hybrid query context dto

Signed-off-by: Varun Jain <[email protected]>

* Adding javadoc in hybridquerycontext and addressing few comments from review

Signed-off-by: Varun Jain <[email protected]>

* rename hybrid query extraction method

Signed-off-by: Varun Jain <[email protected]>

* Refactoring to optimize extractHybridQuery method calls

Signed-off-by: Varun Jain <[email protected]>

* Changes in tests to adapt  with builder pattern in querybuilder

Signed-off-by: Varun Jain <[email protected]>

* Add mapper service mock in tests

Signed-off-by: Varun Jain <[email protected]>

* Fix error message of index.max_result_window setting

Signed-off-by: Varun Jain <[email protected]>

* Fix error message of index.max_result_window setting

Signed-off-by: Varun Jain <[email protected]>

* Fixing validation condition for lower bound

Signed-off-by: Varun Jain <[email protected]>

* fix tests

Signed-off-by: Varun Jain <[email protected]>

* Removing version check from doEquals and doHashCode method

Signed-off-by: Varun Jain <[email protected]>

---------

Signed-off-by: Varun Jain <[email protected]>

* Update pagination_depth datatype from int to Integer (#1094)

* Update pagination_depth datatype from int to Integer

Signed-off-by: Varun Jain <[email protected]>

---------

Signed-off-by: Varun Jain <[email protected]>
  • Loading branch information
vibrantvarun authored Jan 14, 2025
1 parent a349cf8 commit 96fa384
Show file tree
Hide file tree
Showing 24 changed files with 884 additions and 103 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),

## [Unreleased 2.x](https://github.com/opensearch-project/neural-search/compare/2.18...2.x)
### Features
- Pagination in Hybrid query ([#1048](https://github.com/opensearch-project/neural-search/pull/1048))
### Enhancements
- Explainability in hybrid query ([#970](https://github.com/opensearch-project/neural-search/pull/970))
- Implement pruning for neural sparse ingestion pipeline and two phase search processor ([#988](https://github.com/opensearch-project/neural-search/pull/988))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ public final class MinClusterVersionUtil {
private static final Version MINIMAL_SUPPORTED_VERSION_DEFAULT_MODEL_ID = Version.V_2_11_0;
private static final Version MINIMAL_SUPPORTED_VERSION_RADIAL_SEARCH = Version.V_2_14_0;
private static final Version MINIMAL_SUPPORTED_VERSION_QUERY_IMAGE_FIX = Version.V_2_19_0;
private static final Version MINIMAL_SUPPORTED_VERSION_PAGINATION_IN_HYBRID_QUERY = Version.V_2_19_0;

// Note this minimal version will act as a override
private static final Map<String, Version> MINIMAL_VERSION_NEURAL = ImmutableMap.<String, Version>builder()
Expand All @@ -41,6 +42,10 @@ public static boolean isClusterOnOrAfterMinReqVersionForRadialSearch() {
return NeuralSearchClusterUtil.instance().getClusterMinVersion().onOrAfter(MINIMAL_SUPPORTED_VERSION_RADIAL_SEARCH);
}

public static boolean isClusterOnOrAfterMinReqVersionForPaginationInHybridQuery() {
return NeuralSearchClusterUtil.instance().getClusterMinVersion().onOrAfter(MINIMAL_SUPPORTED_VERSION_PAGINATION_IN_HYBRID_QUERY);
}

public static boolean isClusterOnOrAfterMinReqVersion(String key) {
Version version;
if (MINIMAL_VERSION_NEURAL.containsKey(key)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ private <Result extends SearchPhaseResult> void prepareAndExecuteNormalizationWo
.combinationTechnique(combinationTechnique)
.explain(explain)
.pipelineProcessingContext(requestContextOptional.orElse(null))
.searchPhaseContext(searchPhaseContext)
.build();
normalizationWorkflow.execute(request);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import org.apache.lucene.search.Sort;
import org.apache.lucene.search.TopFieldDocs;
import org.apache.lucene.search.FieldDoc;
import org.opensearch.action.search.SearchPhaseContext;
import org.opensearch.common.lucene.search.TopDocsAndMaxScore;
import org.opensearch.neuralsearch.processor.combination.CombineScoresDto;
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique;
Expand Down Expand Up @@ -64,25 +65,30 @@ public void execute(
final List<QuerySearchResult> querySearchResults,
final Optional<FetchSearchResult> fetchSearchResultOptional,
final ScoreNormalizationTechnique normalizationTechnique,
final ScoreCombinationTechnique combinationTechnique
final ScoreCombinationTechnique combinationTechnique,
final SearchPhaseContext searchPhaseContext
) {
NormalizationProcessorWorkflowExecuteRequest request = NormalizationProcessorWorkflowExecuteRequest.builder()
.querySearchResults(querySearchResults)
.fetchSearchResultOptional(fetchSearchResultOptional)
.normalizationTechnique(normalizationTechnique)
.combinationTechnique(combinationTechnique)
.explain(false)
.searchPhaseContext(searchPhaseContext)
.build();
execute(request);
}

public void execute(final NormalizationProcessorWorkflowExecuteRequest request) {
List<QuerySearchResult> querySearchResults = request.getQuerySearchResults();
Optional<FetchSearchResult> fetchSearchResultOptional = request.getFetchSearchResultOptional();

// save original state
List<Integer> unprocessedDocIds = unprocessedDocIds(request.getQuerySearchResults());
List<Integer> unprocessedDocIds = unprocessedDocIds(querySearchResults);

// pre-process data
log.debug("Pre-process query results");
List<CompoundTopDocs> queryTopDocs = getQueryTopDocs(request.getQuerySearchResults());
List<CompoundTopDocs> queryTopDocs = getQueryTopDocs(querySearchResults);

explain(request, queryTopDocs);

Expand All @@ -93,8 +99,9 @@ public void execute(final NormalizationProcessorWorkflowExecuteRequest request)
CombineScoresDto combineScoresDTO = CombineScoresDto.builder()
.queryTopDocs(queryTopDocs)
.scoreCombinationTechnique(request.getCombinationTechnique())
.querySearchResults(request.getQuerySearchResults())
.sort(evaluateSortCriteria(request.getQuerySearchResults(), queryTopDocs))
.querySearchResults(querySearchResults)
.sort(evaluateSortCriteria(querySearchResults, queryTopDocs))
.fromValueForSingleShard(getFromValueIfSingleShard(request))
.build();

// combine
Expand All @@ -103,8 +110,26 @@ public void execute(final NormalizationProcessorWorkflowExecuteRequest request)

// post-process data
log.debug("Post-process query results after score normalization and combination");
updateOriginalQueryResults(combineScoresDTO);
updateOriginalFetchResults(request.getQuerySearchResults(), request.getFetchSearchResultOptional(), unprocessedDocIds);
updateOriginalQueryResults(combineScoresDTO, fetchSearchResultOptional.isPresent());
updateOriginalFetchResults(
querySearchResults,
fetchSearchResultOptional,
unprocessedDocIds,
combineScoresDTO.getFromValueForSingleShard()
);
}

/**
* Get value of from parameter when there is a single shard
* and fetch phase is already executed
* Ref https://github.com/opensearch-project/OpenSearch/blob/main/server/src/main/java/org/opensearch/search/SearchService.java#L715
*/
private int getFromValueIfSingleShard(final NormalizationProcessorWorkflowExecuteRequest request) {
final SearchPhaseContext searchPhaseContext = request.getSearchPhaseContext();
if (searchPhaseContext.getNumShards() > 1 || request.fetchSearchResultOptional.isEmpty()) {
return -1;
}
return searchPhaseContext.getRequest().source().from();
}

/**
Expand Down Expand Up @@ -173,19 +198,33 @@ private List<CompoundTopDocs> getQueryTopDocs(final List<QuerySearchResult> quer
return queryTopDocs;
}

private void updateOriginalQueryResults(final CombineScoresDto combineScoresDTO) {
private void updateOriginalQueryResults(final CombineScoresDto combineScoresDTO, final boolean isFetchPhaseExecuted) {
final List<QuerySearchResult> querySearchResults = combineScoresDTO.getQuerySearchResults();
final List<CompoundTopDocs> queryTopDocs = getCompoundTopDocs(combineScoresDTO, querySearchResults);
final Sort sort = combineScoresDTO.getSort();
int totalScoreDocsCount = 0;
for (int index = 0; index < querySearchResults.size(); index++) {
QuerySearchResult querySearchResult = querySearchResults.get(index);
CompoundTopDocs updatedTopDocs = queryTopDocs.get(index);
totalScoreDocsCount += updatedTopDocs.getScoreDocs().size();
TopDocsAndMaxScore updatedTopDocsAndMaxScore = new TopDocsAndMaxScore(
buildTopDocs(updatedTopDocs, sort),
maxScoreForShard(updatedTopDocs, sort != null)
);
// Fetch Phase had ran before the normalization phase, therefore update the from value in result of each shard.
// This will ensure the trimming of the search results.
if (isFetchPhaseExecuted) {
querySearchResult.from(combineScoresDTO.getFromValueForSingleShard());
}
querySearchResult.topDocs(updatedTopDocsAndMaxScore, querySearchResult.sortValueFormats());
}

final int from = querySearchResults.get(0).from();
if (from > totalScoreDocsCount) {
throw new IllegalArgumentException(
String.format(Locale.ROOT, "Reached end of search result, increase pagination_depth value to see more results")
);
}
}

private List<CompoundTopDocs> getCompoundTopDocs(CombineScoresDto combineScoresDTO, List<QuerySearchResult> querySearchResults) {
Expand Down Expand Up @@ -244,7 +283,8 @@ private TopDocs buildTopDocs(CompoundTopDocs updatedTopDocs, Sort sort) {
private void updateOriginalFetchResults(
final List<QuerySearchResult> querySearchResults,
final Optional<FetchSearchResult> fetchSearchResultOptional,
final List<Integer> docIds
final List<Integer> docIds,
final int fromValueForSingleShard
) {
if (fetchSearchResultOptional.isEmpty()) {
return;
Expand Down Expand Up @@ -276,14 +316,21 @@ private void updateOriginalFetchResults(

QuerySearchResult querySearchResult = querySearchResults.get(0);
TopDocs topDocs = querySearchResult.topDocs().topDocs;
// Scenario to handle when calculating the trimmed length of updated search hits
// When normalization process runs after fetch phase, then search hits already fetched. Therefore, use the from value sent in the
// search request to calculate the effective length of updated search hits array.
int trimmedLengthOfSearchHits = topDocs.scoreDocs.length - fromValueForSingleShard;
// iterate over the normalized/combined scores, that solves (1) and (3)
SearchHit[] updatedSearchHitArray = Arrays.stream(topDocs.scoreDocs).map(scoreDoc -> {
SearchHit[] updatedSearchHitArray = new SearchHit[trimmedLengthOfSearchHits];
for (int i = 0; i < trimmedLengthOfSearchHits; i++) {
// Read topDocs after the desired from length
ScoreDoc scoreDoc = topDocs.scoreDocs[i + fromValueForSingleShard];
// get fetched hit content by doc_id
SearchHit searchHit = docIdToSearchHit.get(scoreDoc.doc);
// update score to normalized/combined value (3)
searchHit.score(scoreDoc.score);
return searchHit;
}).toArray(SearchHit[]::new);
updatedSearchHitArray[i] = searchHit;
}
SearchHits updatedSearchHits = new SearchHits(
updatedSearchHitArray,
querySearchResult.getTotalHits(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Getter;
import org.opensearch.action.search.SearchPhaseContext;
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique;
import org.opensearch.search.fetch.FetchSearchResult;
Expand All @@ -29,4 +30,5 @@ public class NormalizationProcessorWorkflowExecuteRequest {
final ScoreCombinationTechnique combinationTechnique;
boolean explain;
final PipelineProcessingContext pipelineProcessingContext;
final SearchPhaseContext searchPhaseContext;
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,5 @@ public class CombineScoresDto {
private List<QuerySearchResult> querySearchResults;
@Nullable
private Sort sort;
private int fromValueForSingleShard;
}
Original file line number Diff line number Diff line change
Expand Up @@ -70,14 +70,10 @@ public class ScoreCombiner {
public void combineScores(final CombineScoresDto combineScoresDTO) {
// iterate over results from each shard. Every CompoundTopDocs object has results from
// multiple sub queries, doc ids may repeat for each sub query results
ScoreCombinationTechnique scoreCombinationTechnique = combineScoresDTO.getScoreCombinationTechnique();
Sort sort = combineScoresDTO.getSort();
combineScoresDTO.getQueryTopDocs()
.forEach(
compoundQueryTopDocs -> combineShardScores(
combineScoresDTO.getScoreCombinationTechnique(),
compoundQueryTopDocs,
combineScoresDTO.getSort()
)
);
.forEach(compoundQueryTopDocs -> combineShardScores(scoreCombinationTechnique, compoundQueryTopDocs, sort));
}

private void combineShardScores(
Expand Down
18 changes: 14 additions & 4 deletions src/main/java/org/opensearch/neuralsearch/query/HybridQuery.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,17 +34,22 @@
public final class HybridQuery extends Query implements Iterable<Query> {

private final List<Query> subQueries;
private final HybridQueryContext queryContext;

/**
* Create new instance of hybrid query object based on collection of sub queries and filter query
* @param subQueries collection of queries that are executed individually and contribute to a final list of combined scores
* @param filterQueries list of filters that will be applied to each sub query. Each filter from the list is added as bool "filter" clause. If this is null sub queries will be executed as is
*/
public HybridQuery(final Collection<Query> subQueries, final List<Query> filterQueries) {
public HybridQuery(final Collection<Query> subQueries, final List<Query> filterQueries, final HybridQueryContext hybridQueryContext) {
Objects.requireNonNull(subQueries, "collection of queries must not be null");
if (subQueries.isEmpty()) {
throw new IllegalArgumentException("collection of queries must not be empty");
}
Integer paginationDepth = hybridQueryContext.getPaginationDepth();
if (Objects.nonNull(paginationDepth) && paginationDepth == 0) {
throw new IllegalArgumentException("pagination_depth must not be zero");
}
if (Objects.isNull(filterQueries) || filterQueries.isEmpty()) {
this.subQueries = new ArrayList<>(subQueries);
} else {
Expand All @@ -57,10 +62,11 @@ public HybridQuery(final Collection<Query> subQueries, final List<Query> filterQ
}
this.subQueries = modifiedSubQueries;
}
this.queryContext = hybridQueryContext;
}

public HybridQuery(final Collection<Query> subQueries) {
this(subQueries, List.of());
public HybridQuery(final Collection<Query> subQueries, final HybridQueryContext hybridQueryContext) {
this(subQueries, List.of(), hybridQueryContext);
}

/**
Expand Down Expand Up @@ -128,7 +134,7 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
return super.rewrite(indexSearcher);
}
final List<Query> rewrittenSubQueries = manager.getQueriesAfterRewrite(collectors);
return new HybridQuery(rewrittenSubQueries);
return new HybridQuery(rewrittenSubQueries, queryContext);
}

private Void rewriteQuery(Query query, HybridQueryExecutorCollector<IndexSearcher, Map.Entry<Query, Boolean>> collector) {
Expand Down Expand Up @@ -190,6 +196,10 @@ public Collection<Query> getSubQueries() {
return Collections.unmodifiableCollection(subQueries);
}

public HybridQueryContext getQueryContext() {
return queryContext;
}

/**
* Create the Weight used to score this query
*
Expand Down
Loading

0 comments on commit 96fa384

Please sign in to comment.