Skip to content

Commit

Permalink
Add cosine similarity support for faiss engine
Browse files Browse the repository at this point in the history
FAISS engine doesn't support cosine similarity natively.
However we can use inner product to achieve the same, because,
when vectors are normalized then inner product will be same
as cosine similarity. Hence, before ingestion and perform search,
normalize the input vector and add it to faiss index with type
as inner product.

Since we will be storing normalized vector in segments, to get
actual vectors, source can be used. By saving as normalized vector,
we don't have to normalize whenever segments are merged. This will
keep force merge time and search at competitive, provided we will
face additional latency during indexing (one time where we normalize).

We also support radial search for cosine similarity.

Signed-off-by: Vijayan Balasubramanian <[email protected]>
  • Loading branch information
VijayanB committed Jan 10, 2025
1 parent a875eb8 commit 4658bee
Show file tree
Hide file tree
Showing 22 changed files with 542 additions and 47 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Add Support for Multi Values in innerHit for Nested k-NN Fields in Lucene and FAISS (#2283)[https://github.com/opensearch-project/k-NN/pull/2283]
- Add binary index support for Lucene engine. (#2292)[https://github.com/opensearch-project/k-NN/pull/2292]
- Add expand_nested_docs Parameter support to NMSLIB engine (#2331)[https://github.com/opensearch-project/k-NN/pull/2331]
- Add cosine similarity support for faiss engine (#2376)[https://github.com/opensearch-project/k-NN/pull/2376]
### Enhancements
- Introduced a writing layer in native engines where relies on the writing interface to process IO. (#2241)[https://github.com/opensearch-project/k-NN/pull/2241]
- Allow method parameter override for training based indices (#2290) https://github.com/opensearch-project/k-NN/pull/2290]
Expand Down
5 changes: 5 additions & 0 deletions src/main/java/org/opensearch/knn/index/SpaceType.java
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,11 @@ public float scoreTranslation(float rawScore) {
return Math.max((2.0F - rawScore) / 2.0F, 0.0F);
}

@Override
public float scoreToDistanceTranslation(float score) {
return score;
}

@Override
public KNNVectorSimilarityFunction getKnnVectorSimilarityFunction() {
return KNNVectorSimilarityFunction.COSINE;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import org.opensearch.knn.index.mapper.PerDimensionProcessor;
import org.opensearch.knn.index.mapper.PerDimensionValidator;
import org.opensearch.knn.index.mapper.SpaceVectorValidator;
import org.opensearch.knn.index.mapper.VectorTransformer;
import org.opensearch.knn.index.mapper.VectorTransformerFactory;
import org.opensearch.knn.index.mapper.VectorValidator;

import java.util.ArrayList;
Expand Down Expand Up @@ -106,6 +108,10 @@ protected PerDimensionProcessor doGetPerDimensionProcessor(
return PerDimensionProcessor.NOOP_PROCESSOR;
}

protected VectorTransformer getVectorTransformer(KNNMethodContext knnMethodContext) {
return VectorTransformerFactory.getVectorTransformer(knnMethodContext);
}

@Override
public KNNLibraryIndexingContext getKNNLibraryIndexingContext(
KNNMethodContext knnMethodContext,
Expand All @@ -116,19 +122,35 @@ public KNNLibraryIndexingContext getKNNLibraryIndexingContext(
knnMethodConfigContext
);
Map<String, Object> parameterMap = knnLibraryIndexingContext.getLibraryParameters();
parameterMap.put(KNNConstants.SPACE_TYPE, knnMethodContext.getSpaceType().getValue());
parameterMap.put(KNNConstants.SPACE_TYPE, getCompatibleSpaceType(knnMethodContext.getSpaceType()).getValue());
parameterMap.put(KNNConstants.VECTOR_DATA_TYPE_FIELD, knnMethodConfigContext.getVectorDataType().getValue());
return KNNLibraryIndexingContextImpl.builder()
.quantizationConfig(knnLibraryIndexingContext.getQuantizationConfig())
.parameters(parameterMap)
.vectorValidator(doGetVectorValidator(knnMethodContext, knnMethodConfigContext))
.perDimensionValidator(doGetPerDimensionValidator(knnMethodContext, knnMethodConfigContext))
.perDimensionProcessor(doGetPerDimensionProcessor(knnMethodContext, knnMethodConfigContext))
.vectorTransformer(getVectorTransformer(knnMethodContext))
.build();
}

@Override
public KNNLibrarySearchContext getKNNLibrarySearchContext() {
return knnLibrarySearchContext;
}

/**
* Gets the compatible space type for the given space type parameter.
* The subclass can override this method and returns the appropriate space type that
* is compatible with the library.
*
* @param spaceType The space type to check for compatibility
* @return The compatible space type for the given input, returns the same
* space type if it's already compatible
* @see SpaceType
*/

protected SpaceType getCompatibleSpaceType(SpaceType spaceType) {
return spaceType;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import org.opensearch.knn.index.engine.qframe.QuantizationConfig;
import org.opensearch.knn.index.mapper.PerDimensionProcessor;
import org.opensearch.knn.index.mapper.PerDimensionValidator;
import org.opensearch.knn.index.mapper.VectorTransformer;
import org.opensearch.knn.index.mapper.VectorValidator;

import java.util.Map;
Expand Down Expand Up @@ -47,4 +48,6 @@ public interface KNNLibraryIndexingContext {
* @return Get the per dimension processor
*/
PerDimensionProcessor getPerDimensionProcessor();

VectorTransformer getVectorTransformer();
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import org.opensearch.knn.index.engine.qframe.QuantizationConfig;
import org.opensearch.knn.index.mapper.PerDimensionProcessor;
import org.opensearch.knn.index.mapper.PerDimensionValidator;
import org.opensearch.knn.index.mapper.VectorTransformer;
import org.opensearch.knn.index.mapper.VectorValidator;

import java.util.Collections;
Expand All @@ -23,6 +24,7 @@ public class KNNLibraryIndexingContextImpl implements KNNLibraryIndexingContext
private VectorValidator vectorValidator;
private PerDimensionValidator perDimensionValidator;
private PerDimensionProcessor perDimensionProcessor;
private VectorTransformer vectorTransformer;
@Builder.Default
private Map<String, Object> parameters = Collections.emptyMap();
@Builder.Default
Expand All @@ -43,6 +45,11 @@ public VectorValidator getVectorValidator() {
return vectorValidator;
}

@Override
public VectorTransformer getVectorTransformer() {
return vectorTransformer;
}

@Override
public PerDimensionValidator getPerDimensionValidator() {
return perDimensionValidator;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.opensearch.knn.index.engine.MethodComponentContext;
import org.opensearch.knn.index.mapper.PerDimensionProcessor;
import org.opensearch.knn.index.mapper.PerDimensionValidator;
import org.opensearch.knn.index.mapper.VectorTransformer;

import java.util.Objects;
import java.util.Set;
Expand Down Expand Up @@ -89,6 +90,11 @@ protected PerDimensionProcessor doGetPerDimensionProcessor(
throw new IllegalStateException("Unsupported vector data type " + vectorDataType);
}

@Override
protected VectorTransformer getVectorTransformer(KNNMethodContext knnMethodContext) {
return super.getVectorTransformer(knnMethodContext);
}

static KNNLibraryIndexingContext adjustIndexDescription(
MethodAsMapBuilder methodAsMapBuilder,
MethodComponentContext methodComponentContext,
Expand Down Expand Up @@ -132,4 +138,15 @@ static MethodComponentContext getEncoderMethodComponent(MethodComponentContext m
}
return (MethodComponentContext) object;
}

@Override
protected SpaceType getCompatibleSpaceType(SpaceType spaceType) {
// While FAISS doesn't directly support cosine similarity, we can leverage the mathematical
// relationship between cosine similarity and inner product for normalized vectors to add support.
// When ||a|| = ||b|| = 1, cos(θ) = a · b
if (spaceType == SpaceType.COSINESIMIL) {
return SpaceType.INNER_PRODUCT;
}
return super.getCompatibleSpaceType(spaceType);
}
}
22 changes: 17 additions & 5 deletions src/main/java/org/opensearch/knn/index/engine/faiss/Faiss.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
*/
public class Faiss extends NativeLibrary {
public static final String FAISS_BINARY_INDEX_DESCRIPTION_PREFIX = "B";
Map<SpaceType, Function<Float, Float>> distanceTransform;
Map<SpaceType, Function<Float, Float>> scoreTransform;

// TODO: Current version is not really current version. Instead, it encodes information in the file name
Expand All @@ -36,7 +37,10 @@ public class Faiss extends NativeLibrary {
// Map that overrides OpenSearch score translation by space type of scores returned by faiss
private final static Map<SpaceType, Function<Float, Float>> SCORE_TRANSLATIONS = ImmutableMap.of(
SpaceType.INNER_PRODUCT,
rawScore -> SpaceType.INNER_PRODUCT.scoreTranslation(-1 * rawScore)
rawScore -> SpaceType.INNER_PRODUCT.scoreTranslation(-1 * rawScore),
// COSINESIMIL expects the raw score in 1 - cosine(x,y)
SpaceType.COSINESIMIL,
rawScore -> SpaceType.COSINESIMIL.scoreTranslation(1 - rawScore)
);

// Map that overrides radial search score threshold to faiss required distance, check more details in knn documentation:
Expand All @@ -45,6 +49,10 @@ public class Faiss extends NativeLibrary {
SpaceType,
Function<Float, Float>>builder().put(SpaceType.INNER_PRODUCT, score -> score > 1 ? 1 - score : 1 / score - 1).build();

private final static Map<SpaceType, Function<Float, Float>> DISTANCE_TRANSLATIONS = ImmutableMap.<
SpaceType,
Function<Float, Float>>builder().put(SpaceType.COSINESIMIL, distance -> (2 - distance) / 2).build();

// Package private so that the method resolving logic can access the methods
final static Map<String, KNNMethod> METHODS = ImmutableMap.of(METHOD_HNSW, new FaissHNSWMethod(), METHOD_IVF, new FaissIVFMethod());

Expand All @@ -53,7 +61,8 @@ public class Faiss extends NativeLibrary {
SCORE_TRANSLATIONS,
CURRENT_VERSION,
KNNConstants.FAISS_EXTENSION,
SCORE_TO_DISTANCE_TRANSFORMATIONS
SCORE_TO_DISTANCE_TRANSFORMATIONS,
DISTANCE_TRANSLATIONS
);

private final MethodResolver methodResolver;
Expand All @@ -71,22 +80,25 @@ private Faiss(
Map<SpaceType, Function<Float, Float>> scoreTranslation,
String currentVersion,
String extension,
Map<SpaceType, Function<Float, Float>> scoreTransform
Map<SpaceType, Function<Float, Float>> scoreTransform,
Map<SpaceType, Function<Float, Float>> distanceTransform
) {
super(methods, scoreTranslation, currentVersion, extension);
this.scoreTransform = scoreTransform;
this.distanceTransform = distanceTransform;
this.methodResolver = new FaissMethodResolver();
}

@Override
public Float distanceToRadialThreshold(Float distance, SpaceType spaceType) {
// Faiss engine uses distance as is and does not need transformation
if (this.distanceTransform.containsKey(spaceType)) {
return this.distanceTransform.get(spaceType).apply(distance);
}
return distance;
}

@Override
public Float scoreToRadialThreshold(Float score, SpaceType spaceType) {
// Faiss engine uses distance as is and need transformation
if (this.scoreTransform.containsKey(spaceType)) {
return this.scoreTransform.get(spaceType).apply(score);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ public class FaissHNSWMethod extends AbstractFaissMethod {
SpaceType.UNDEFINED,
SpaceType.HAMMING,
SpaceType.L2,
SpaceType.INNER_PRODUCT
SpaceType.INNER_PRODUCT,
SpaceType.COSINESIMIL
);

private final static MethodComponentContext DEFAULT_ENCODER_CONTEXT = new MethodComponentContext(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ public class FaissIVFMethod extends AbstractFaissMethod {
SpaceType.UNDEFINED,
SpaceType.L2,
SpaceType.INNER_PRODUCT,
SpaceType.HAMMING
SpaceType.HAMMING,
SpaceType.COSINESIMIL
);

private final static MethodComponentContext DEFAULT_ENCODER_CONTEXT = new MethodComponentContext(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,4 +109,9 @@ protected PerDimensionValidator getPerDimensionValidator() {
protected PerDimensionProcessor getPerDimensionProcessor() {
return PerDimensionProcessor.NOOP_PROCESSOR;
}

@Override
protected VectorTransformer getVectorTransformer() {
return VectorTransformer.NOOP_VECTOR_TRANSFORMER;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -681,6 +681,8 @@ protected void validatePreparse() {
*/
protected abstract PerDimensionProcessor getPerDimensionProcessor();

protected abstract VectorTransformer getVectorTransformer();

protected void parseCreateField(ParseContext context, int dimension, VectorDataType vectorDataType) throws IOException {
validatePreparse();

Expand All @@ -691,7 +693,8 @@ protected void parseCreateField(ParseContext context, int dimension, VectorDataT
}
final byte[] array = bytesArrayOptional.get();
getVectorValidator().validateVector(array);
context.doc().addAll(getFieldsForByteVector(array));
final byte[] transformedArray = getVectorTransformer().transform(array);
context.doc().addAll(getFieldsForByteVector(transformedArray));
} else if (VectorDataType.FLOAT == vectorDataType) {
Optional<float[]> floatsArrayOptional = getFloatsFromContext(context, dimension);

Expand All @@ -700,7 +703,8 @@ protected void parseCreateField(ParseContext context, int dimension, VectorDataT
}
final float[] array = floatsArrayOptional.get();
getVectorValidator().validateVector(array);
context.doc().addAll(getFieldsForFloatVector(array));
final float[] transformedArray = getVectorTransformer().transform(array);
context.doc().addAll(getFieldsForFloatVector(transformedArray));
} else {
throw new IllegalArgumentException(
String.format(Locale.ROOT, "Cannot parse context for unsupported values provided for field [%s]", VECTOR_DATA_TYPE_FIELD)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ public class LuceneFieldMapper extends KNNVectorFieldMapper {
private final PerDimensionProcessor perDimensionProcessor;
private final PerDimensionValidator perDimensionValidator;
private final VectorValidator vectorValidator;
private final VectorTransformer vectorTransformer;

static LuceneFieldMapper createFieldMapper(
String fullname,
Expand Down Expand Up @@ -122,6 +123,7 @@ private LuceneFieldMapper(
this.perDimensionProcessor = knnLibraryIndexingContext.getPerDimensionProcessor();
this.perDimensionValidator = knnLibraryIndexingContext.getPerDimensionValidator();
this.vectorValidator = knnLibraryIndexingContext.getVectorValidator();
this.vectorTransformer = knnLibraryIndexingContext.getVectorTransformer();
}

@Override
Expand Down Expand Up @@ -169,6 +171,11 @@ protected PerDimensionProcessor getPerDimensionProcessor() {
return perDimensionProcessor;
}

@Override
protected VectorTransformer getVectorTransformer() {
return vectorTransformer;
}

@Override
void updateEngineStats() {
KNNEngine.LUCENE.setInitialized(true);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ public class MethodFieldMapper extends KNNVectorFieldMapper {
private final PerDimensionProcessor perDimensionProcessor;
private final PerDimensionValidator perDimensionValidator;
private final VectorValidator vectorValidator;
private final VectorTransformer vectorTransformer;

public static MethodFieldMapper createFieldMapper(
String fullname,
Expand Down Expand Up @@ -180,6 +181,7 @@ private MethodFieldMapper(
this.perDimensionProcessor = knnLibraryIndexingContext.getPerDimensionProcessor();
this.perDimensionValidator = knnLibraryIndexingContext.getPerDimensionValidator();
this.vectorValidator = knnLibraryIndexingContext.getVectorValidator();
this.vectorTransformer = knnLibraryIndexingContext.getVectorTransformer();
}

@Override
Expand All @@ -196,4 +198,9 @@ protected PerDimensionValidator getPerDimensionValidator() {
protected PerDimensionProcessor getPerDimensionProcessor() {
return perDimensionProcessor;
}

@Override
protected VectorTransformer getVectorTransformer() {
return vectorTransformer;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ public class ModelFieldMapper extends KNNVectorFieldMapper {
private PerDimensionProcessor perDimensionProcessor;
private PerDimensionValidator perDimensionValidator;
private VectorValidator vectorValidator;
private VectorTransformer vectorTransformer;

private final String modelId;

Expand Down Expand Up @@ -192,6 +193,37 @@ protected PerDimensionProcessor getPerDimensionProcessor() {
return perDimensionProcessor;
}

@Override
protected VectorTransformer getVectorTransformer() {
initVectorTransformer();
return vectorTransformer;
}

/**
* Initializes the vector transformer for the model field if not already initialized.
* This method handles the vector transformation configuration based on the model metadata
* and KNN method context.
* @throws IllegalStateException if model metadata cannot be retrieved
*/
private void initVectorTransformer() {
if (vectorTransformer != null) {
return;
}
ModelMetadata modelMetadata = getModelMetadata(modelDao, modelId);

KNNMethodContext knnMethodContext = getKNNMethodContextFromModelMetadata(modelMetadata);
KNNMethodConfigContext knnMethodConfigContext = getKNNMethodConfigContextFromModelMetadata(modelMetadata);
// Need to handle BWC case
if (knnMethodContext == null || knnMethodConfigContext == null) {
vectorTransformer = VectorTransformerFactory.getVectorTransformer(modelMetadata.getKnnEngine(), modelMetadata.getSpaceType());
return;
}

KNNLibraryIndexingContext knnLibraryIndexingContext = knnMethodContext.getKnnEngine()
.getKNNLibraryIndexingContext(knnMethodContext, knnMethodConfigContext);
vectorTransformer = knnLibraryIndexingContext.getVectorTransformer();
}

private void initVectorValidator() {
if (vectorValidator != null) {
return;
Expand Down
Loading

0 comments on commit 4658bee

Please sign in to comment.