diff --git a/clientlib/src/main/proto/yelp/nrtsearch/search.proto b/clientlib/src/main/proto/yelp/nrtsearch/search.proto index 8e27f1744..54e9deb95 100644 --- a/clientlib/src/main/proto/yelp/nrtsearch/search.proto +++ b/clientlib/src/main/proto/yelp/nrtsearch/search.proto @@ -328,15 +328,15 @@ message MultiFunctionScoreQuery { DecayType decayType = 2; // Origin point to calculate the distance oneof Origin { - int32 numeric = 3; - google.type.LatLng geoPoint = 4; + google.type.LatLng geoPoint = 3; } - // Distance from origin + offset at which computed score will be equal to decay - string scale = 5; - // Compute decay function for docs with a distance greater than offset, will be 0.0 if none is set - string offset = 6; - // Defines how documents are scored at the distance - float decay = 7; + // Currently only distance based scale and offset units are supported + // Distance from origin + offset at which computed score will be equal to decay. Scale should be distance, unit (m, km, mi) with space is optional. Default unit will be meters. Ex: "10", "15 km", "5 m", "7 mi" + string scale = 4; + // Compute decay function for docs with a distance greater than offset, will be 0.0 if none is set. Offset should be distance, unit (m, km, mi) with space is optional. Default unit will be meters. Ex: "10", "15 km", "5 m", "7 mi" + string offset = 5; + // Defines decay rate for scoring. Should be between (0, 1) + float decay = 6; } enum DecayType { @@ -345,10 +345,9 @@ message MultiFunctionScoreQuery { // Linear decay function DECAY_TYPE_LINEAR = 1; // Gaussian decay function - DECAY_TYPE_GUASS = 2; + DECAY_TYPE_GUASSIAN = 2; } - // How to combine multiple function scores to produce a final function score enum FunctionScoreMode { // Multiply weighted function scores together diff --git a/docs/queries/multi_function_score.rst b/docs/queries/multi_function_score.rst index 60ba779fe..782ed2cc3 100644 --- a/docs/queries/multi_function_score.rst +++ b/docs/queries/multi_function_score.rst @@ -35,15 +35,15 @@ Proto definition: DecayType decayType = 2; // Origin point to calculate the distance oneof Origin { - int32 numeric = 3; - google.type.LatLng geoPoint = 4; + google.type.LatLng geoPoint = 3; } - // Distance from origin + offset at which computed score will be equal to decay - string scale = 5; - // Compute decay function for docs with a distance greater than offset, will be 0.0 if none is set - string offset = 6; - // Defines how documents are scored at the distance - float decay = 7; + // Currently only distance based scale and offset units are supported + // Distance from origin + offset at which computed score will be equal to decay. Scale should be distance, unit (m, km, mi) with space is optional. Default unit will be meters. Ex: "10", 15 km", "5 m", "7 mi" + string scale = 4; + // Compute decay function for docs with a distance greater than offset, will be 0.0 if none is set. Offset should be distance, unit (m, km, mi) with space is optional. Default unit will be meters. Ex: "10", 15 km", "5 m", "7 mi" + string offset = 5; + // Defines decay rate for scoring. Should be between (0, 1) + float decay = 6; } enum DecayType { @@ -52,7 +52,7 @@ Proto definition: // Linear decay function DECAY_TYPE_LINEAR = 1; // Gaussian decay function - DECAY_TYPE_GUASS = 2; + DECAY_TYPE_GUASSIAN = 2; } diff --git a/src/main/java/com/yelp/nrtsearch/server/luceneserver/search/query/multifunction/DecayFilterFunction.java b/src/main/java/com/yelp/nrtsearch/server/luceneserver/search/query/multifunction/DecayFilterFunction.java index 43d1deca7..2791fc52a 100644 --- a/src/main/java/com/yelp/nrtsearch/server/luceneserver/search/query/multifunction/DecayFilterFunction.java +++ b/src/main/java/com/yelp/nrtsearch/server/luceneserver/search/query/multifunction/DecayFilterFunction.java @@ -32,20 +32,19 @@ public DecayFilterFunction( Query filterQuery, float weight, MultiFunctionScoreQuery.DecayFunction decayFunction) { super(filterQuery, weight); if (decayFunction.getDecay() <= 0 || decayFunction.getDecay() >= 1) { - throw new IllegalArgumentException("decay rate should be between (0, 1)"); + throw new IllegalArgumentException( + "decay rate should be between (0, 1) but is " + decayFunction.getDecay()); } } protected DecayFunction getDecayType(MultiFunctionScoreQuery.DecayType decayType) { - switch (decayType) { - case DECAY_TYPE_GUASS: - return new GuassDecayFunction(); - case DECAY_TYPE_EXPONENTIAL: - return new ExponentialDecayFunction(); - case DECAY_TYPE_LINEAR: - return new LinearDecayFunction(); - default: - return null; - } + return switch (decayType) { + case DECAY_TYPE_GUASSIAN -> new GuassianDecayFunction(); + case DECAY_TYPE_EXPONENTIAL -> new ExponentialDecayFunction(); + case DECAY_TYPE_LINEAR -> new LinearDecayFunction(); + default -> throw new IllegalArgumentException( + decayType + + " not supported. Only exponential, guassian and linear decay functions are supported"); + }; } } diff --git a/src/main/java/com/yelp/nrtsearch/server/luceneserver/search/query/multifunction/DecayFunction.java b/src/main/java/com/yelp/nrtsearch/server/luceneserver/search/query/multifunction/DecayFunction.java index 8ab725ba3..af0205e71 100644 --- a/src/main/java/com/yelp/nrtsearch/server/luceneserver/search/query/multifunction/DecayFunction.java +++ b/src/main/java/com/yelp/nrtsearch/server/luceneserver/search/query/multifunction/DecayFunction.java @@ -18,9 +18,35 @@ import org.apache.lucene.search.Explanation; public interface DecayFunction { + /** + * Computes the decayed score based on the provided distance, offset, and scale. + * + * @param distance the distance from a given origin point. + * @param offset the point after which the decay starts. + * @param scale scale factor that influences the rate of decay. This scale value is computed from + * the user given scale using the computeScale() method. + * @return the decayed score after applying the decay function + */ double computeScore(double distance, double offset, double scale); + /** + * Computes the adjusted scale based on a user given scale and decay rate. + * + * @param scale user given scale. + * @param decay decay rate that decides how the score decreases. + * @return adjusted scale which will be used by the computeScore() method. + */ double computeScale(double scale, double decay); + /** + * Provides an explanation for the computed score based on the given distance, offset, and scale. + * + * @param distance the distance from a given origin point. + * @param offset the point after which the decay starts. + * @param scale scale factor that influences the rate of decay. This scale value is computed from + * the user given scale using the computeScale() method. + * @return Explanation object that details the calculations involved in computing the decayed + * score. + */ Explanation explainComputeScore(double distance, double offset, double scale); } diff --git a/src/main/java/com/yelp/nrtsearch/server/luceneserver/search/query/multifunction/FilterFunction.java b/src/main/java/com/yelp/nrtsearch/server/luceneserver/search/query/multifunction/FilterFunction.java index 77b4e4681..5ea1fde4b 100644 --- a/src/main/java/com/yelp/nrtsearch/server/luceneserver/search/query/multifunction/FilterFunction.java +++ b/src/main/java/com/yelp/nrtsearch/server/luceneserver/search/query/multifunction/FilterFunction.java @@ -77,12 +77,6 @@ public static FilterFunction build( ? QueryNodeMapper.getInstance().getQuery(filterFunctionGrpc.getFilter(), indexState) : null; float weight = filterFunctionGrpc.getWeight() != 0.0f ? filterFunctionGrpc.getWeight() : 1.0f; - if (filterFunctionGrpc.hasDecayFunction()) { - MultiFunctionScoreQuery.DecayFunction decayFunction = filterFunctionGrpc.getDecayFunction(); - if (decayFunction.hasGeoPoint()) { - return new GeoPointDecayFilterFunction(filterQuery, weight, decayFunction, indexState); - } - } switch (filterFunctionGrpc.getFunctionCase()) { case SCRIPT: ScoreScript.Factory factory = @@ -94,6 +88,14 @@ public static FilterFunction build( indexState.docLookup); return new ScriptFilterFunction( filterQuery, weight, filterFunctionGrpc.getScript(), scriptSource); + case DECAYFUNCTION: + MultiFunctionScoreQuery.DecayFunction decayFunction = filterFunctionGrpc.getDecayFunction(); + if (decayFunction.hasGeoPoint()) { + return new GeoPointDecayFilterFunction(filterQuery, weight, decayFunction, indexState); + } else { + throw new IllegalArgumentException( + "Decay Function should contain a geoPoint for Origin field"); + } case FUNCTION_NOT_SET: return new WeightFilterFunction(filterQuery, weight); default: diff --git a/src/main/java/com/yelp/nrtsearch/server/luceneserver/search/query/multifunction/GeoPointDecayFilterFunction.java b/src/main/java/com/yelp/nrtsearch/server/luceneserver/search/query/multifunction/GeoPointDecayFilterFunction.java index c30529718..be827c2f6 100644 --- a/src/main/java/com/yelp/nrtsearch/server/luceneserver/search/query/multifunction/GeoPointDecayFilterFunction.java +++ b/src/main/java/com/yelp/nrtsearch/server/luceneserver/search/query/multifunction/GeoPointDecayFilterFunction.java @@ -20,6 +20,7 @@ import com.yelp.nrtsearch.server.luceneserver.IndexState; import com.yelp.nrtsearch.server.luceneserver.doc.LoadedDocValues; import com.yelp.nrtsearch.server.luceneserver.doc.SegmentDocLookup; +import com.yelp.nrtsearch.server.luceneserver.field.FieldDef; import com.yelp.nrtsearch.server.luceneserver.field.LatLonFieldDef; import com.yelp.nrtsearch.server.luceneserver.geo.GeoPoint; import com.yelp.nrtsearch.server.luceneserver.geo.GeoUtils; @@ -31,6 +32,7 @@ import org.apache.lucene.search.Query; public class GeoPointDecayFilterFunction extends DecayFilterFunction { + private static final String LAT_LON = "LAT_LON"; private final MultiFunctionScoreQuery.DecayFunction decayFunction; private final String fieldName; @@ -47,7 +49,7 @@ public class GeoPointDecayFilterFunction extends DecayFilterFunction { * @param filterQuery filter to use when applying this function, or null if none * @param weight weight multiple to scale the function score * @param decayFunction to score a document with a function that decays depending on the distance - * between an origin point and a numeric doc field value + * between an origin point and a geoPoint doc field value * @param indexState indexState for validation and doc value lookup */ public GeoPointDecayFilterFunction( @@ -59,24 +61,33 @@ public GeoPointDecayFilterFunction( this.decayFunction = decayFunction; this.fieldName = decayFunction.getFieldName(); this.decayType = getDecayType(decayFunction.getDecayType()); + this.origin = decayFunction.getGeoPoint(); + this.decay = decayFunction.getDecay(); + double userGivenScale = GeoUtils.getDistance(decayFunction.getScale()); + this.scale = decayType.computeScale(userGivenScale, decay); this.offset = !decayFunction.getOffset().isEmpty() ? GeoUtils.getDistance(decayFunction.getOffset()) : 0.0; - this.decay = decayFunction.getDecay(); - double userGivenScale = GeoUtils.getDistance(decayFunction.getScale()); - this.scale = decayType.computeScale(userGivenScale, decay); this.indexState = indexState; - if (!decayFunction.hasGeoPoint()) { - throw new IllegalArgumentException("Decay Function should have a geoPoint for Origin field"); - } else { - this.origin = decayFunction.getGeoPoint(); + validateLatLonField(indexState.getField(fieldName)); + } + + public void validateLatLonField(FieldDef fieldDef) { + if (!LAT_LON.equals(fieldDef.getType())) { + throw new IllegalArgumentException( + fieldName + + " should be a LAT_LON to apply geoPoint decay function but it is: " + + fieldDef.getType()); } - LatLonFieldDef latLonFieldDef = (LatLonFieldDef) indexState.getField(fieldName); + LatLonFieldDef latLonFieldDef = (LatLonFieldDef) fieldDef; // TODO: Add support for multi-value fields if (latLonFieldDef.isMultiValue()) { throw new IllegalArgumentException( - "Multivalue fields are not supported for decay functions yet"); + "Multivalued fields are not supported for decay functions yet"); + } + if (!latLonFieldDef.hasDocValues()) { + throw new IllegalStateException("No doc values present for LAT_LON field: " + fieldName); } } @@ -88,7 +99,6 @@ public LeafFunction getLeafFunction(LeafReaderContext leafContext) { public final class GeoPointDecayLeafFunction implements LeafFunction { SegmentDocLookup segmentDocLookup; - LatLng latLng; public GeoPointDecayLeafFunction(LeafReaderContext context) { segmentDocLookup = indexState.docLookup.getSegmentLookup(context); @@ -97,36 +107,53 @@ public GeoPointDecayLeafFunction(LeafReaderContext context) { @Override public double score(int docId, float innerQueryScore) throws IOException { segmentDocLookup.setDocId(docId); - LoadedDocValues latLonValues = - (LoadedDocValues) segmentDocLookup.get(fieldName); - this.latLng = latLonValues.toFieldValue(0).getLatLngValue(); - double distance = - GeoUtils.arcDistance( - origin.getLatitude(), - origin.getLongitude(), - latLng.getLatitude(), - latLng.getLongitude()); - double score = decayType.computeScore(distance, offset, scale); - return score * getWeight(); + if (!validateDocValuesPresent()) { + return 0.0; + } else { + LoadedDocValues geoPointLoadedDocValues = + (LoadedDocValues) segmentDocLookup.get(fieldName); + GeoPoint latLng = geoPointLoadedDocValues.get(0); + double distance = + GeoUtils.arcDistance( + origin.getLatitude(), origin.getLongitude(), latLng.getLat(), latLng.getLon()); + double score = decayType.computeScore(distance, offset, scale); + return score * getWeight(); + } + } + + public boolean validateDocValuesPresent() { + return segmentDocLookup.containsKey(fieldName) && !segmentDocLookup.get(fieldName).isEmpty(); } @Override public Explanation explainScore(int docId, Explanation innerQueryScore) { - double distance = - GeoUtils.arcDistance( - origin.getLatitude(), - origin.getLongitude(), - latLng.getLatitude(), - latLng.getLongitude()); - double score = decayType.computeScore(distance, offset, scale); - Explanation distanceExp = - Explanation.match(distance, "arc distance calculated between two geoPoints"); - return Explanation.match( - score, - "final score with the provided decay function calculated by score * weight with " - + getWeight() - + " weight value", - List.of(distanceExp, decayType.explainComputeScore(distance, offset, scale))); + double score; + if (validateDocValuesPresent()) { + LoadedDocValues geoPointLoadedDocValues = + (LoadedDocValues) segmentDocLookup.get(fieldName); + GeoPoint latLng = geoPointLoadedDocValues.get(0); + double distance = + GeoUtils.arcDistance( + origin.getLatitude(), origin.getLongitude(), latLng.getLat(), latLng.getLon()); + + Explanation distanceExp = + Explanation.match(distance, "arc distance calculated between two geoPoints"); + + score = decayType.computeScore(distance, offset, scale); + double finalScore = score * getWeight(); + return Explanation.match( + finalScore, + "final score with the provided decay function calculated by score * weight with " + + getWeight() + + " weight value and " + + score + + "score", + List.of(distanceExp, decayType.explainComputeScore(distance, offset, scale))); + } else { + score = 0.0; + return Explanation.match( + score, "score is 0.0 since no doc values were present for " + fieldName); + } } } diff --git a/src/main/java/com/yelp/nrtsearch/server/luceneserver/search/query/multifunction/GuassDecayFunction.java b/src/main/java/com/yelp/nrtsearch/server/luceneserver/search/query/multifunction/GuassianDecayFunction.java similarity index 95% rename from src/main/java/com/yelp/nrtsearch/server/luceneserver/search/query/multifunction/GuassDecayFunction.java rename to src/main/java/com/yelp/nrtsearch/server/luceneserver/search/query/multifunction/GuassianDecayFunction.java index 51632c3bc..122384e7d 100644 --- a/src/main/java/com/yelp/nrtsearch/server/luceneserver/search/query/multifunction/GuassDecayFunction.java +++ b/src/main/java/com/yelp/nrtsearch/server/luceneserver/search/query/multifunction/GuassianDecayFunction.java @@ -17,7 +17,7 @@ import org.apache.lucene.search.Explanation; -public class GuassDecayFunction implements DecayFunction { +public class GuassianDecayFunction implements DecayFunction { @Override public double computeScore(double distance, double offset, double scale) { return Math.exp((-1.0 * Math.pow(Math.max(0.0, distance - offset), 2.0)) / 2.0 * scale); diff --git a/src/test/java/com/yelp/nrtsearch/server/luceneserver/search/query/multifunction/MultiFunctionScoreQueryTest.java b/src/test/java/com/yelp/nrtsearch/server/luceneserver/search/query/multifunction/MultiFunctionScoreQueryTest.java index cf4485a32..ebf5db847 100644 --- a/src/test/java/com/yelp/nrtsearch/server/luceneserver/search/query/multifunction/MultiFunctionScoreQueryTest.java +++ b/src/test/java/com/yelp/nrtsearch/server/luceneserver/search/query/multifunction/MultiFunctionScoreQueryTest.java @@ -16,6 +16,7 @@ package com.yelp.nrtsearch.server.luceneserver.search.query.multifunction; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; import com.google.type.LatLng; import com.yelp.nrtsearch.server.grpc.AddDocumentRequest; @@ -67,9 +68,6 @@ protected void initIndex(String name) throws Exception { MultiValuedField.newBuilder() .addValue("Document1 with none of filter terms") .build()) - .putFields( - "lat_lon_field", - MultiValuedField.newBuilder().addValue("34.0522").addValue("-118.2437").build()) .build(); docs.add(request); request = @@ -618,7 +616,33 @@ public void testExpDecayFunctionGeoPoint() { FunctionScoreMode.SCORE_MODE_MULTIPLY, BoostMode.BOOST_MODE_MULTIPLY); verifyResponseHitsWithDelta( - response, List.of(2, 4), List.of(3.4234246868436458E-6, 2.034676950471705E-18), 0.00000001); + response, List.of(2, 4), List.of(3.4234246868436458E-6, 2.034676950471705E-18), 0.0); + } + + @Test + public void testExpDecayFunctionNoDocValue() { + LatLng latLng = LatLng.newBuilder().setLatitude(40.7128).setLongitude(-74.0060).build(); + SearchResponse response = + doQuery( + Query.newBuilder() + .setMatchQuery( + MatchQuery.newBuilder().setField("text_field").setQuery("none").build()) + .build(), + List.of( + MultiFunctionScoreQuery.FilterFunction.newBuilder() + .setDecayFunction( + MultiFunctionScoreQuery.DecayFunction.newBuilder() + .setDecay(0.99f) + .setDecayType(MultiFunctionScoreQuery.DecayType.DECAY_TYPE_EXPONENTIAL) + .setOffset("0 m") + .setScale("1 km") + .setGeoPoint(latLng) + .setFieldName("lat_lon_field") + .build()) + .build()), + FunctionScoreMode.SCORE_MODE_MULTIPLY, + BoostMode.BOOST_MODE_MULTIPLY); + verifyResponseHitsWithDelta(response, List.of(1), List.of(0.0), 0.0); } @Test @@ -661,7 +685,7 @@ public void testGuassDecayFunctionGeoPoint() { .setDecayFunction( MultiFunctionScoreQuery.DecayFunction.newBuilder() .setDecay(0.5f) - .setDecayType(MultiFunctionScoreQuery.DecayType.DECAY_TYPE_GUASS) + .setDecayType(MultiFunctionScoreQuery.DecayType.DECAY_TYPE_GUASSIAN) .setOffset("10000 km") .setScale("100 km") .setGeoPoint(latLng) @@ -673,6 +697,39 @@ public void testGuassDecayFunctionGeoPoint() { verifyResponseHitsWithDelta(response, List.of(2, 4), List.of(0.3381, 0.2772), 0.0001); } + @Test + public void testInvalidGeoPointField() { + LatLng latLng = LatLng.newBuilder().setLatitude(40.7128).setLongitude(-74.0060).build(); + assertThrows( + Exception.class, + () -> { + SearchResponse response = + doQuery( + Query.newBuilder() + .setMatchQuery( + MatchQuery.newBuilder() + .setField("text_field") + .setQuery("Document2") + .build()) + .build(), + List.of( + MultiFunctionScoreQuery.FilterFunction.newBuilder() + .setDecayFunction( + MultiFunctionScoreQuery.DecayFunction.newBuilder() + .setDecay(0.99f) + .setDecayType( + MultiFunctionScoreQuery.DecayType.DECAY_TYPE_EXPONENTIAL) + .setOffset("0 m") + .setScale("1 km") + .setGeoPoint(latLng) + .setFieldName("text_field") + .build()) + .build()), + FunctionScoreMode.SCORE_MODE_MULTIPLY, + BoostMode.BOOST_MODE_MULTIPLY); + }); + } + @Test public void testMultiMatchAll_multiply_multiply() { multiFunctionAndVerify(