Skip to content

Commit

Permalink
PR changes and some refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
swethakann committed Apr 2, 2024
1 parent d59c35b commit 33fd26b
Show file tree
Hide file tree
Showing 8 changed files with 189 additions and 79 deletions.
19 changes: 9 additions & 10 deletions clientlib/src/main/proto/yelp/nrtsearch/search.proto
Original file line number Diff line number Diff line change
Expand Up @@ -328,15 +328,15 @@ message MultiFunctionScoreQuery {
DecayType decayType = 2;
// Origin point to calculate the distance
oneof Origin {
int32 numeric = 3;
google.type.LatLng geoPoint = 4;
google.type.LatLng geoPoint = 3;
}
// Distance from origin + offset at which computed score will be equal to decay
string scale = 5;
// Compute decay function for docs with a distance greater than offset, will be 0.0 if none is set
string offset = 6;
// Defines how documents are scored at the distance
float decay = 7;
// Currently only distance based scale and offset units are supported
// Distance from origin + offset at which computed score will be equal to decay. Scale should be distance, unit (m, km, mi) with space is optional. Default unit will be meters. Ex: "10", "15 km", "5 m", "7 mi"
string scale = 4;
// Compute decay function for docs with a distance greater than offset, will be 0.0 if none is set. Offset should be distance, unit (m, km, mi) with space is optional. Default unit will be meters. Ex: "10", "15 km", "5 m", "7 mi"
string offset = 5;
// Defines decay rate for scoring. Should be between (0, 1)
float decay = 6;
}

enum DecayType {
Expand All @@ -345,10 +345,9 @@ message MultiFunctionScoreQuery {
// Linear decay function
DECAY_TYPE_LINEAR = 1;
// Gaussian decay function
DECAY_TYPE_GUASS = 2;
DECAY_TYPE_GUASSIAN = 2;
}


// How to combine multiple function scores to produce a final function score
enum FunctionScoreMode {
// Multiply weighted function scores together
Expand Down
18 changes: 9 additions & 9 deletions docs/queries/multi_function_score.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,15 @@ Proto definition:
DecayType decayType = 2;
// Origin point to calculate the distance
oneof Origin {
int32 numeric = 3;
google.type.LatLng geoPoint = 4;
google.type.LatLng geoPoint = 3;
}
// Distance from origin + offset at which computed score will be equal to decay
string scale = 5;
// Compute decay function for docs with a distance greater than offset, will be 0.0 if none is set
string offset = 6;
// Defines how documents are scored at the distance
float decay = 7;
// Currently only distance based scale and offset units are supported
// Distance from origin + offset at which computed score will be equal to decay. Scale should be distance, unit (m, km, mi) with space is optional. Default unit will be meters. Ex: "10", 15 km", "5 m", "7 mi"
string scale = 4;
// Compute decay function for docs with a distance greater than offset, will be 0.0 if none is set. Offset should be distance, unit (m, km, mi) with space is optional. Default unit will be meters. Ex: "10", 15 km", "5 m", "7 mi"
string offset = 5;
// Defines decay rate for scoring. Should be between (0, 1)
float decay = 6;
}
enum DecayType {
Expand All @@ -52,7 +52,7 @@ Proto definition:
// Linear decay function
DECAY_TYPE_LINEAR = 1;
// Gaussian decay function
DECAY_TYPE_GUASS = 2;
DECAY_TYPE_GUASSIAN = 2;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,20 +32,19 @@ public DecayFilterFunction(
Query filterQuery, float weight, MultiFunctionScoreQuery.DecayFunction decayFunction) {
super(filterQuery, weight);
if (decayFunction.getDecay() <= 0 || decayFunction.getDecay() >= 1) {
throw new IllegalArgumentException("decay rate should be between (0, 1)");
throw new IllegalArgumentException(
"decay rate should be between (0, 1) but is " + decayFunction.getDecay());
}
}

protected DecayFunction getDecayType(MultiFunctionScoreQuery.DecayType decayType) {
switch (decayType) {
case DECAY_TYPE_GUASS:
return new GuassDecayFunction();
case DECAY_TYPE_EXPONENTIAL:
return new ExponentialDecayFunction();
case DECAY_TYPE_LINEAR:
return new LinearDecayFunction();
default:
return null;
}
return switch (decayType) {
case DECAY_TYPE_GUASSIAN -> new GuassianDecayFunction();
case DECAY_TYPE_EXPONENTIAL -> new ExponentialDecayFunction();
case DECAY_TYPE_LINEAR -> new LinearDecayFunction();
default -> throw new IllegalArgumentException(
decayType
+ " not supported. Only exponential, guassian and linear decay functions are supported");
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,35 @@
import org.apache.lucene.search.Explanation;

public interface DecayFunction {
/**
* Computes the decayed score based on the provided distance, offset, and scale.
*
* @param distance the distance from a given origin point.
* @param offset the point after which the decay starts.
* @param scale scale factor that influences the rate of decay. This scale value is computed from
* the user given scale using the computeScale() method.
* @return the decayed score after applying the decay function
*/
double computeScore(double distance, double offset, double scale);

/**
* Computes the adjusted scale based on a user given scale and decay rate.
*
* @param scale user given scale.
* @param decay decay rate that decides how the score decreases.
* @return adjusted scale which will be used by the computeScore() method.
*/
double computeScale(double scale, double decay);

/**
* Provides an explanation for the computed score based on the given distance, offset, and scale.
*
* @param distance the distance from a given origin point.
* @param offset the point after which the decay starts.
* @param scale scale factor that influences the rate of decay. This scale value is computed from
* the user given scale using the computeScale() method.
* @return Explanation object that details the calculations involved in computing the decayed
* score.
*/
Explanation explainComputeScore(double distance, double offset, double scale);
}
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,6 @@ public static FilterFunction build(
? QueryNodeMapper.getInstance().getQuery(filterFunctionGrpc.getFilter(), indexState)
: null;
float weight = filterFunctionGrpc.getWeight() != 0.0f ? filterFunctionGrpc.getWeight() : 1.0f;
if (filterFunctionGrpc.hasDecayFunction()) {
MultiFunctionScoreQuery.DecayFunction decayFunction = filterFunctionGrpc.getDecayFunction();
if (decayFunction.hasGeoPoint()) {
return new GeoPointDecayFilterFunction(filterQuery, weight, decayFunction, indexState);
}
}
switch (filterFunctionGrpc.getFunctionCase()) {
case SCRIPT:
ScoreScript.Factory factory =
Expand All @@ -94,6 +88,14 @@ public static FilterFunction build(
indexState.docLookup);
return new ScriptFilterFunction(
filterQuery, weight, filterFunctionGrpc.getScript(), scriptSource);
case DECAYFUNCTION:
MultiFunctionScoreQuery.DecayFunction decayFunction = filterFunctionGrpc.getDecayFunction();
if (decayFunction.hasGeoPoint()) {
return new GeoPointDecayFilterFunction(filterQuery, weight, decayFunction, indexState);
} else {
throw new IllegalArgumentException(
"Decay Function should contain a geoPoint for Origin field");
}
case FUNCTION_NOT_SET:
return new WeightFilterFunction(filterQuery, weight);
default:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import com.yelp.nrtsearch.server.luceneserver.IndexState;
import com.yelp.nrtsearch.server.luceneserver.doc.LoadedDocValues;
import com.yelp.nrtsearch.server.luceneserver.doc.SegmentDocLookup;
import com.yelp.nrtsearch.server.luceneserver.field.FieldDef;
import com.yelp.nrtsearch.server.luceneserver.field.LatLonFieldDef;
import com.yelp.nrtsearch.server.luceneserver.geo.GeoPoint;
import com.yelp.nrtsearch.server.luceneserver.geo.GeoUtils;
Expand All @@ -31,6 +32,7 @@
import org.apache.lucene.search.Query;

public class GeoPointDecayFilterFunction extends DecayFilterFunction {
private static final String LAT_LON = "LAT_LON";

private final MultiFunctionScoreQuery.DecayFunction decayFunction;
private final String fieldName;
Expand All @@ -47,7 +49,7 @@ public class GeoPointDecayFilterFunction extends DecayFilterFunction {
* @param filterQuery filter to use when applying this function, or null if none
* @param weight weight multiple to scale the function score
* @param decayFunction to score a document with a function that decays depending on the distance
* between an origin point and a numeric doc field value
* between an origin point and a geoPoint doc field value
* @param indexState indexState for validation and doc value lookup
*/
public GeoPointDecayFilterFunction(
Expand All @@ -59,24 +61,33 @@ public GeoPointDecayFilterFunction(
this.decayFunction = decayFunction;
this.fieldName = decayFunction.getFieldName();
this.decayType = getDecayType(decayFunction.getDecayType());
this.origin = decayFunction.getGeoPoint();
this.decay = decayFunction.getDecay();
double userGivenScale = GeoUtils.getDistance(decayFunction.getScale());
this.scale = decayType.computeScale(userGivenScale, decay);
this.offset =
!decayFunction.getOffset().isEmpty()
? GeoUtils.getDistance(decayFunction.getOffset())
: 0.0;
this.decay = decayFunction.getDecay();
double userGivenScale = GeoUtils.getDistance(decayFunction.getScale());
this.scale = decayType.computeScale(userGivenScale, decay);
this.indexState = indexState;
if (!decayFunction.hasGeoPoint()) {
throw new IllegalArgumentException("Decay Function should have a geoPoint for Origin field");
} else {
this.origin = decayFunction.getGeoPoint();
validateLatLonField(indexState.getField(fieldName));
}

public void validateLatLonField(FieldDef fieldDef) {
if (!LAT_LON.equals(fieldDef.getType())) {
throw new IllegalArgumentException(
fieldName
+ " should be a LAT_LON to apply geoPoint decay function but it is: "
+ fieldDef.getType());
}
LatLonFieldDef latLonFieldDef = (LatLonFieldDef) indexState.getField(fieldName);
LatLonFieldDef latLonFieldDef = (LatLonFieldDef) fieldDef;
// TODO: Add support for multi-value fields
if (latLonFieldDef.isMultiValue()) {
throw new IllegalArgumentException(
"Multivalue fields are not supported for decay functions yet");
"Multivalued fields are not supported for decay functions yet");
}
if (!latLonFieldDef.hasDocValues()) {
throw new IllegalStateException("No doc values present for LAT_LON field: " + fieldName);
}
}

Expand All @@ -88,7 +99,6 @@ public LeafFunction getLeafFunction(LeafReaderContext leafContext) {
public final class GeoPointDecayLeafFunction implements LeafFunction {

SegmentDocLookup segmentDocLookup;
LatLng latLng;

public GeoPointDecayLeafFunction(LeafReaderContext context) {
segmentDocLookup = indexState.docLookup.getSegmentLookup(context);
Expand All @@ -97,36 +107,53 @@ public GeoPointDecayLeafFunction(LeafReaderContext context) {
@Override
public double score(int docId, float innerQueryScore) throws IOException {
segmentDocLookup.setDocId(docId);
LoadedDocValues<GeoPoint> latLonValues =
(LoadedDocValues<GeoPoint>) segmentDocLookup.get(fieldName);
this.latLng = latLonValues.toFieldValue(0).getLatLngValue();
double distance =
GeoUtils.arcDistance(
origin.getLatitude(),
origin.getLongitude(),
latLng.getLatitude(),
latLng.getLongitude());
double score = decayType.computeScore(distance, offset, scale);
return score * getWeight();
if (!validateDocValuesPresent()) {
return 0.0;
} else {
LoadedDocValues<GeoPoint> geoPointLoadedDocValues =
(LoadedDocValues<GeoPoint>) segmentDocLookup.get(fieldName);
GeoPoint latLng = geoPointLoadedDocValues.get(0);
double distance =
GeoUtils.arcDistance(
origin.getLatitude(), origin.getLongitude(), latLng.getLat(), latLng.getLon());
double score = decayType.computeScore(distance, offset, scale);
return score * getWeight();
}
}

public boolean validateDocValuesPresent() {
return segmentDocLookup.containsKey(fieldName) && !segmentDocLookup.get(fieldName).isEmpty();
}

@Override
public Explanation explainScore(int docId, Explanation innerQueryScore) {
double distance =
GeoUtils.arcDistance(
origin.getLatitude(),
origin.getLongitude(),
latLng.getLatitude(),
latLng.getLongitude());
double score = decayType.computeScore(distance, offset, scale);
Explanation distanceExp =
Explanation.match(distance, "arc distance calculated between two geoPoints");
return Explanation.match(
score,
"final score with the provided decay function calculated by score * weight with "
+ getWeight()
+ " weight value",
List.of(distanceExp, decayType.explainComputeScore(distance, offset, scale)));
double score;
if (validateDocValuesPresent()) {
LoadedDocValues<GeoPoint> geoPointLoadedDocValues =
(LoadedDocValues<GeoPoint>) segmentDocLookup.get(fieldName);
GeoPoint latLng = geoPointLoadedDocValues.get(0);
double distance =
GeoUtils.arcDistance(
origin.getLatitude(), origin.getLongitude(), latLng.getLat(), latLng.getLon());

Explanation distanceExp =
Explanation.match(distance, "arc distance calculated between two geoPoints");

score = decayType.computeScore(distance, offset, scale);
double finalScore = score * getWeight();
return Explanation.match(
finalScore,
"final score with the provided decay function calculated by score * weight with "
+ getWeight()
+ " weight value and "
+ score
+ "score",
List.of(distanceExp, decayType.explainComputeScore(distance, offset, scale)));
} else {
score = 0.0;
return Explanation.match(
score, "score is 0.0 since no doc values were present for " + fieldName);
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import org.apache.lucene.search.Explanation;

public class GuassDecayFunction implements DecayFunction {
public class GuassianDecayFunction implements DecayFunction {
@Override
public double computeScore(double distance, double offset, double scale) {
return Math.exp((-1.0 * Math.pow(Math.max(0.0, distance - offset), 2.0)) / 2.0 * scale);
Expand Down
Loading

0 comments on commit 33fd26b

Please sign in to comment.