Skip to content

Commit

Permalink
Fix code review comments
Browse files Browse the repository at this point in the history
Signed-off-by: Vijayan Balasubramanian <[email protected]>
  • Loading branch information
VijayanB committed Jan 14, 2025
1 parent f35f222 commit ba8e9ba
Show file tree
Hide file tree
Showing 8 changed files with 64 additions and 64 deletions.
5 changes: 0 additions & 5 deletions src/main/java/org/opensearch/knn/index/SpaceType.java
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,6 @@ public float scoreTranslation(float rawScore) {
return Math.max((2.0F - rawScore) / 2.0F, 0.0F);
}

@Override
public float scoreToDistanceTranslation(float score) {
return score;
}

@Override
public KNNVectorSimilarityFunction getKnnVectorSimilarityFunction() {
return KNNVectorSimilarityFunction.COSINE;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,14 @@ public class Faiss extends NativeLibrary {
// https://opensearch.org/docs/latest/search-plugins/knn/approximate-knn/#spaces
private final static Map<SpaceType, Function<Float, Float>> SCORE_TO_DISTANCE_TRANSFORMATIONS = ImmutableMap.<
SpaceType,
Function<Float, Float>>builder().put(SpaceType.INNER_PRODUCT, score -> score > 1 ? 1 - score : 1 / score - 1).build();
Function<Float, Float>>builder()
.put(SpaceType.INNER_PRODUCT, score -> score > 1 ? 1 - score : (1 / score) - 1)
.put(SpaceType.COSINESIMIL, score -> 2 - 2 * score)
.build();

private final static Map<SpaceType, Function<Float, Float>> DISTANCE_TRANSLATIONS = ImmutableMap.<
SpaceType,
Function<Float, Float>>builder().put(SpaceType.COSINESIMIL, distance -> (2 - distance) / 2).build();
Function<Float, Float>>builder().put(SpaceType.COSINESIMIL, distance -> 1 - distance).build();

// Package private so that the method resolving logic can access the methods
final static Map<String, KNNMethod> METHODS = ImmutableMap.of(METHOD_HNSW, new FaissHNSWMethod(), METHOD_IVF, new FaissIVFMethod());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -675,7 +675,7 @@ protected void validatePreparse() {
protected abstract VectorValidator getVectorValidator();

/**
* Getter for per dimension validator during vector parsing
* Getter for per dimension validator during vector parsing, and before any transformation
*
* @return PerDimensionValidator
*/
Expand All @@ -688,6 +688,11 @@ protected void validatePreparse() {
*/
protected abstract PerDimensionProcessor getPerDimensionProcessor();

/**
* Getter for vector transformer after vector parsing and validation
*
* @return VectorTransformer
*/
protected abstract VectorTransformer getVectorTransformer();

protected void parseCreateField(ParseContext context, int dimension, VectorDataType vectorDataType) throws IOException {
Expand All @@ -700,8 +705,8 @@ protected void parseCreateField(ParseContext context, int dimension, VectorDataT
}
final byte[] array = bytesArrayOptional.get();
getVectorValidator().validateVector(array);
final byte[] transformedArray = getVectorTransformer().transform(array);
context.doc().addAll(getFieldsForByteVector(transformedArray));
getVectorTransformer().transform(array);
context.doc().addAll(getFieldsForByteVector(array));
} else if (VectorDataType.FLOAT == vectorDataType) {
Optional<float[]> floatsArrayOptional = getFloatsFromContext(context, dimension);

Expand All @@ -710,8 +715,8 @@ protected void parseCreateField(ParseContext context, int dimension, VectorDataT
}
final float[] array = floatsArrayOptional.get();
getVectorValidator().validateVector(array);
final float[] transformedArray = getVectorTransformer().transform(array);
context.doc().addAll(getFieldsForFloatVector(transformedArray));
getVectorTransformer().transform(array);
context.doc().addAll(getFieldsForFloatVector(array));
} else {
throw new IllegalArgumentException(
String.format(Locale.ROOT, "Cannot parse context for unsupported values provided for field [%s]", VECTOR_DATA_TYPE_FIELD)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@
import org.opensearch.index.query.QueryShardException;
import org.opensearch.knn.index.KNNVectorIndexFieldData;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.engine.KNNMethodContext;
import org.opensearch.knn.index.query.rescore.RescoreContext;
import org.opensearch.search.aggregations.support.CoreValuesSourceType;
import org.opensearch.search.lookup.SearchLookup;

import java.util.Locale;
import java.util.Map;
import java.util.Optional;
import java.util.function.Supplier;

import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.deserializeStoredVector;
Expand Down Expand Up @@ -99,4 +101,32 @@ public RescoreContext resolveRescoreContext(RescoreContext userProvidedContext)
Mode mode = knnMappingConfig.getMode();
return compressionLevel.getDefaultRescoreContext(mode, dimension);
}

/**
* Transforms a query vector based on the configured vector data type and KNN method context.
* This method only performs transformations on FLOAT type vectors, leaving other types unchanged.
*
* @param vector The float array to be transformed in place. The transformation will modify
* the original array values directly.
* @throws IllegalStateException if the KNN method context is not properly configured
* or is missing from the mapping configuration
*
* The transformation process:
* 1. Checks if the vector is of FLOAT type
* 2. Retrieves the KNN method context from mapping configuration
* 3. Applies the appropriate vector transformation based on the method context
*
* If the vector is not of FLOAT type, this method returns without performing any transformation.
*/

public void transformQueryVector(float[] vector) {
if (VectorDataType.FLOAT != vectorDataType) {
return;
}
final Optional<KNNMethodContext> knnMethodContext = knnMappingConfig.getKnnMethodContext();
if (knnMethodContext.isEmpty()) {
throw new IllegalStateException("KNN method context is not set");
}
VectorTransformerFactory.getVectorTransformer(knnMethodContext.get()).transform(vector);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,30 +2,25 @@
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.mapper;

import org.apache.lucene.util.VectorUtil;

/**
* Normalizes vectors using L2 (Euclidean) normalization. This transformation ensures
* that the vector's magnitude becomes 1 while preserving its directional properties.
* Normalizes vectors using L2 (Euclidean) normalization, ensuring the vector's
* magnitude becomes 1 while preserving its directional properties.
*/
public class NormalizeVectorTransformer implements VectorTransformer {

/**
* Transforms the input vector into unit vector by applying L2 normalization.
*
* @param vector The input vector to be normalized. Must not be null.
* @return A new float array containing the L2-normalized version of the input vector.
* Each component is divided by the Euclidean norm of the vector.
* @throws IllegalArgumentException if the input vector is null, empty, or a zero vector
*/
@Override
public float[] transform(float[] vector) {
public void transform(float[] vector) {
validateVector(vector);
VectorUtil.l2normalize(vector);
}

private void validateVector(float[] vector) {
if (vector == null || vector.length == 0) {
throw new IllegalArgumentException("Vector cannot be null or empty");
}
return VectorUtil.l2normalize(vector);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,8 @@
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.mapper;

import java.util.Arrays;

/**
* Defines operations for transforming vectors in the k-NN search context.
* Implementations can modify vectors while preserving their dimensional properties
Expand All @@ -15,50 +12,31 @@
public interface VectorTransformer {

/**
* Transforms a float vector into a new vector of the same type.
*
* Example:
* <pre>{@code
* float[] input = {1.0f, 2.0f, 3.0f};
* float[] transformed = transformer.transform(input);
* }</pre>
* Transforms a float vector in place.
*
* @param vector The input vector to transform (must not be null)
* @return The transformed vector
* @throws IllegalArgumentException if the input vector is null
*/
default float[] transform(final float[] vector) {
default void transform(final float[] vector) {
if (vector == null) {
throw new IllegalArgumentException("Input vector cannot be null");
}
return Arrays.copyOf(vector, vector.length);
}

/**
* Transforms a byte vector into a new vector of the same type.
*
* Example:
* <pre>{@code
* byte[] input = {1, 2, 3};
* byte[] transformed = transformer.transform(input);
* }</pre>
* Transforms a byte vector in place.
*
* @param vector The input vector to transform (must not be null)
* @return The transformed vector
* @throws IllegalArgumentException if the input vector is null
*/
default byte[] transform(final byte[] vector) {
default void transform(final byte[] vector) {
if (vector == null) {
throw new IllegalArgumentException("Input vector cannot be null");
}
// return copy of vector to avoid side effects
return Arrays.copyOf(vector, vector.length);

}

/**
* A no-operation transformer that returns vector values unchanged.
* This constant can be used when no transformation is needed.
*/
VectorTransformer NOOP_VECTOR_TRANSFORMER = new VectorTransformer() {
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import org.apache.commons.lang.StringUtils;
import org.apache.lucene.search.MatchNoDocsQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.util.VectorUtil;
import org.opensearch.common.ValidationException;
import org.opensearch.core.ParseField;
import org.opensearch.core.common.Strings;
Expand Down Expand Up @@ -429,6 +428,7 @@ protected Query doToQuery(QueryShardContext context) {
SpaceType spaceType = queryConfigFromMapping.get().getSpaceType();
VectorDataType vectorDataType = queryConfigFromMapping.get().getVectorDataType();
RescoreContext processedRescoreContext = knnVectorFieldType.resolveRescoreContext(rescoreContext);
knnVectorFieldType.transformQueryVector(vector);

VectorQueryType vectorQueryType = getVectorQueryType(k, maxDistance, minScore);
updateQueryStats(vectorQueryType);
Expand Down Expand Up @@ -542,7 +542,7 @@ protected Query doToQuery(QueryShardContext context) {
.knnEngine(knnEngine)
.indexName(indexName)
.fieldName(this.fieldName)
.vector(getVectorForCreatingQueryRequest(vectorDataType, knnEngine, spaceType))
.vector(getVectorForCreatingQueryRequest(vectorDataType, knnEngine))
.byteVector(getVectorForCreatingQueryRequest(vectorDataType, knnEngine, byteVector))
.vectorDataType(vectorDataType)
.k(this.k)
Expand Down Expand Up @@ -612,13 +612,7 @@ private void updateQueryStats(VectorQueryType vectorQueryType) {
}
}

private float[] getVectorForCreatingQueryRequest(VectorDataType vectorDataType, KNNEngine knnEngine, SpaceType spaceType) {

// Cosine similarity is supported as Inner product by FAISS by normalizing input vector, hence, we have to normalize
// query vector before applying search
if (knnEngine == KNNEngine.FAISS && spaceType == SpaceType.COSINESIMIL && VectorDataType.FLOAT == vectorDataType) {
return VectorUtil.l2normalize(this.vector);
}
private float[] getVectorForCreatingQueryRequest(VectorDataType vectorDataType, KNNEngine knnEngine) {
if ((VectorDataType.FLOAT == vectorDataType) || (VectorDataType.BYTE == vectorDataType && KNNEngine.FAISS == knnEngine)) {
return this.vector;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@ public void testNormalizeTransformer_withEmptyVector_thenThrowsException() {

public void testNormalizeTransformer_withValidVector_thenSuccess() {
float[] input = { -3.0f, 4.0f };
float[] normalized = transformer.transform(input);
transformer.transform(input);

assertEquals(-0.6f, normalized[0], DELTA);
assertEquals(0.8f, normalized[1], DELTA);
assertEquals(-0.6f, input[0], DELTA);
assertEquals(0.8f, input[1], DELTA);

// Verify the magnitude is 1
assertEquals(1.0f, calculateMagnitude(normalized), DELTA);
assertEquals(1.0f, calculateMagnitude(input), DELTA);
}

private float calculateMagnitude(float[] vector) {
Expand Down

0 comments on commit ba8e9ba

Please sign in to comment.