Skip to content

Commit

Permalink
Adds parsing of numerical string values
Browse files Browse the repository at this point in the history
Signed-off-by: Brian Flores <[email protected]>
  • Loading branch information
brianf-aws committed Dec 27, 2024
1 parent 1bcd7d1 commit 06a25a5
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import static org.opensearch.neuralsearch.processor.util.ProcessorUtils.getScoreFromSourceMap;
import static org.opensearch.neuralsearch.processor.util.ProcessorUtils.getValueFromSource;
import static org.opensearch.neuralsearch.processor.util.ProcessorUtils.isNumeric;
import static org.opensearch.neuralsearch.processor.util.ProcessorUtils.mappingExistsInSource;
import static org.opensearch.neuralsearch.processor.util.ProcessorUtils.removeTargetFieldFromSource;
import static org.opensearch.neuralsearch.processor.util.ProcessorUtils.validateRerankCriteria;
Expand Down Expand Up @@ -113,7 +114,6 @@ public void rescoreSearchResponse(
final ActionListener<List<Float>> listener
) {
SearchHit[] searchHits = response.getHits().getHits();

SearchHitValidator searchHitValidator = this::byFieldSearchHitValidator;

if (!validateRerankCriteria(searchHits, searchHitValidator, listener)) {
Expand Down Expand Up @@ -177,9 +177,9 @@ public void byFieldSearchHitValidator(final SearchHit hit) {

Optional<Object> val = getValueFromSource(sourceMap, targetField);

if (!(val.get() instanceof Number)) {
if (!(isNumeric(val.get()))) {
// Strictly get the type of value removing the prefix of getClass() having a value is guaranteed so no NPE check
String typeOfMapping = val.get().getClass().toString().replace("class ", "");
String typeOfMapping = val.get().getClass().getSimpleName();
log.error(
String.format(
Locale.ROOT,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ public interface SearchHitValidator {
* for each SearchHit follows the correct form as specified by the validator.
* When just one of the conditions fail (as specified by the validator) the exception will be thrown to the listener.
* @param searchHits from the SearchResponse
* @param validator The given validator used to check every search hit being correct
* @param listener returns an error to the listener in case on of the conditions fail
* @return The status indicating that the SearchHits are in correct form to perform the Rerank
*/
Expand Down Expand Up @@ -77,6 +78,9 @@ public static boolean validateRerankCriteria(
*/
public static float getScoreFromSourceMap(final Map<String, Object> sourceAsMap, final String targetField) {
Object val = getValueFromSource(sourceAsMap, targetField).get();
if (val instanceof String) {
return Float.parseFloat((String) val);
}
return ((Number) val).floatValue();
}

Expand Down Expand Up @@ -180,4 +184,29 @@ public static boolean mappingExistsInSource(final Map<String, Object> sourceAsMa
return getValueFromSource(sourceAsMap, pathToValue).isPresent();
}

/**
* @param value Any value to be determined to be numerical
* @return whether the value can be turned into a number
*/
public static boolean isNumeric(Object value) {
if (value == null) {
return false;
}

if (value instanceof Number) {
return true;
}

if (value instanceof String) {
String string = (String) value;
try {
Double.parseDouble(string);
return true;
} catch (NumberFormatException e) {
return false;
}
}

return false;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -755,6 +755,90 @@ public void testRerank_keepsTargetFieldAndHasNoPreviousScore_WhenByFieldHasDefau
}
}

public void testRerank_reranksHits_WhenTargetFieldIsNumericalString() throws IOException {
String targetField = "ml.info.score";
setUpValidSearchResultsWithNestedTargetValueWithNumericalString();
List<Map.Entry<Integer, Float>> sortedScoresDescending = sampleIndexMLScorePairs.stream()
.sorted(Map.Entry.<Integer, Float>comparingByValue().reversed())
.toList();

Map<String, Object> config = new HashMap<>(
Map.of(RerankType.BY_FIELD.getLabel(), new HashMap<>(Map.of(ByFieldRerankProcessor.TARGET_FIELD, targetField)))
);
processor = (ByFieldRerankProcessor) factory.create(
Map.of(),
"rerank processor",
"processor for 2nd level reranking based on provided field, This will check a nested field and numerical string",
false,
config,
pipelineContext
);
ActionListener<SearchResponse> listener = mock(ActionListener.class);
processor.rerank(response, Map.of(), listener);
ArgumentCaptor<SearchResponse> argCaptor = ArgumentCaptor.forClass(SearchResponse.class);

verify(listener, times(1)).onResponse(argCaptor.capture());
SearchResponse searchResponse = argCaptor.getValue();

assertEquals(sampleIndexMLScorePairs.size(), searchResponse.getHits().getHits().length);
assertEquals(sortedScoresDescending.getFirst().getValue(), searchResponse.getHits().getMaxScore(), 0.0001);

for (int i = 0; i < sortedScoresDescending.size(); i++) {
int docId = sortedScoresDescending.get(i).getKey();
float ml_score = sortedScoresDescending.get(i).getValue();
assertEquals(docId, searchResponse.getHits().getAt(i).docId());
assertEquals(ml_score, searchResponse.getHits().getAt(i).getScore(), 0.001);

// Test that the path to targetField is valid
Map<String, Object> currentMap = searchResponse.getHits().getAt(i).getSourceAsMap();
String[] keys = targetField.split("\\.");
String lastKey = keys[keys.length - 1];
for (int keyIndex = 0; keyIndex < keys.length - 1; keyIndex++) {
String key = keys[keyIndex];
assertTrue("The key:" + key + "does not exist in" + currentMap, currentMap.containsKey(key));
currentMap = (Map<String, Object>) currentMap.get(key);
}
assertTrue("The key:" + lastKey + "does not exist in" + currentMap, currentMap.containsKey(lastKey));

}
}

/**
* Setups a search response that has a target field with a numerical string for example "3.2"
* Which can be used by the processor to rerank documents.
*/
private void setUpValidSearchResultsWithNestedTargetValueWithNumericalString() {
SearchHit[] hits = new SearchHit[sampleIndexMLScorePairs.size()];

String templateString = """
{
"my_field" : "%s",
"ml": {
"info" : {
"score": "%s"
}
}
}
""".replace("\n", "");

for (int i = 0; i < sampleIndexMLScorePairs.size(); i++) {
int docId = sampleIndexMLScorePairs.get(i).getKey();
String mlScore = sampleIndexMLScorePairs.get(i).getValue() + "";

String sourceMap = templateString.formatted(i, mlScore);

hits[i] = new SearchHit(docId, docId + "", Collections.emptyMap(), Collections.emptyMap());
hits[i].sourceRef(new BytesArray(sourceMap));
hits[i].score(1);
}

TotalHits totalHits = new TotalHits(sampleIndexMLScorePairs.size(), TotalHits.Relation.EQUAL_TO);

SearchHits searchHits = new SearchHits(hits, totalHits, 1.0f);
SearchResponseSections internal = new SearchResponseSections(searchHits, null, null, false, false, null, 0);
response = new SearchResponse(internal, null, 1, 1, 0, 1, new ShardSearchFailure[0], new SearchResponse.Clusters(1, 1, 0), null);
}

/**
* Creates a searchResponse where the value to reRank by is Nested.
* The location where the target is within a map of size 1 meaning after
Expand Down Expand Up @@ -986,7 +1070,7 @@ public void testRerank_throwsExceptionOnHavingNonNumericValue_WhenTargetFieldHas
verify(listener, times(1)).onFailure(argumentCaptor.capture());

assertEquals(
"The field mapping to rerank by [hello world] is not Numerical, instead of type [java.lang.String]",
"The field mapping to rerank by [hello world] is not Numerical, instead of type [String]",
argumentCaptor.getValue().getMessage()
);
assert (argumentCaptor.getValue() instanceof IllegalArgumentException);
Expand Down

0 comments on commit 06a25a5

Please sign in to comment.