Skip to content

Commit

Permalink
Support unit in SortType (#590)
Browse files Browse the repository at this point in the history
create sort context and support unit in lat_lon distance sort
  • Loading branch information
waziqi89 authored Sep 11, 2023
1 parent 8caaeab commit 100c28b
Show file tree
Hide file tree
Showing 16 changed files with 1,097 additions and 771 deletions.
2 changes: 1 addition & 1 deletion build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ sourceCompatibility = 1.14
targetCompatibility = 1.14

allprojects {
version = '0.26.1'
version = '0.27.0'
group = 'com.yelp.nrtsearch'
}

Expand Down
2 changes: 2 additions & 0 deletions clientlib/src/main/proto/yelp/nrtsearch/search.proto
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,8 @@ message SortType {
bool missingLat = 4;
// Sort in reverse of the field's natural order
bool reverse = 5;
// The unit used for the distance sort. Supported options are m, km and mi, default is m
string unit = 6;
}

/* For multi valued fields, how to select which value is used for sorting */
Expand Down
4 changes: 4 additions & 0 deletions grpc-gateway/luceneserver.swagger.json
Original file line number Diff line number Diff line change
Expand Up @@ -4880,6 +4880,10 @@
"reverse": {
"type": "boolean",
"title": "Sort in reverse of the field's natural order"
},
"unit": {
"type": "string",
"title": "The unit used for the distance sort. Supported options are m, km and mi, default is m"
}
},
"description": "\"The field to sort on. Pass \u003ccode\u003edocid\u003c/code\u003e for index order and \u003ccode\u003escore\u003c/code\u003e for relevance sort."
Expand Down
1,314 changes: 662 additions & 652 deletions grpc-gateway/search.pb.go

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,16 @@
import com.yelp.nrtsearch.server.grpc.GeoBoundingBoxQuery;
import com.yelp.nrtsearch.server.grpc.GeoRadiusQuery;
import com.yelp.nrtsearch.server.grpc.Point;
import com.yelp.nrtsearch.server.grpc.SearchResponse.Hit.CompositeFieldValue;
import com.yelp.nrtsearch.server.grpc.SearchResponse.Hit.FieldValue;
import com.yelp.nrtsearch.server.grpc.SortType;
import com.yelp.nrtsearch.server.luceneserver.doc.LoadedDocValues;
import com.yelp.nrtsearch.server.luceneserver.field.properties.GeoQueryable;
import com.yelp.nrtsearch.server.luceneserver.field.properties.Sortable;
import com.yelp.nrtsearch.server.luceneserver.geo.GeoUtils;
import java.io.IOException;
import java.util.List;
import java.util.function.BiFunction;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.FieldType;
import org.apache.lucene.document.LatLonDocValuesField;
Expand Down Expand Up @@ -156,4 +159,13 @@ public Query getGeoRadiusQuery(GeoRadiusQuery geoRadiusQuery) {
geoRadiusQuery.getCenter().getLongitude(),
radius);
}

@Override
public BiFunction<SortField, Object, CompositeFieldValue> sortValueExtractor(SortType sortType) {
double multiplier = GeoUtils.convertDistanceToADifferentUnit(1.0, sortType.getUnit());
return (sortField, value) ->
CompositeFieldValue.newBuilder()
.addFieldValue(FieldValue.newBuilder().setDoubleValue(multiplier * (double) value))
.build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,11 @@
*/
package com.yelp.nrtsearch.server.luceneserver.field.properties;

import com.yelp.nrtsearch.server.grpc.SearchResponse.Hit.CompositeFieldValue;
import com.yelp.nrtsearch.server.grpc.Selector;
import com.yelp.nrtsearch.server.grpc.SortType;
import com.yelp.nrtsearch.server.luceneserver.search.sort.SortParser;
import java.util.function.BiFunction;
import java.util.function.Function;
import org.apache.lucene.search.SortField;
import org.apache.lucene.search.SortedNumericSelector;
Expand Down Expand Up @@ -62,4 +65,16 @@ public interface Sortable {
* @return sort field for this type
*/
SortField getSortField(SortType type);

/**
* Allow customized sorted value processing before return per fieldDef. The validation must be
* completed here, and throw an exception if it is failed.
*
* @param sortType settings for this sort
* @return Extractor method to process the value
* @throws IllegalArgumentException if validation fails
*/
default BiFunction<SortField, Object, CompositeFieldValue> sortValueExtractor(SortType sortType) {
return SortParser.DEFAULT_SORT_VALUE_EXTRACTOR;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -51,4 +51,17 @@ public static double getDistance(String rawDistance) {
throw new IllegalArgumentException("Invalid distance " + rawDistance);
}
}

public static double convertDistanceToADifferentUnit(double distanceNumber, String unit) {
String distanceUnit = unit.strip().toLowerCase();
if (distanceUnit.isEmpty() || distanceUnit.equals("m")) {
return distanceNumber;
} else if (distanceUnit.equals("km")) {
return distanceNumber / KM_TO_M;
} else if (distanceUnit.equals("mi")) {
return distanceNumber / MI_TO_M;
} else {
throw new IllegalArgumentException("Invalid unit " + unit);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
import com.yelp.nrtsearch.server.luceneserver.field.FieldDef;
import com.yelp.nrtsearch.server.luceneserver.field.IdFieldDef;
import com.yelp.nrtsearch.server.luceneserver.field.properties.GlobalOrdinalable;
import com.yelp.nrtsearch.server.luceneserver.search.SortParser;
import com.yelp.nrtsearch.server.luceneserver.search.sort.SortParser;
import com.yelp.nrtsearch.server.luceneserver.state.StateUtils;
import java.io.File;
import java.io.IOException;
Expand Down Expand Up @@ -235,7 +235,6 @@ public ImmutableIndexState(
indexSort =
SortParser.parseSort(
mergedSettings.getIndexSort().getSortedFieldsList(),
null,
fieldAndFacetState.getFields());
validateIndexSort(indexSort);
} catch (SearchHandlerException e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,25 +18,21 @@
import com.yelp.nrtsearch.server.grpc.QuerySortField;
import com.yelp.nrtsearch.server.luceneserver.IndexState;
import com.yelp.nrtsearch.server.luceneserver.QueryNodeMapper;
import com.yelp.nrtsearch.server.luceneserver.SearchHandler.SearchHandlerException;
import com.yelp.nrtsearch.server.luceneserver.ShardState;
import com.yelp.nrtsearch.server.luceneserver.field.FieldDef;
import com.yelp.nrtsearch.server.luceneserver.highlights.HighlightFetchTask;
import com.yelp.nrtsearch.server.luceneserver.search.FetchTasks;
import com.yelp.nrtsearch.server.luceneserver.search.FieldFetchContext;
import com.yelp.nrtsearch.server.luceneserver.search.SearchContext;
import com.yelp.nrtsearch.server.luceneserver.search.SortParser;
import com.yelp.nrtsearch.server.luceneserver.search.sort.SortContext;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import org.apache.lucene.facet.taxonomy.SearcherTaxonomyManager;
import org.apache.lucene.facet.taxonomy.SearcherTaxonomyManager.SearcherAndTaxonomy;
import org.apache.lucene.search.CollectorManager;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.Sort;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TopDocsCollector;
import org.apache.lucene.search.TopFieldCollector;
Expand All @@ -62,8 +58,8 @@ public class InnerHitContext implements FieldFetchContext {
private final int topHits;
private final Map<String, FieldDef> queryFields;
private final Map<String, FieldDef> retrieveFields;
private final List<String> sortedFieldNames;
private final Sort sort;
private final SortContext sortContext;

private final CollectorManager<? extends TopDocsCollector, ? extends TopDocs>
topDocsCollectorManager;
private final FetchTasks fetchTasks;
Expand Down Expand Up @@ -93,23 +89,15 @@ private InnerHitContext(InnerHitContextBuilder builder, boolean needValidation)

if (builder.querySort == null) {
// relevance collector
this.sortedFieldNames = Collections.EMPTY_LIST;
this.sort = null;
this.sortContext = null;
this.topDocsCollectorManager =
TopScoreDocCollector.createSharedManager(topHits, null, Integer.MAX_VALUE);
} else {
// sortedField collector
this.sortedFieldNames =
new ArrayList<>(builder.querySort.getFields().getSortedFieldsList().size());
try {
this.sort =
SortParser.parseSort(
builder.querySort.getFields().getSortedFieldsList(), sortedFieldNames, queryFields);
this.topDocsCollectorManager =
TopFieldCollector.createSharedManager(sort, topHits, null, Integer.MAX_VALUE);
} catch (SearchHandlerException e) {
throw new IllegalArgumentException(e);
}
this.sortContext = new SortContext(builder.querySort, queryFields);
this.topDocsCollectorManager =
TopFieldCollector.createSharedManager(
sortContext.getSort(), topHits, null, Integer.MAX_VALUE);
}

if (needValidation) {
Expand Down Expand Up @@ -239,17 +227,9 @@ public Map<String, FieldDef> getRetrieveFields() {
return retrieveFields;
}

/**
* Get the field names used in sort if {@link QuerySortField} is in use, otherwise returns an
* empty list.
*/
public List<String> getSortedFieldNames() {
return sortedFieldNames;
}

/** Get the sort object if {@link QuerySortField} is in use, otherwise returns null. */
public Sort getSort() {
return sort;
public SortContext getSortContext() {
return sortContext;
}

/** Get the topDocsCollectorManager to collect the search results. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import com.yelp.nrtsearch.server.luceneserver.SearchHandler;
import com.yelp.nrtsearch.server.luceneserver.search.FetchTasks.FetchTask;
import com.yelp.nrtsearch.server.luceneserver.search.SearchContext;
import com.yelp.nrtsearch.server.luceneserver.search.SortParser;
import com.yelp.nrtsearch.server.luceneserver.search.sort.SortParser;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
Expand All @@ -42,7 +42,6 @@
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.ScorerSupplier;
import org.apache.lucene.search.SortField;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TopDocsCollector;
import org.apache.lucene.search.Weight;
Expand Down Expand Up @@ -74,7 +73,8 @@ public InnerHitFetchTask(InnerHitContext innerHitContext) throws IOException {
this.searcher = innerHitContext.getSearcherAndTaxonomy().searcher;
boolean needScore =
innerHitContext.getTopHits() >= innerHitContext.getStartHit()
&& (innerHitContext.getSort() == null || innerHitContext.getSort().needsScores());
&& (innerHitContext.getSortContext() == null
|| innerHitContext.getSortContext().getSort().needsScores());
// We support TopDocsCollector only, so top_scores is good enough
this.innerHitWeight =
searcher
Expand Down Expand Up @@ -132,15 +132,11 @@ public void processHit(
SearchResponse.Hit.Builder innerHitResponse = innerHitResultBuilder.addHitsBuilder();
ScoreDoc innerHit = topDocs.scoreDocs[innerHitIndex];
innerHitResponse.setLuceneDocId(innerHit.doc);
if (!innerHitContext.getSortedFieldNames().isEmpty()) {
if (innerHitContext.getSortContext() != null) {
// fill the sortedFields
FieldDoc fd = (FieldDoc) innerHit;
for (int i = 0; i < fd.fields.length; ++i) {
SortField sortField = innerHitContext.getSort().getSort()[i];
innerHitResponse.putSortedFields(
innerHitContext.getSortedFieldNames().get(i),
SortParser.getValueForSortField(sortField, fd.fields[i]));
}
innerHitResponse.putAllSortedFields(
SortParser.getAllSortedValues(fd, innerHitContext.getSortContext()));
innerHitResponse.setScore(Double.NaN);
} else {
innerHitResponse.setScore(innerHit.score);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,13 @@

import com.yelp.nrtsearch.server.grpc.CollectorResult;
import com.yelp.nrtsearch.server.grpc.SearchResponse;
import com.yelp.nrtsearch.server.luceneserver.SearchHandler;
import com.yelp.nrtsearch.server.luceneserver.search.SortParser;
import java.util.ArrayList;
import com.yelp.nrtsearch.server.luceneserver.search.sort.SortContext;
import com.yelp.nrtsearch.server.luceneserver.search.sort.SortParser;
import java.util.List;
import org.apache.lucene.search.Collector;
import org.apache.lucene.search.CollectorManager;
import org.apache.lucene.search.FieldDoc;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.Sort;
import org.apache.lucene.search.SortField;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TopFieldCollector;
import org.apache.lucene.search.TopFieldDocs;
Expand All @@ -40,8 +37,7 @@ public class SortFieldCollector extends DocCollector {
private static final Logger logger = LoggerFactory.getLogger(SortFieldCollector.class);

private final CollectorManager<TopFieldCollector, TopFieldDocs> manager;
private final Sort sort;
private final List<String> sortNames;
private final SortContext sortContext;

public SortFieldCollector(
CollectorCreatorContext context,
Expand All @@ -61,18 +57,10 @@ public SortFieldCollector(
totalHitsThreshold = context.getRequest().getTotalHitsThreshold();
}

sortNames =
new ArrayList<>(context.getRequest().getQuerySort().getFields().getSortedFieldsCount());
try {
sort =
SortParser.parseSort(
context.getRequest().getQuerySort().getFields().getSortedFieldsList(),
sortNames,
context.getQueryFields());
} catch (SearchHandler.SearchHandlerException e) {
throw new IllegalArgumentException(e);
}
manager = TopFieldCollector.createSharedManager(sort, topHits, searchAfter, totalHitsThreshold);
sortContext = new SortContext(context.getRequest().getQuerySort(), context.getQueryFields());
manager =
TopFieldCollector.createSharedManager(
sortContext.getSort(), topHits, searchAfter, totalHitsThreshold);
}

@Override
Expand All @@ -83,26 +71,7 @@ public SortFieldCollector(
@Override
public void fillHitRanking(SearchResponse.Hit.Builder hitResponse, ScoreDoc scoreDoc) {
FieldDoc fd = (FieldDoc) scoreDoc;
if (fd.fields.length != sort.getSort().length) {
throw new IllegalArgumentException(
"Size mismatch between Sort and ScoreDoc: "
+ sort.getSort().length
+ " != "
+ fd.fields.length);
}
if (fd.fields.length != sortNames.size()) {
throw new IllegalArgumentException(
"Size mismatch between Sort and Sort names: "
+ fd.fields.length
+ " != "
+ sortNames.size());
}

for (int i = 0; i < fd.fields.length; ++i) {
SortField sortField = sort.getSort()[i];
hitResponse.putSortedFields(
sortNames.get(i), SortParser.getValueForSortField(sortField, fd.fields[i]));
}
hitResponse.putAllSortedFields(SortParser.getAllSortedValues(fd, sortContext));
}

@Override
Expand Down
Loading

0 comments on commit 100c28b

Please sign in to comment.