Skip to content

Commit

Permalink
Add decay function support for MultiFunctionScoreQuery
Browse files Browse the repository at this point in the history
  • Loading branch information
swethakann committed Mar 28, 2024
1 parent 915da1a commit 74f238c
Show file tree
Hide file tree
Showing 11 changed files with 431 additions and 1 deletion.
14 changes: 14 additions & 0 deletions clientlib/src/main/proto/yelp/nrtsearch/search.proto
Original file line number Diff line number Diff line change
Expand Up @@ -316,8 +316,22 @@ message MultiFunctionScoreQuery {
// Produce score with score script definition
Script script = 3;
}
DecayFunction decayFunction = 4;
}

message DecayFunction {
string fieldName = 1;
string decayType = 2;
oneof Origin {
int32 numeric = 3;
google.type.LatLng geoPoint = 4;
}
string scale = 5;
string offset = 6;
float decay = 7;
}


// How to combine multiple function scores to produce a final function score
enum FunctionScoreMode {
// Multiply weighted function scores together
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
*/
package com.yelp.nrtsearch.server.luceneserver.geo;

import org.apache.lucene.util.SloppyMath;

public class GeoUtils {

private static final double KM_TO_M = 1000.0;
Expand Down Expand Up @@ -64,4 +66,12 @@ public static double convertDistanceToADifferentUnit(double distanceNumber, Stri
throw new IllegalArgumentException("Invalid unit " + unit);
}
}

/**
* Return the distance (in meters) between 2 lat,lon geo points using the haversine method
* implemented by lucene
*/
public static double arcDistance(double lat1, double lon1, double lat2, double lon2) {
return SloppyMath.haversinMeters(lat1, lon1, lat2, lon2);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
* Copyright 2024 Yelp Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.yelp.nrtsearch.server.luceneserver.search.query.multifunction;

import com.yelp.nrtsearch.server.grpc.MultiFunctionScoreQuery;
import org.apache.lucene.search.Query;

public abstract class DecayFilterFunction extends FilterFunction {
private static final String EXP = "exp";
private static final String GUASS = "guass";
private static final String LINEAR = "linear";

/**
* Constructor.
*
* @param filterQuery filter to use when applying this function, or null if none
* @param weight weight multiple to scale the function score
* @param decayFunction to score a document with a function that decays depending on the distance
* between an origin point and a numeric doc field value
*/
public DecayFilterFunction(
Query filterQuery, float weight, MultiFunctionScoreQuery.DecayFunction decayFunction) {
super(filterQuery, weight);
if (decayFunction.getDecay() <= 0 || decayFunction.getDecay() >= 1) {
throw new IllegalArgumentException("decay rate should be between (0, 1)");
}
}

protected DecayFunction getDecayFunction(String decayType) {
switch (decayType) {
case GUASS:
return new GuassDecayFunction();
case EXP:
return new ExponentialDecayFunction();
case LINEAR:
return new LinearDecayFunction();
default:
throw new IllegalArgumentException(
decayType + " decay function type is not supported. Needs to be guass, exp or linear");
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
/*
* Copyright 2024 Yelp Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.yelp.nrtsearch.server.luceneserver.search.query.multifunction;

import org.apache.lucene.search.Explanation;

public interface DecayFunction {
double computeScore(double distance, double scale);

double computeScale(double scale, double decay);

Explanation explainComputeScore(String distanceString, double distance, double scale);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/*
* Copyright 2024 Yelp Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.yelp.nrtsearch.server.luceneserver.search.query.multifunction;

import org.apache.lucene.search.Explanation;

public class ExponentialDecayFunction implements DecayFunction {
@Override
public double computeScore(double distance, double scale) {
return Math.exp(distance * scale);
}

@Override
public double computeScale(double scale, double decay) {
return Math.log(decay) / scale;
}

@Override
public Explanation explainComputeScore(String distanceString, double distance, double scale) {
return Explanation.match(
(float) computeScore(distance, scale), "exp(" + distanceString + " * " + scale + ")");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,12 @@ public static FilterFunction build(
? QueryNodeMapper.getInstance().getQuery(filterFunctionGrpc.getFilter(), indexState)
: null;
float weight = filterFunctionGrpc.getWeight() != 0.0f ? filterFunctionGrpc.getWeight() : 1.0f;
if (filterFunctionGrpc.hasDecayFunction()) {
MultiFunctionScoreQuery.DecayFunction decayFunction = filterFunctionGrpc.getDecayFunction();
if (decayFunction.hasGeoPoint()) {
return new GeoPointDecayFilterFunction(filterQuery, weight, decayFunction);
}
}
switch (filterFunctionGrpc.getFunctionCase()) {
case SCRIPT:
ScoreScript.Factory factory =
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
/*
* Copyright 2024 Yelp Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.yelp.nrtsearch.server.luceneserver.search.query.multifunction;

import com.google.type.LatLng;
import com.yelp.nrtsearch.server.grpc.MultiFunctionScoreQuery;
import com.yelp.nrtsearch.server.luceneserver.doc.LoadedDocValues;
import com.yelp.nrtsearch.server.luceneserver.geo.GeoPoint;
import com.yelp.nrtsearch.server.luceneserver.geo.GeoUtils;
import java.io.IOException;
import java.util.List;
import java.util.Objects;
import org.apache.lucene.index.*;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.Query;

public class GeoPointDecayFilterFunction extends DecayFilterFunction {

private final MultiFunctionScoreQuery.DecayFunction decayFunction;
private final String fieldName;
private final DecayFunction decayFunc;
private final double scale;
private final double offset;
private final double decay;
private final LatLng origin;

/**
* Constructor.
*
* @param filterQuery filter to use when applying this function, or null if none
* @param weight weight multiple to scale the function score
* @param decayFunction to score a document with a function that decays depending on the distance
* between an origin point and a numeric doc field value
*/
public GeoPointDecayFilterFunction(
Query filterQuery, float weight, MultiFunctionScoreQuery.DecayFunction decayFunction) {
super(filterQuery, weight, decayFunction);
this.decayFunction = decayFunction;
this.fieldName = decayFunction.getFieldName();
this.decayFunc = getDecayFunction(decayFunction.getDecayType());
this.offset = GeoUtils.getDistance(decayFunction.getOffset());
this.decay = decayFunction.getDecay();
double userGivenScale = GeoUtils.getDistance(decayFunction.getScale());
this.scale = decayFunc.computeScale(userGivenScale, decay);
if (!decayFunction.hasGeoPoint()) {
throw new IllegalArgumentException("Decay Function should have a geoPoint for Origin field");
} else {
this.origin = decayFunction.getGeoPoint();
}
}

@Override
public LeafFunction getLeafFunction(LeafReaderContext leafContext) throws IOException {
return new GeoPointDecayLeafFunction(leafContext);
}

public final class GeoPointDecayLeafFunction implements LeafFunction {

private final LoadedDocValues.SingleLocation geoPointValue;

public GeoPointDecayLeafFunction(LeafReaderContext context) throws IOException {
this.geoPointValue =
new LoadedDocValues.SingleLocation(context.reader().getSortedNumericDocValues(fieldName));
}

@Override
public double score(int docId, float innerQueryScore) throws IOException {
GeoPoint geoPoint = getGeoPoint(docId);
double distance =
GeoUtils.arcDistance(
origin.getLatitude(), origin.getLongitude(), geoPoint.getLat(), geoPoint.getLon());
double score = decayFunc.computeScore(distance, scale);
return Math.max(0.0, score - offset);
}

public GeoPoint getGeoPoint(int docId) throws IOException {
geoPointValue.setDocId(docId);
return geoPointValue.getValue();
}

@Override
public Explanation explainScore(int docId, Explanation innerQueryScore) throws IOException {
GeoPoint geoPoint = getGeoPoint(docId);
double distance =
GeoUtils.arcDistance(
origin.getLatitude(), origin.getLongitude(), geoPoint.getLat(), geoPoint.getLon());
double score = Math.max(0.0, decayFunc.computeScore(distance, scale) - offset);
Explanation paramsExp =
Explanation.match(distance, "arc distance calculated between two geoPoints");
return Explanation.match(
score,
"final score with the provided decay function calculated by max(0.0, score - offset) with "
+ offset
+ " offset value",
List.of(
paramsExp, decayFunc.explainComputeScore(String.valueOf(distance), distance, scale)));
}
}

@Override
public String toString() {
StringBuilder sb = new StringBuilder();
sb.append(super.toString()).append(", decayFunction:");
sb.append("fieldName: ").append(fieldName);
sb.append("decayFunc: ").append(decayFunc);
sb.append("origin: ").append(origin);
sb.append("scale: ").append(scale);
sb.append("offset: ").append(offset);
sb.append("decay: ").append(decay);
return sb.toString();
}

@Override
protected FilterFunction doRewrite(
IndexReader reader, boolean filterQueryRewritten, Query rewrittenFilterQuery) {
if (filterQueryRewritten) {
return new GeoPointDecayFilterFunction(rewrittenFilterQuery, getWeight(), decayFunction);
} else {
return this;
}
}

@Override
protected boolean doEquals(FilterFunction other) {
if (other == null) {
return false;
}
if (other.getClass() != this.getClass()) {
return false;
}
GeoPointDecayFilterFunction otherGeoPointDecayFilterFunction =
(GeoPointDecayFilterFunction) other;
return Objects.equals(origin, otherGeoPointDecayFilterFunction.origin)
&& decayFunc.equals(otherGeoPointDecayFilterFunction.decayFunc)
&& Objects.equals(fieldName, otherGeoPointDecayFilterFunction.fieldName)
&& Double.compare(scale, otherGeoPointDecayFilterFunction.scale) == 0
&& Double.compare(offset, otherGeoPointDecayFilterFunction.offset) == 0
&& Double.compare(decay, otherGeoPointDecayFilterFunction.decay) == 0;
}

@Override
protected int doHashCode() {
return Objects.hash(fieldName, decayFunc, scale, offset, decay, origin);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/*
* Copyright 2024 Yelp Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.yelp.nrtsearch.server.luceneserver.search.query.multifunction;

import org.apache.lucene.search.Explanation;

public class GuassDecayFunction implements DecayFunction {
@Override
public double computeScore(double distance, double scale) {
return Math.exp(0.5 * Math.pow(distance, 2.0) / scale);
}

@Override
public double computeScale(double scale, double decay) {
return 0.5 * Math.pow(scale, 2.0) / Math.log(decay);
}

@Override
public Explanation explainComputeScore(String distanceString, double distance, double scale) {
return Explanation.match(
(float) computeScore(distance, scale),
"exp(0.5 * pow(" + distanceString + ", 2.0) / " + scale + ")");
}
}
Loading

0 comments on commit 74f238c

Please sign in to comment.