Skip to content

Commit

Permalink
PR changes and more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
swethakann committed Apr 1, 2024
1 parent 74f238c commit 25516c6
Show file tree
Hide file tree
Showing 10 changed files with 209 additions and 68 deletions.
21 changes: 19 additions & 2 deletions clientlib/src/main/proto/yelp/nrtsearch/search.proto
Original file line number Diff line number Diff line change
Expand Up @@ -315,22 +315,39 @@ message MultiFunctionScoreQuery {
oneof Function {
// Produce score with score script definition
Script script = 3;
// Produce score with a decay function
DecayFunction decayFunction = 4;
}
DecayFunction decayFunction = 4;
}

// Apply decay function to docs
message DecayFunction {
// Document field name to use
string fieldName = 1;
string decayType = 2;
// Type of decay function to apply
DecayType decayType = 2;
// Origin point to calculate the distance
oneof Origin {
int32 numeric = 3;
google.type.LatLng geoPoint = 4;
}
// Distance from origin + offset at which computed score will be equal to decay
string scale = 5;
// Compute decay function for docs with a distance greater than offset
string offset = 6;
// Defines how documents are scored at the distance
float decay = 7;
}

enum DecayType {
// Exponential decay function
DECAY_TYPE_EXPONENTIAL = 0;
// Linear decay function
DECAY_TYPE_LINEAR = 1;
// Gaussian decay function
DECAY_TYPE_GUASS = 2;
}


// How to combine multiple function scores to produce a final function score
enum FunctionScoreMode {
Expand Down
33 changes: 32 additions & 1 deletion docs/queries/multi_function_score.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,40 @@ Proto definition:
oneof Function {
// Produce score with score script definition
Script script = 3;
// Produce score with a decay function
DecayFunction decayFunction = 4;
}
}
// Apply decay function to docs
message DecayFunction {
// Document field name to use
string fieldName = 1;
// Type of decay function to apply
DecayType decayType = 2;
// Origin point to calculate the distance
oneof Origin {
int32 numeric = 3;
google.type.LatLng geoPoint = 4;
}
// Distance from origin + offset at which computed score will be equal to decay
string scale = 5;
// Compute decay function for docs with a distance greater than offset
string offset = 6;
// Defines how documents are scored at the distance
float decay = 7;
}
enum DecayType {
// Exponential decay function
DECAY_TYPE_EXPONENTIAL = 0;
// Linear decay function
DECAY_TYPE_LINEAR = 1;
// Gaussian decay function
DECAY_TYPE_GUASS = 2;
}
// How to combine multiple function scores to produce a final function score
enum FunctionScoreMode {
// Multiply weighted function scores together
Expand Down Expand Up @@ -55,4 +86,4 @@ Proto definition:
float min_score = 5;
// Determine minimal score is excluded or not. By default, it's false;
bool min_excluded = 6;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,6 @@
import org.apache.lucene.search.Query;

public abstract class DecayFilterFunction extends FilterFunction {
private static final String EXP = "exp";
private static final String GUASS = "guass";
private static final String LINEAR = "linear";

/**
* Constructor.
Expand All @@ -39,17 +36,16 @@ public DecayFilterFunction(
}
}

protected DecayFunction getDecayFunction(String decayType) {
protected DecayFunction getDecayType(MultiFunctionScoreQuery.DecayType decayType) {
switch (decayType) {
case GUASS:
case DECAY_TYPE_GUASS:
return new GuassDecayFunction();
case EXP:
case DECAY_TYPE_EXPONENTIAL:
return new ExponentialDecayFunction();
case LINEAR:
case DECAY_TYPE_LINEAR:
return new LinearDecayFunction();
default:
throw new IllegalArgumentException(
decayType + " decay function type is not supported. Needs to be guass, exp or linear");
return null;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
import org.apache.lucene.search.Explanation;

public interface DecayFunction {
double computeScore(double distance, double scale);
double computeScore(double distance, double offset, double scale);

double computeScale(double scale, double decay);

Explanation explainComputeScore(String distanceString, double distance, double scale);
Explanation explainComputeScore(double distance, double offset, double scale);
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@

public class ExponentialDecayFunction implements DecayFunction {
@Override
public double computeScore(double distance, double scale) {
return Math.exp(distance * scale);
public double computeScore(double distance, double offset, double scale) {
return Math.exp(scale * Math.max(0.0, distance - offset));
}

@Override
Expand All @@ -29,8 +29,9 @@ public double computeScale(double scale, double decay) {
}

@Override
public Explanation explainComputeScore(String distanceString, double distance, double scale) {
public Explanation explainComputeScore(double distance, double offset, double scale) {
return Explanation.match(
(float) computeScore(distance, scale), "exp(" + distanceString + " * " + scale + ")");
(float) computeScore(distance, offset, scale),
"exp(" + scale + " * max(0.0, " + distance + " - " + offset);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ public static FilterFunction build(
if (filterFunctionGrpc.hasDecayFunction()) {
MultiFunctionScoreQuery.DecayFunction decayFunction = filterFunctionGrpc.getDecayFunction();
if (decayFunction.hasGeoPoint()) {
return new GeoPointDecayFilterFunction(filterQuery, weight, decayFunction);
return new GeoPointDecayFilterFunction(filterQuery, weight, decayFunction, indexState);
}
}
switch (filterFunctionGrpc.getFunctionCase()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@

import com.google.type.LatLng;
import com.yelp.nrtsearch.server.grpc.MultiFunctionScoreQuery;
import com.yelp.nrtsearch.server.luceneserver.IndexState;
import com.yelp.nrtsearch.server.luceneserver.doc.LoadedDocValues;
import com.yelp.nrtsearch.server.luceneserver.geo.GeoPoint;
import com.yelp.nrtsearch.server.luceneserver.doc.SegmentDocLookup;
import com.yelp.nrtsearch.server.luceneserver.field.LatLonFieldDef;
import com.yelp.nrtsearch.server.luceneserver.geo.GeoUtils;
import java.io.IOException;
import java.util.List;
Expand All @@ -31,11 +33,12 @@ public class GeoPointDecayFilterFunction extends DecayFilterFunction {

private final MultiFunctionScoreQuery.DecayFunction decayFunction;
private final String fieldName;
private final DecayFunction decayFunc;
private final DecayFunction decayType;
private final double scale;
private final double offset;
private final double decay;
private final LatLng origin;
private final IndexState indexState;

/**
* Constructor.
Expand All @@ -44,69 +47,81 @@ public class GeoPointDecayFilterFunction extends DecayFilterFunction {
* @param weight weight multiple to scale the function score
* @param decayFunction to score a document with a function that decays depending on the distance
* between an origin point and a numeric doc field value
* @param indexState indexState for validation and doc value lookup
*/
public GeoPointDecayFilterFunction(
Query filterQuery, float weight, MultiFunctionScoreQuery.DecayFunction decayFunction) {
Query filterQuery,
float weight,
MultiFunctionScoreQuery.DecayFunction decayFunction,
IndexState indexState) {
super(filterQuery, weight, decayFunction);
this.decayFunction = decayFunction;
this.fieldName = decayFunction.getFieldName();
this.decayFunc = getDecayFunction(decayFunction.getDecayType());
this.decayType = getDecayType(decayFunction.getDecayType());
this.offset = GeoUtils.getDistance(decayFunction.getOffset());
this.decay = decayFunction.getDecay();
double userGivenScale = GeoUtils.getDistance(decayFunction.getScale());
this.scale = decayFunc.computeScale(userGivenScale, decay);
this.scale = decayType.computeScale(userGivenScale, decay);
this.indexState = indexState;
if (!decayFunction.hasGeoPoint()) {
throw new IllegalArgumentException("Decay Function should have a geoPoint for Origin field");
} else {
this.origin = decayFunction.getGeoPoint();
}
LatLonFieldDef latLonFieldDef = (LatLonFieldDef) indexState.getField(fieldName);
// TODO: Add support for multi-value fields
if (latLonFieldDef.isMultiValue()) {
throw new IllegalArgumentException(
"Multivalue fields are not supported for decay functions yet");
}
}

@Override
public LeafFunction getLeafFunction(LeafReaderContext leafContext) throws IOException {
public LeafFunction getLeafFunction(LeafReaderContext leafContext) {
return new GeoPointDecayLeafFunction(leafContext);
}

public final class GeoPointDecayLeafFunction implements LeafFunction {

private final LoadedDocValues.SingleLocation geoPointValue;
SegmentDocLookup segmentDocLookup;
LatLng latLng;

public GeoPointDecayLeafFunction(LeafReaderContext context) throws IOException {
this.geoPointValue =
new LoadedDocValues.SingleLocation(context.reader().getSortedNumericDocValues(fieldName));
public GeoPointDecayLeafFunction(LeafReaderContext context) {
segmentDocLookup = indexState.docLookup.getSegmentLookup(context);
}

@Override
public double score(int docId, float innerQueryScore) throws IOException {
GeoPoint geoPoint = getGeoPoint(docId);
segmentDocLookup.setDocId(docId);
LoadedDocValues<?> latLonValues = segmentDocLookup.get(fieldName);
this.latLng = latLonValues.toFieldValue(0).getLatLngValue();
double distance =
GeoUtils.arcDistance(
origin.getLatitude(), origin.getLongitude(), geoPoint.getLat(), geoPoint.getLon());
double score = decayFunc.computeScore(distance, scale);
return Math.max(0.0, score - offset);
}

public GeoPoint getGeoPoint(int docId) throws IOException {
geoPointValue.setDocId(docId);
return geoPointValue.getValue();
origin.getLatitude(),
origin.getLongitude(),
latLng.getLatitude(),
latLng.getLongitude());
double score = decayType.computeScore(distance, offset, scale);
return score * getWeight();
}

@Override
public Explanation explainScore(int docId, Explanation innerQueryScore) throws IOException {
GeoPoint geoPoint = getGeoPoint(docId);
public Explanation explainScore(int docId, Explanation innerQueryScore) {
double distance =
GeoUtils.arcDistance(
origin.getLatitude(), origin.getLongitude(), geoPoint.getLat(), geoPoint.getLon());
double score = Math.max(0.0, decayFunc.computeScore(distance, scale) - offset);
Explanation paramsExp =
origin.getLatitude(),
origin.getLongitude(),
latLng.getLatitude(),
latLng.getLongitude());
double score = decayType.computeScore(distance, offset, scale);
Explanation distanceExp =
Explanation.match(distance, "arc distance calculated between two geoPoints");
return Explanation.match(
score,
"final score with the provided decay function calculated by max(0.0, score - offset) with "
+ offset
+ " offset value",
List.of(
paramsExp, decayFunc.explainComputeScore(String.valueOf(distance), distance, scale)));
"final score with the provided decay function calculated by score * weight with "
+ getWeight()
+ " weight value",
List.of(distanceExp, decayType.explainComputeScore(distance, offset, scale)));
}
}

Expand All @@ -115,7 +130,7 @@ public String toString() {
StringBuilder sb = new StringBuilder();
sb.append(super.toString()).append(", decayFunction:");
sb.append("fieldName: ").append(fieldName);
sb.append("decayFunc: ").append(decayFunc);
sb.append("decayType: ").append(decayType);
sb.append("origin: ").append(origin);
sb.append("scale: ").append(scale);
sb.append("offset: ").append(offset);
Expand All @@ -127,7 +142,8 @@ public String toString() {
protected FilterFunction doRewrite(
IndexReader reader, boolean filterQueryRewritten, Query rewrittenFilterQuery) {
if (filterQueryRewritten) {
return new GeoPointDecayFilterFunction(rewrittenFilterQuery, getWeight(), decayFunction);
return new GeoPointDecayFilterFunction(
rewrittenFilterQuery, getWeight(), decayFunction, indexState);
} else {
return this;
}
Expand All @@ -143,16 +159,16 @@ protected boolean doEquals(FilterFunction other) {
}
GeoPointDecayFilterFunction otherGeoPointDecayFilterFunction =
(GeoPointDecayFilterFunction) other;
return Objects.equals(origin, otherGeoPointDecayFilterFunction.origin)
&& decayFunc.equals(otherGeoPointDecayFilterFunction.decayFunc)
&& Objects.equals(fieldName, otherGeoPointDecayFilterFunction.fieldName)
return Objects.equals(fieldName, otherGeoPointDecayFilterFunction.fieldName)
&& Objects.equals(decayType, otherGeoPointDecayFilterFunction.decayType)
&& Objects.equals(origin, otherGeoPointDecayFilterFunction.origin)
&& Double.compare(scale, otherGeoPointDecayFilterFunction.scale) == 0
&& Double.compare(offset, otherGeoPointDecayFilterFunction.offset) == 0
&& Double.compare(decay, otherGeoPointDecayFilterFunction.decay) == 0;
}

@Override
protected int doHashCode() {
return Objects.hash(fieldName, decayFunc, scale, offset, decay, origin);
return Objects.hash(fieldName, decayType, origin, scale, offset, decay);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,19 @@

public class GuassDecayFunction implements DecayFunction {
@Override
public double computeScore(double distance, double scale) {
return Math.exp(0.5 * Math.pow(distance, 2.0) / scale);
public double computeScore(double distance, double offset, double scale) {
return Math.exp((-1.0 * Math.pow(Math.max(0.0, distance - offset), 2.0)) / 2.0 * scale);
}

@Override
public double computeScale(double scale, double decay) {
return 0.5 * Math.pow(scale, 2.0) / Math.log(decay);
return (-1.0 * Math.pow(scale, 2.0)) / (2.0 * Math.log(decay));
}

@Override
public Explanation explainComputeScore(String distanceString, double distance, double scale) {
public Explanation explainComputeScore(double distance, double offset, double scale) {
return Explanation.match(
(float) computeScore(distance, scale),
"exp(0.5 * pow(" + distanceString + ", 2.0) / " + scale + ")");
(float) computeScore(distance, offset, scale),
"exp(- pow(max(0.0, |" + distance + " - " + offset + "), 2.0)/ 2.0 * " + scale);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@

public class LinearDecayFunction implements DecayFunction {
@Override
public double computeScore(double distance, double scale) {
return Math.max(0.0, (scale - distance) / scale);
public double computeScore(double distance, double offset, double scale) {
return Math.max(0.0, (scale - Math.max(0.0, distance - offset)) / scale);
}

@Override
Expand All @@ -29,9 +29,9 @@ public double computeScale(double scale, double decay) {
}

@Override
public Explanation explainComputeScore(String distanceString, double distance, double scale) {
public Explanation explainComputeScore(double distance, double offset, double scale) {
return Explanation.match(
(float) computeScore(distance, scale),
"max(0.0, (" + scale + " - " + distanceString + ") / " + scale + ")");
(float) computeScore(distance, offset, scale),
"max(0.0, (" + scale + " - max(0.0, " + distance + " - " + offset + ")) / " + scale + ")");
}
}
Loading

0 comments on commit 25516c6

Please sign in to comment.