From 25516c66b2cd8d478b13b859a412c3a817c8c1de Mon Sep 17 00:00:00 2001 From: swethakann Date: Mon, 1 Apr 2024 11:55:26 -0700 Subject: [PATCH] PR changes and more tests --- .../main/proto/yelp/nrtsearch/search.proto | 21 ++++- docs/queries/multi_function_score.rst | 33 ++++++- .../multifunction/DecayFilterFunction.java | 14 ++- .../query/multifunction/DecayFunction.java | 4 +- .../ExponentialDecayFunction.java | 9 +- .../query/multifunction/FilterFunction.java | 2 +- .../GeoPointDecayFilterFunction.java | 86 +++++++++++-------- .../multifunction/GuassDecayFunction.java | 12 +-- .../multifunction/LinearDecayFunction.java | 10 +-- .../MultiFunctionScoreQueryTest.java | 86 ++++++++++++++++++- 10 files changed, 209 insertions(+), 68 deletions(-) diff --git a/clientlib/src/main/proto/yelp/nrtsearch/search.proto b/clientlib/src/main/proto/yelp/nrtsearch/search.proto index 6ea5d20f4..907b9041a 100644 --- a/clientlib/src/main/proto/yelp/nrtsearch/search.proto +++ b/clientlib/src/main/proto/yelp/nrtsearch/search.proto @@ -315,22 +315,39 @@ message MultiFunctionScoreQuery { oneof Function { // Produce score with score script definition Script script = 3; + // Produce score with a decay function + DecayFunction decayFunction = 4; } - DecayFunction decayFunction = 4; } + // Apply decay function to docs message DecayFunction { + // Document field name to use string fieldName = 1; - string decayType = 2; + // Type of decay function to apply + DecayType decayType = 2; + // Origin point to calculate the distance oneof Origin { int32 numeric = 3; google.type.LatLng geoPoint = 4; } + // Distance from origin + offset at which computed score will be equal to decay string scale = 5; + // Compute decay function for docs with a distance greater than offset string offset = 6; + // Defines how documents are scored at the distance float decay = 7; } + enum DecayType { + // Exponential decay function + DECAY_TYPE_EXPONENTIAL = 0; + // Linear decay function + DECAY_TYPE_LINEAR = 1; + // Gaussian decay function + DECAY_TYPE_GUASS = 2; + } + // How to combine multiple function scores to produce a final function score enum FunctionScoreMode { diff --git a/docs/queries/multi_function_score.rst b/docs/queries/multi_function_score.rst index 8b2839b37..ba559885b 100644 --- a/docs/queries/multi_function_score.rst +++ b/docs/queries/multi_function_score.rst @@ -22,9 +22,40 @@ Proto definition: oneof Function { // Produce score with score script definition Script script = 3; + // Produce score with a decay function + DecayFunction decayFunction = 4; } } + // Apply decay function to docs + message DecayFunction { + // Document field name to use + string fieldName = 1; + // Type of decay function to apply + DecayType decayType = 2; + // Origin point to calculate the distance + oneof Origin { + int32 numeric = 3; + google.type.LatLng geoPoint = 4; + } + // Distance from origin + offset at which computed score will be equal to decay + string scale = 5; + // Compute decay function for docs with a distance greater than offset + string offset = 6; + // Defines how documents are scored at the distance + float decay = 7; + } + + enum DecayType { + // Exponential decay function + DECAY_TYPE_EXPONENTIAL = 0; + // Linear decay function + DECAY_TYPE_LINEAR = 1; + // Gaussian decay function + DECAY_TYPE_GUASS = 2; + } + + // How to combine multiple function scores to produce a final function score enum FunctionScoreMode { // Multiply weighted function scores together @@ -55,4 +86,4 @@ Proto definition: float min_score = 5; // Determine minimal score is excluded or not. By default, it's false; bool min_excluded = 6; - } \ No newline at end of file + } diff --git a/src/main/java/com/yelp/nrtsearch/server/luceneserver/search/query/multifunction/DecayFilterFunction.java b/src/main/java/com/yelp/nrtsearch/server/luceneserver/search/query/multifunction/DecayFilterFunction.java index dd44c060d..43d1deca7 100644 --- a/src/main/java/com/yelp/nrtsearch/server/luceneserver/search/query/multifunction/DecayFilterFunction.java +++ b/src/main/java/com/yelp/nrtsearch/server/luceneserver/search/query/multifunction/DecayFilterFunction.java @@ -19,9 +19,6 @@ import org.apache.lucene.search.Query; public abstract class DecayFilterFunction extends FilterFunction { - private static final String EXP = "exp"; - private static final String GUASS = "guass"; - private static final String LINEAR = "linear"; /** * Constructor. @@ -39,17 +36,16 @@ public DecayFilterFunction( } } - protected DecayFunction getDecayFunction(String decayType) { + protected DecayFunction getDecayType(MultiFunctionScoreQuery.DecayType decayType) { switch (decayType) { - case GUASS: + case DECAY_TYPE_GUASS: return new GuassDecayFunction(); - case EXP: + case DECAY_TYPE_EXPONENTIAL: return new ExponentialDecayFunction(); - case LINEAR: + case DECAY_TYPE_LINEAR: return new LinearDecayFunction(); default: - throw new IllegalArgumentException( - decayType + " decay function type is not supported. Needs to be guass, exp or linear"); + return null; } } } diff --git a/src/main/java/com/yelp/nrtsearch/server/luceneserver/search/query/multifunction/DecayFunction.java b/src/main/java/com/yelp/nrtsearch/server/luceneserver/search/query/multifunction/DecayFunction.java index 1d2bb6a29..8ab725ba3 100644 --- a/src/main/java/com/yelp/nrtsearch/server/luceneserver/search/query/multifunction/DecayFunction.java +++ b/src/main/java/com/yelp/nrtsearch/server/luceneserver/search/query/multifunction/DecayFunction.java @@ -18,9 +18,9 @@ import org.apache.lucene.search.Explanation; public interface DecayFunction { - double computeScore(double distance, double scale); + double computeScore(double distance, double offset, double scale); double computeScale(double scale, double decay); - Explanation explainComputeScore(String distanceString, double distance, double scale); + Explanation explainComputeScore(double distance, double offset, double scale); } diff --git a/src/main/java/com/yelp/nrtsearch/server/luceneserver/search/query/multifunction/ExponentialDecayFunction.java b/src/main/java/com/yelp/nrtsearch/server/luceneserver/search/query/multifunction/ExponentialDecayFunction.java index 2f6dbcbe9..e9556d648 100644 --- a/src/main/java/com/yelp/nrtsearch/server/luceneserver/search/query/multifunction/ExponentialDecayFunction.java +++ b/src/main/java/com/yelp/nrtsearch/server/luceneserver/search/query/multifunction/ExponentialDecayFunction.java @@ -19,8 +19,8 @@ public class ExponentialDecayFunction implements DecayFunction { @Override - public double computeScore(double distance, double scale) { - return Math.exp(distance * scale); + public double computeScore(double distance, double offset, double scale) { + return Math.exp(scale * Math.max(0.0, distance - offset)); } @Override @@ -29,8 +29,9 @@ public double computeScale(double scale, double decay) { } @Override - public Explanation explainComputeScore(String distanceString, double distance, double scale) { + public Explanation explainComputeScore(double distance, double offset, double scale) { return Explanation.match( - (float) computeScore(distance, scale), "exp(" + distanceString + " * " + scale + ")"); + (float) computeScore(distance, offset, scale), + "exp(" + scale + " * max(0.0, " + distance + " - " + offset); } } diff --git a/src/main/java/com/yelp/nrtsearch/server/luceneserver/search/query/multifunction/FilterFunction.java b/src/main/java/com/yelp/nrtsearch/server/luceneserver/search/query/multifunction/FilterFunction.java index 571fc80ed..77b4e4681 100644 --- a/src/main/java/com/yelp/nrtsearch/server/luceneserver/search/query/multifunction/FilterFunction.java +++ b/src/main/java/com/yelp/nrtsearch/server/luceneserver/search/query/multifunction/FilterFunction.java @@ -80,7 +80,7 @@ public static FilterFunction build( if (filterFunctionGrpc.hasDecayFunction()) { MultiFunctionScoreQuery.DecayFunction decayFunction = filterFunctionGrpc.getDecayFunction(); if (decayFunction.hasGeoPoint()) { - return new GeoPointDecayFilterFunction(filterQuery, weight, decayFunction); + return new GeoPointDecayFilterFunction(filterQuery, weight, decayFunction, indexState); } } switch (filterFunctionGrpc.getFunctionCase()) { diff --git a/src/main/java/com/yelp/nrtsearch/server/luceneserver/search/query/multifunction/GeoPointDecayFilterFunction.java b/src/main/java/com/yelp/nrtsearch/server/luceneserver/search/query/multifunction/GeoPointDecayFilterFunction.java index 3aa88ba15..c39f26579 100644 --- a/src/main/java/com/yelp/nrtsearch/server/luceneserver/search/query/multifunction/GeoPointDecayFilterFunction.java +++ b/src/main/java/com/yelp/nrtsearch/server/luceneserver/search/query/multifunction/GeoPointDecayFilterFunction.java @@ -17,8 +17,10 @@ import com.google.type.LatLng; import com.yelp.nrtsearch.server.grpc.MultiFunctionScoreQuery; +import com.yelp.nrtsearch.server.luceneserver.IndexState; import com.yelp.nrtsearch.server.luceneserver.doc.LoadedDocValues; -import com.yelp.nrtsearch.server.luceneserver.geo.GeoPoint; +import com.yelp.nrtsearch.server.luceneserver.doc.SegmentDocLookup; +import com.yelp.nrtsearch.server.luceneserver.field.LatLonFieldDef; import com.yelp.nrtsearch.server.luceneserver.geo.GeoUtils; import java.io.IOException; import java.util.List; @@ -31,11 +33,12 @@ public class GeoPointDecayFilterFunction extends DecayFilterFunction { private final MultiFunctionScoreQuery.DecayFunction decayFunction; private final String fieldName; - private final DecayFunction decayFunc; + private final DecayFunction decayType; private final double scale; private final double offset; private final double decay; private final LatLng origin; + private final IndexState indexState; /** * Constructor. @@ -44,69 +47,81 @@ public class GeoPointDecayFilterFunction extends DecayFilterFunction { * @param weight weight multiple to scale the function score * @param decayFunction to score a document with a function that decays depending on the distance * between an origin point and a numeric doc field value + * @param indexState indexState for validation and doc value lookup */ public GeoPointDecayFilterFunction( - Query filterQuery, float weight, MultiFunctionScoreQuery.DecayFunction decayFunction) { + Query filterQuery, + float weight, + MultiFunctionScoreQuery.DecayFunction decayFunction, + IndexState indexState) { super(filterQuery, weight, decayFunction); this.decayFunction = decayFunction; this.fieldName = decayFunction.getFieldName(); - this.decayFunc = getDecayFunction(decayFunction.getDecayType()); + this.decayType = getDecayType(decayFunction.getDecayType()); this.offset = GeoUtils.getDistance(decayFunction.getOffset()); this.decay = decayFunction.getDecay(); double userGivenScale = GeoUtils.getDistance(decayFunction.getScale()); - this.scale = decayFunc.computeScale(userGivenScale, decay); + this.scale = decayType.computeScale(userGivenScale, decay); + this.indexState = indexState; if (!decayFunction.hasGeoPoint()) { throw new IllegalArgumentException("Decay Function should have a geoPoint for Origin field"); } else { this.origin = decayFunction.getGeoPoint(); } + LatLonFieldDef latLonFieldDef = (LatLonFieldDef) indexState.getField(fieldName); + // TODO: Add support for multi-value fields + if (latLonFieldDef.isMultiValue()) { + throw new IllegalArgumentException( + "Multivalue fields are not supported for decay functions yet"); + } } @Override - public LeafFunction getLeafFunction(LeafReaderContext leafContext) throws IOException { + public LeafFunction getLeafFunction(LeafReaderContext leafContext) { return new GeoPointDecayLeafFunction(leafContext); } public final class GeoPointDecayLeafFunction implements LeafFunction { - private final LoadedDocValues.SingleLocation geoPointValue; + SegmentDocLookup segmentDocLookup; + LatLng latLng; - public GeoPointDecayLeafFunction(LeafReaderContext context) throws IOException { - this.geoPointValue = - new LoadedDocValues.SingleLocation(context.reader().getSortedNumericDocValues(fieldName)); + public GeoPointDecayLeafFunction(LeafReaderContext context) { + segmentDocLookup = indexState.docLookup.getSegmentLookup(context); } @Override public double score(int docId, float innerQueryScore) throws IOException { - GeoPoint geoPoint = getGeoPoint(docId); + segmentDocLookup.setDocId(docId); + LoadedDocValues latLonValues = segmentDocLookup.get(fieldName); + this.latLng = latLonValues.toFieldValue(0).getLatLngValue(); double distance = GeoUtils.arcDistance( - origin.getLatitude(), origin.getLongitude(), geoPoint.getLat(), geoPoint.getLon()); - double score = decayFunc.computeScore(distance, scale); - return Math.max(0.0, score - offset); - } - - public GeoPoint getGeoPoint(int docId) throws IOException { - geoPointValue.setDocId(docId); - return geoPointValue.getValue(); + origin.getLatitude(), + origin.getLongitude(), + latLng.getLatitude(), + latLng.getLongitude()); + double score = decayType.computeScore(distance, offset, scale); + return score * getWeight(); } @Override - public Explanation explainScore(int docId, Explanation innerQueryScore) throws IOException { - GeoPoint geoPoint = getGeoPoint(docId); + public Explanation explainScore(int docId, Explanation innerQueryScore) { double distance = GeoUtils.arcDistance( - origin.getLatitude(), origin.getLongitude(), geoPoint.getLat(), geoPoint.getLon()); - double score = Math.max(0.0, decayFunc.computeScore(distance, scale) - offset); - Explanation paramsExp = + origin.getLatitude(), + origin.getLongitude(), + latLng.getLatitude(), + latLng.getLongitude()); + double score = decayType.computeScore(distance, offset, scale); + Explanation distanceExp = Explanation.match(distance, "arc distance calculated between two geoPoints"); return Explanation.match( score, - "final score with the provided decay function calculated by max(0.0, score - offset) with " - + offset - + " offset value", - List.of( - paramsExp, decayFunc.explainComputeScore(String.valueOf(distance), distance, scale))); + "final score with the provided decay function calculated by score * weight with " + + getWeight() + + " weight value", + List.of(distanceExp, decayType.explainComputeScore(distance, offset, scale))); } } @@ -115,7 +130,7 @@ public String toString() { StringBuilder sb = new StringBuilder(); sb.append(super.toString()).append(", decayFunction:"); sb.append("fieldName: ").append(fieldName); - sb.append("decayFunc: ").append(decayFunc); + sb.append("decayType: ").append(decayType); sb.append("origin: ").append(origin); sb.append("scale: ").append(scale); sb.append("offset: ").append(offset); @@ -127,7 +142,8 @@ public String toString() { protected FilterFunction doRewrite( IndexReader reader, boolean filterQueryRewritten, Query rewrittenFilterQuery) { if (filterQueryRewritten) { - return new GeoPointDecayFilterFunction(rewrittenFilterQuery, getWeight(), decayFunction); + return new GeoPointDecayFilterFunction( + rewrittenFilterQuery, getWeight(), decayFunction, indexState); } else { return this; } @@ -143,9 +159,9 @@ protected boolean doEquals(FilterFunction other) { } GeoPointDecayFilterFunction otherGeoPointDecayFilterFunction = (GeoPointDecayFilterFunction) other; - return Objects.equals(origin, otherGeoPointDecayFilterFunction.origin) - && decayFunc.equals(otherGeoPointDecayFilterFunction.decayFunc) - && Objects.equals(fieldName, otherGeoPointDecayFilterFunction.fieldName) + return Objects.equals(fieldName, otherGeoPointDecayFilterFunction.fieldName) + && Objects.equals(decayType, otherGeoPointDecayFilterFunction.decayType) + && Objects.equals(origin, otherGeoPointDecayFilterFunction.origin) && Double.compare(scale, otherGeoPointDecayFilterFunction.scale) == 0 && Double.compare(offset, otherGeoPointDecayFilterFunction.offset) == 0 && Double.compare(decay, otherGeoPointDecayFilterFunction.decay) == 0; @@ -153,6 +169,6 @@ protected boolean doEquals(FilterFunction other) { @Override protected int doHashCode() { - return Objects.hash(fieldName, decayFunc, scale, offset, decay, origin); + return Objects.hash(fieldName, decayType, origin, scale, offset, decay); } } diff --git a/src/main/java/com/yelp/nrtsearch/server/luceneserver/search/query/multifunction/GuassDecayFunction.java b/src/main/java/com/yelp/nrtsearch/server/luceneserver/search/query/multifunction/GuassDecayFunction.java index 8e02fbd33..51632c3bc 100644 --- a/src/main/java/com/yelp/nrtsearch/server/luceneserver/search/query/multifunction/GuassDecayFunction.java +++ b/src/main/java/com/yelp/nrtsearch/server/luceneserver/search/query/multifunction/GuassDecayFunction.java @@ -19,19 +19,19 @@ public class GuassDecayFunction implements DecayFunction { @Override - public double computeScore(double distance, double scale) { - return Math.exp(0.5 * Math.pow(distance, 2.0) / scale); + public double computeScore(double distance, double offset, double scale) { + return Math.exp((-1.0 * Math.pow(Math.max(0.0, distance - offset), 2.0)) / 2.0 * scale); } @Override public double computeScale(double scale, double decay) { - return 0.5 * Math.pow(scale, 2.0) / Math.log(decay); + return (-1.0 * Math.pow(scale, 2.0)) / (2.0 * Math.log(decay)); } @Override - public Explanation explainComputeScore(String distanceString, double distance, double scale) { + public Explanation explainComputeScore(double distance, double offset, double scale) { return Explanation.match( - (float) computeScore(distance, scale), - "exp(0.5 * pow(" + distanceString + ", 2.0) / " + scale + ")"); + (float) computeScore(distance, offset, scale), + "exp(- pow(max(0.0, |" + distance + " - " + offset + "), 2.0)/ 2.0 * " + scale); } } diff --git a/src/main/java/com/yelp/nrtsearch/server/luceneserver/search/query/multifunction/LinearDecayFunction.java b/src/main/java/com/yelp/nrtsearch/server/luceneserver/search/query/multifunction/LinearDecayFunction.java index 796a4d03e..e362cb424 100644 --- a/src/main/java/com/yelp/nrtsearch/server/luceneserver/search/query/multifunction/LinearDecayFunction.java +++ b/src/main/java/com/yelp/nrtsearch/server/luceneserver/search/query/multifunction/LinearDecayFunction.java @@ -19,8 +19,8 @@ public class LinearDecayFunction implements DecayFunction { @Override - public double computeScore(double distance, double scale) { - return Math.max(0.0, (scale - distance) / scale); + public double computeScore(double distance, double offset, double scale) { + return Math.max(0.0, (scale - Math.max(0.0, distance - offset)) / scale); } @Override @@ -29,9 +29,9 @@ public double computeScale(double scale, double decay) { } @Override - public Explanation explainComputeScore(String distanceString, double distance, double scale) { + public Explanation explainComputeScore(double distance, double offset, double scale) { return Explanation.match( - (float) computeScore(distance, scale), - "max(0.0, (" + scale + " - " + distanceString + ") / " + scale + ")"); + (float) computeScore(distance, offset, scale), + "max(0.0, (" + scale + " - max(0.0, " + distance + " - " + offset + ")) / " + scale + ")"); } } diff --git a/src/test/java/com/yelp/nrtsearch/server/luceneserver/search/query/multifunction/MultiFunctionScoreQueryTest.java b/src/test/java/com/yelp/nrtsearch/server/luceneserver/search/query/multifunction/MultiFunctionScoreQueryTest.java index a0dc1df05..cf4485a32 100644 --- a/src/test/java/com/yelp/nrtsearch/server/luceneserver/search/query/multifunction/MultiFunctionScoreQueryTest.java +++ b/src/test/java/com/yelp/nrtsearch/server/luceneserver/search/query/multifunction/MultiFunctionScoreQueryTest.java @@ -567,7 +567,7 @@ public void testScriptWithScore() { } @Test - public void testDecayFunctionGeoPoint() { + public void testExpDecayFunctionGeoPointWithWeight() { LatLng latLng = LatLng.newBuilder().setLatitude(40.7128).setLongitude(-74.0060).build(); SearchResponse response = doQuery( @@ -580,17 +580,97 @@ public void testDecayFunctionGeoPoint() { .setDecayFunction( MultiFunctionScoreQuery.DecayFunction.newBuilder() .setDecay(0.99f) - .setDecayType("exp") + .setDecayType(MultiFunctionScoreQuery.DecayType.DECAY_TYPE_EXPONENTIAL) .setOffset("0 m") .setScale("1 km") .setGeoPoint(latLng) .setFieldName("lat_lon_field") .build()) + .setWeight(0.7f) .build()), FunctionScoreMode.SCORE_MODE_MULTIPLY, BoostMode.BOOST_MODE_MULTIPLY); verifyResponseHitsWithDelta( - response, List.of(2, 4), List.of(3.4234246868436458E-6, 2.034676950471705E-18), 0.0); + response, List.of(2, 4), List.of(2.3963971216289792E-6, 2.034676950471705E-18), 0.00000001); + } + + @Test + public void testExpDecayFunctionGeoPoint() { + LatLng latLng = LatLng.newBuilder().setLatitude(40.7128).setLongitude(-74.0060).build(); + SearchResponse response = + doQuery( + Query.newBuilder() + .setMatchQuery( + MatchQuery.newBuilder().setField("text_field").setQuery("Document2").build()) + .build(), + List.of( + MultiFunctionScoreQuery.FilterFunction.newBuilder() + .setDecayFunction( + MultiFunctionScoreQuery.DecayFunction.newBuilder() + .setDecay(0.99f) + .setDecayType(MultiFunctionScoreQuery.DecayType.DECAY_TYPE_EXPONENTIAL) + .setOffset("0 m") + .setScale("1 km") + .setGeoPoint(latLng) + .setFieldName("lat_lon_field") + .build()) + .build()), + FunctionScoreMode.SCORE_MODE_MULTIPLY, + BoostMode.BOOST_MODE_MULTIPLY); + verifyResponseHitsWithDelta( + response, List.of(2, 4), List.of(3.4234246868436458E-6, 2.034676950471705E-18), 0.00000001); + } + + @Test + public void testLinearDecayFunctionGeoPoint() { + LatLng latLng = LatLng.newBuilder().setLatitude(40.7128).setLongitude(-74.0060).build(); + SearchResponse response = + doQuery( + Query.newBuilder() + .setMatchQuery( + MatchQuery.newBuilder().setField("text_field").setQuery("Document2").build()) + .build(), + List.of( + MultiFunctionScoreQuery.FilterFunction.newBuilder() + .setDecayFunction( + MultiFunctionScoreQuery.DecayFunction.newBuilder() + .setDecay(0.2f) + .setDecayType(MultiFunctionScoreQuery.DecayType.DECAY_TYPE_LINEAR) + .setOffset("100 km") + .setScale("6000 km") + .setGeoPoint(latLng) + .setFieldName("lat_lon_field") + .build()) + .build()), + FunctionScoreMode.SCORE_MODE_MULTIPLY, + BoostMode.BOOST_MODE_MULTIPLY); + verifyResponseHitsWithDelta(response, List.of(2, 4), List.of(0.2910, 0.1358), 0.0001); + } + + @Test + public void testGuassDecayFunctionGeoPoint() { + LatLng latLng = LatLng.newBuilder().setLatitude(40.7128).setLongitude(-74.0060).build(); + SearchResponse response = + doQuery( + Query.newBuilder() + .setMatchQuery( + MatchQuery.newBuilder().setField("text_field").setQuery("Document2").build()) + .build(), + List.of( + MultiFunctionScoreQuery.FilterFunction.newBuilder() + .setDecayFunction( + MultiFunctionScoreQuery.DecayFunction.newBuilder() + .setDecay(0.5f) + .setDecayType(MultiFunctionScoreQuery.DecayType.DECAY_TYPE_GUASS) + .setOffset("10000 km") + .setScale("100 km") + .setGeoPoint(latLng) + .setFieldName("lat_lon_field") + .build()) + .build()), + FunctionScoreMode.SCORE_MODE_MULTIPLY, + BoostMode.BOOST_MODE_MULTIPLY); + verifyResponseHitsWithDelta(response, List.of(2, 4), List.of(0.3381, 0.2772), 0.0001); } @Test