From f3b2bd08e32dc12a08fb6917c018c6bbbfa4bd3f Mon Sep 17 00:00:00 2001 From: Doo Yong Kim <0ctopus13prime@gmail.com> Date: Thu, 3 Oct 2024 09:14:24 -0700 Subject: [PATCH] Introducing a loading layer in FAISS native engine. (#2139) * Introducing a loading layer in FAISS native engine. Signed-off-by: Dooyong Kim * Update change log. Signed-off-by: Dooyong Kim * Added unit tests for Faiss stream support. Signed-off-by: Dooyong Kim * Fix a bug to pass a KB size integer value as a byte size integer parameter. Signed-off-by: Dooyong Kim * Fix a casting bugs when it tries to laod more than 4G sized index file. Signed-off-by: Dooyong Kim * Added unit tests for new methods in JNIService. Signed-off-by: Dooyong Kim * Fix formatting and removed nmslib_stream_support. Signed-off-by: Dooyong Kim * Removing redundant exception message in JNIService.loadIndex. Signed-off-by: Dooyong Kim * Fix a flaky testing - testIndexAllocation_closeBlocking Signed-off-by: Dooyong Kim --------- Signed-off-by: Dooyong Kim Signed-off-by: Doo Yong Kim <0ctopus13prime@gmail.com> Co-authored-by: Dooyong Kim --- CHANGELOG.md | 1 + jni/CMakeLists.txt | 1 + jni/include/faiss_stream_support.h | 136 +++++++++++++ jni/include/faiss_wrapper.h | 10 + jni/include/jni_util.h | 102 ++++++---- .../org_opensearch_knn_jni_FaissService.h | 16 ++ jni/src/faiss_wrapper.cpp | 28 +++ jni/src/jni_util.cpp | 28 +++ .../org_opensearch_knn_jni_FaissService.cpp | 43 ++++ jni/tests/commons_test.cpp | 2 +- jni/tests/faiss_stream_support_test.cpp | 132 ++++++++++++ jni/tests/test_util.h | 8 + .../opensearch/knn/index/KNNIndexShard.java | 3 + .../index/memory/NativeMemoryAllocation.java | 34 ++-- .../memory/NativeMemoryEntryContext.java | 48 ++--- .../memory/NativeMemoryLoadStrategy.java | 62 +++++- .../opensearch/knn/index/query/KNNWeight.java | 1 + .../knn/index/store/IndexInputWithBuffer.java | 46 +++++ .../org/opensearch/knn/jni/FaissService.java | 19 ++ .../org/opensearch/knn/jni/JNIService.java | 23 +++ .../memory/NativeMemoryAllocationTests.java | 24 ++- .../memory/NativeMemoryEntryContextTests.java | 5 + .../memory/NativeMemoryLoadStrategyTests.java | 6 + .../opensearch/knn/jni/JNIServiceTests.java | 189 ++++++++++++++++++ 24 files changed, 864 insertions(+), 103 deletions(-) create mode 100644 jni/include/faiss_stream_support.h create mode 100644 jni/tests/faiss_stream_support_test.cpp create mode 100644 src/main/java/org/opensearch/knn/index/store/IndexInputWithBuffer.java diff --git a/CHANGELOG.md b/CHANGELOG.md index 6871074f0..267f6248d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased 3.0](https://github.com/opensearch-project/k-NN/compare/2.x...HEAD) ### Features ### Enhancements +* Introducing a loading layer in FAISS [#2033](https://github.com/opensearch-project/k-NN/issues/2033) ### Bug Fixes * Add DocValuesProducers for releasing memory when close index [#1946](https://github.com/opensearch-project/k-NN/pull/1946) ### Infrastructure diff --git a/jni/CMakeLists.txt b/jni/CMakeLists.txt index 3c071fc1f..4caa907b3 100644 --- a/jni/CMakeLists.txt +++ b/jni/CMakeLists.txt @@ -154,6 +154,7 @@ if ("${WIN32}" STREQUAL "") tests/nmslib_wrapper_unit_test.cpp tests/test_util.cpp tests/commons_test.cpp + tests/faiss_stream_support_test.cpp tests/faiss_index_service_test.cpp ) diff --git a/jni/include/faiss_stream_support.h b/jni/include/faiss_stream_support.h new file mode 100644 index 000000000..65f1631d4 --- /dev/null +++ b/jni/include/faiss_stream_support.h @@ -0,0 +1,136 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +#ifndef OPENSEARCH_KNN_JNI_STREAM_SUPPORT_H +#define OPENSEARCH_KNN_JNI_STREAM_SUPPORT_H + +#include "faiss/impl/io.h" +#include "jni_util.h" + +#include +#include +#include +#include + +namespace knn_jni { +namespace stream { + +/** + * This class contains Java IndexInputWithBuffer reference and calls its API to copy required bytes into a read buffer. + */ + +class NativeEngineIndexInputMediator { + public: + // Expect IndexInputWithBuffer is given as `_indexInput`. + NativeEngineIndexInputMediator(JNIUtilInterface *_jni_interface, + JNIEnv *_env, + jobject _indexInput) + : jni_interface(_jni_interface), + env(_env), + indexInput(_indexInput), + bufferArray((jbyteArray) (_jni_interface->GetObjectField(_env, + _indexInput, + getBufferFieldId(_jni_interface, _env)))), + copyBytesMethod(getCopyBytesMethod(_jni_interface, _env)) { + } + + void copyBytes(int64_t nbytes, uint8_t *destination) { + while (nbytes > 0) { + // Call `copyBytes` to read bytes as many as possible. + const auto readBytes = + jni_interface->CallIntMethodLong(env, indexInput, copyBytesMethod, nbytes); + + // === Critical Section Start === + + // Get primitive array pointer, no copy is happening in OpenJDK. + auto primitiveArray = + (jbyte *) jni_interface->GetPrimitiveArrayCritical(env, bufferArray, nullptr); + + // Copy Java bytes to C++ destination address. + std::memcpy(destination, primitiveArray, readBytes); + + // Release the acquired primitive array pointer. + // JNI_ABORT tells JVM to directly free memory without copying back to Java byte[]. + // Since we're merely copying data, we don't need to copying back. + jni_interface->ReleasePrimitiveArrayCritical(env, bufferArray, primitiveArray, JNI_ABORT); + + // === Critical Section End === + + destination += readBytes; + nbytes -= readBytes; + } // End while + } + + private: + static jclass getIndexInputWithBufferClass(JNIUtilInterface *jni_interface, JNIEnv *env) { + static jclass INDEX_INPUT_WITH_BUFFER_CLASS = + jni_interface->FindClassFromJNIEnv(env, "org/opensearch/knn/index/store/IndexInputWithBuffer"); + return INDEX_INPUT_WITH_BUFFER_CLASS; + } + + static jmethodID getCopyBytesMethod(JNIUtilInterface *jni_interface, JNIEnv *env) { + static jmethodID COPY_METHOD_ID = + jni_interface->GetMethodID(env, getIndexInputWithBufferClass(jni_interface, env), "copyBytes", "(J)I"); + return COPY_METHOD_ID; + } + + static jfieldID getBufferFieldId(JNIUtilInterface *jni_interface, JNIEnv *env) { + static jfieldID BUFFER_FIELD_ID = + jni_interface->GetFieldID(env, getIndexInputWithBufferClass(jni_interface, env), "buffer", "[B"); + return BUFFER_FIELD_ID; + } + + JNIUtilInterface *jni_interface; + JNIEnv *env; + + // `IndexInputWithBuffer` instance having `IndexInput` instance obtained from `Directory` for reading. + jobject indexInput; + jbyteArray bufferArray; + jmethodID copyBytesMethod; +}; // class NativeEngineIndexInputMediator + + + +/** + * A glue component inheriting IOReader to be passed down to Faiss library. + * This will then indirectly call the mediator component and eventually read required bytes from Lucene's IndexInput. + */ +class FaissOpenSearchIOReader final : public faiss::IOReader { + public: + explicit FaissOpenSearchIOReader(NativeEngineIndexInputMediator *_mediator) + : faiss::IOReader(), + mediator(_mediator) { + name = "FaissOpenSearchIOReader"; + } + + size_t operator()(void *ptr, size_t size, size_t nitems) final { + const auto readBytes = size * nitems; + if (readBytes > 0) { + // Mediator calls IndexInput, then copy read bytes to `ptr`. + mediator->copyBytes(readBytes, (uint8_t *) ptr); + } + return nitems; + } + + int filedescriptor() final { + throw std::runtime_error("filedescriptor() is not supported in FaissOpenSearchIOReader."); + } + + private: + NativeEngineIndexInputMediator *mediator; +}; // class FaissOpenSearchIOReader + + + +} +} + +#endif //OPENSEARCH_KNN_JNI_STREAM_SUPPORT_H diff --git a/jni/include/faiss_wrapper.h b/jni/include/faiss_wrapper.h index d6375653d..8ffce4ad1 100644 --- a/jni/include/faiss_wrapper.h +++ b/jni/include/faiss_wrapper.h @@ -47,11 +47,21 @@ namespace knn_jni { // Return a pointer to the loaded index jlong LoadIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jstring indexPathJ); + // Loads an index with a reader implemented IOReader + // + // Returns a pointer of the loaded index + jlong LoadIndexWithStream(faiss::IOReader* ioReader); + // Load a binary index from indexPathJ into memory. // // Return a pointer to the loaded index jlong LoadBinaryIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jstring indexPathJ); + // Loads a binary index with a reader implemented IOReader + // + // Returns a pointer of the loaded index + jlong LoadBinaryIndexWithStream(faiss::IOReader* ioReader); + // Check if a loaded index requires shared state bool IsSharedIndexStateRequired(jlong indexPointerJ); diff --git a/jni/include/jni_util.h b/jni/include/jni_util.h index 825471a3c..6b1b926e7 100644 --- a/jni/include/jni_util.h +++ b/jni/include/jni_util.h @@ -22,8 +22,7 @@ namespace knn_jni { // Interface for making calls to JNI - class JNIUtilInterface { - public: + struct JNIUtilInterface { // -------------------------- EXCEPTION HANDLING ---------------------------- // Takes the name of a Java exception type and a message and throws the corresponding exception // to the JVM @@ -127,56 +126,77 @@ namespace knn_jni { virtual void SetByteArrayRegion(JNIEnv *env, jbyteArray array, jsize start, jsize len, const jbyte * buf) = 0; + virtual jobject GetObjectField(JNIEnv * env, jobject obj, jfieldID fieldID) = 0; + + virtual jclass FindClassFromJNIEnv(JNIEnv * env, const char *name) = 0; + + virtual jmethodID GetMethodID(JNIEnv * env, jclass clazz, const char *name, const char *sig) = 0; + + virtual jfieldID GetFieldID(JNIEnv * env, jclass clazz, const char *name, const char *sig) = 0; + + virtual void * GetPrimitiveArrayCritical(JNIEnv * env, jarray array, jboolean *isCopy) = 0; + + virtual void ReleasePrimitiveArrayCritical(JNIEnv * env, jarray array, void *carray, jint mode) = 0; + + virtual jint CallIntMethodLong(JNIEnv * env, jobject obj, jmethodID methodID, int64_t longArg) = 0; + // -------------------------------------------------------------------------- }; jobject GetJObjectFromMapOrThrow(std::unordered_map map, std::string key); // Class that implements JNIUtilInterface methods - class JNIUtil: public JNIUtilInterface { + class JNIUtil final : public JNIUtilInterface { public: // Initialize and Uninitialize methods are used for caching/cleaning up Java classes and methods void Initialize(JNIEnv* env); void Uninitialize(JNIEnv* env); - void ThrowJavaException(JNIEnv* env, const char* type = "", const char* message = ""); - void HasExceptionInStack(JNIEnv* env); - void HasExceptionInStack(JNIEnv* env, const std::string& message); - void CatchCppExceptionAndThrowJava(JNIEnv* env); - jclass FindClass(JNIEnv * env, const std::string& className); - jmethodID FindMethod(JNIEnv * env, const std::string& className, const std::string& methodName); - std::string ConvertJavaStringToCppString(JNIEnv * env, jstring javaString); - std::unordered_map ConvertJavaMapToCppMap(JNIEnv *env, jobject parametersJ); - std::string ConvertJavaObjectToCppString(JNIEnv *env, jobject objectJ); - int ConvertJavaObjectToCppInteger(JNIEnv *env, jobject objectJ); - std::vector Convert2dJavaObjectArrayToCppFloatVector(JNIEnv *env, jobjectArray array2dJ, int dim); - std::vector ConvertJavaIntArrayToCppIntVector(JNIEnv *env, jintArray arrayJ); - int GetInnerDimensionOf2dJavaFloatArray(JNIEnv *env, jobjectArray array2dJ); - int GetInnerDimensionOf2dJavaByteArray(JNIEnv *env, jobjectArray array2dJ); - int GetJavaObjectArrayLength(JNIEnv *env, jobjectArray arrayJ); - int GetJavaIntArrayLength(JNIEnv *env, jintArray arrayJ); - int GetJavaLongArrayLength(JNIEnv *env, jlongArray arrayJ); - int GetJavaBytesArrayLength(JNIEnv *env, jbyteArray arrayJ); - int GetJavaFloatArrayLength(JNIEnv *env, jfloatArray arrayJ); - - void DeleteLocalRef(JNIEnv *env, jobject obj); - jbyte * GetByteArrayElements(JNIEnv *env, jbyteArray array, jboolean * isCopy); - jfloat * GetFloatArrayElements(JNIEnv *env, jfloatArray array, jboolean * isCopy); - jint * GetIntArrayElements(JNIEnv *env, jintArray array, jboolean * isCopy); - jlong * GetLongArrayElements(JNIEnv *env, jlongArray array, jboolean * isCopy); - jobject GetObjectArrayElement(JNIEnv *env, jobjectArray array, jsize index); - jobject NewObject(JNIEnv *env, jclass clazz, jmethodID methodId, int id, float distance); - jobjectArray NewObjectArray(JNIEnv *env, jsize len, jclass clazz, jobject init); - jbyteArray NewByteArray(JNIEnv *env, jsize len); - void ReleaseByteArrayElements(JNIEnv *env, jbyteArray array, jbyte *elems, int mode); - void ReleaseFloatArrayElements(JNIEnv *env, jfloatArray array, jfloat *elems, int mode); - void ReleaseIntArrayElements(JNIEnv *env, jintArray array, jint *elems, jint mode); - void ReleaseLongArrayElements(JNIEnv *env, jlongArray array, jlong *elems, jint mode); - void SetObjectArrayElement(JNIEnv *env, jobjectArray array, jsize index, jobject val); - void SetByteArrayRegion(JNIEnv *env, jbyteArray array, jsize start, jsize len, const jbyte * buf); - void Convert2dJavaObjectArrayAndStoreToFloatVector(JNIEnv *env, jobjectArray array2dJ, int dim, std::vector *vect); - void Convert2dJavaObjectArrayAndStoreToBinaryVector(JNIEnv *env, jobjectArray array2dJ, int dim, std::vector *vect); - void Convert2dJavaObjectArrayAndStoreToByteVector(JNIEnv *env, jobjectArray array2dJ, int dim, std::vector *vect); + void ThrowJavaException(JNIEnv* env, const char* type = "", const char* message = "") final; + void HasExceptionInStack(JNIEnv* env) final; + void HasExceptionInStack(JNIEnv* env, const std::string& message) final; + void CatchCppExceptionAndThrowJava(JNIEnv* env) final; + jclass FindClass(JNIEnv * env, const std::string& className) final; + jmethodID FindMethod(JNIEnv * env, const std::string& className, const std::string& methodName) final; + std::string ConvertJavaStringToCppString(JNIEnv * env, jstring javaString) final; + std::unordered_map ConvertJavaMapToCppMap(JNIEnv *env, jobject parametersJ) final; + std::string ConvertJavaObjectToCppString(JNIEnv *env, jobject objectJ) final; + int ConvertJavaObjectToCppInteger(JNIEnv *env, jobject objectJ) final; + std::vector Convert2dJavaObjectArrayToCppFloatVector(JNIEnv *env, jobjectArray array2dJ, int dim) final; + std::vector ConvertJavaIntArrayToCppIntVector(JNIEnv *env, jintArray arrayJ) final; + int GetInnerDimensionOf2dJavaFloatArray(JNIEnv *env, jobjectArray array2dJ) final; + int GetInnerDimensionOf2dJavaByteArray(JNIEnv *env, jobjectArray array2dJ) final; + int GetJavaObjectArrayLength(JNIEnv *env, jobjectArray arrayJ) final; + int GetJavaIntArrayLength(JNIEnv *env, jintArray arrayJ) final; + int GetJavaLongArrayLength(JNIEnv *env, jlongArray arrayJ) final; + int GetJavaBytesArrayLength(JNIEnv *env, jbyteArray arrayJ) final; + int GetJavaFloatArrayLength(JNIEnv *env, jfloatArray arrayJ) final; + + void DeleteLocalRef(JNIEnv *env, jobject obj) final; + jbyte * GetByteArrayElements(JNIEnv *env, jbyteArray array, jboolean * isCopy) final; + jfloat * GetFloatArrayElements(JNIEnv *env, jfloatArray array, jboolean * isCopy) final; + jint * GetIntArrayElements(JNIEnv *env, jintArray array, jboolean * isCopy) final; + jlong * GetLongArrayElements(JNIEnv *env, jlongArray array, jboolean * isCopy) final; + jobject GetObjectArrayElement(JNIEnv *env, jobjectArray array, jsize index) final; + jobject NewObject(JNIEnv *env, jclass clazz, jmethodID methodId, int id, float distance) final; + jobjectArray NewObjectArray(JNIEnv *env, jsize len, jclass clazz, jobject init) final; + jbyteArray NewByteArray(JNIEnv *env, jsize len) final; + void ReleaseByteArrayElements(JNIEnv *env, jbyteArray array, jbyte *elems, int mode) final; + void ReleaseFloatArrayElements(JNIEnv *env, jfloatArray array, jfloat *elems, int mode) final; + void ReleaseIntArrayElements(JNIEnv *env, jintArray array, jint *elems, jint mode) final; + void ReleaseLongArrayElements(JNIEnv *env, jlongArray array, jlong *elems, jint mode) final; + void SetObjectArrayElement(JNIEnv *env, jobjectArray array, jsize index, jobject val) final; + void SetByteArrayRegion(JNIEnv *env, jbyteArray array, jsize start, jsize len, const jbyte * buf) final; + void Convert2dJavaObjectArrayAndStoreToFloatVector(JNIEnv *env, jobjectArray array2dJ, int dim, std::vector *vect) final; + void Convert2dJavaObjectArrayAndStoreToBinaryVector(JNIEnv *env, jobjectArray array2dJ, int dim, std::vector *vect) final; + void Convert2dJavaObjectArrayAndStoreToByteVector(JNIEnv *env, jobjectArray array2dJ, int dim, std::vector *vect) final; + jobject GetObjectField(JNIEnv * env, jobject obj, jfieldID fieldID) final; + jclass FindClassFromJNIEnv(JNIEnv * env, const char *name) final; + jmethodID GetMethodID(JNIEnv * env, jclass clazz, const char *name, const char *sig) final; + jfieldID GetFieldID(JNIEnv * env, jclass clazz, const char *name, const char *sig) final; + jint CallIntMethodLong(JNIEnv * env, jobject obj, jmethodID methodID, int64_t longArg) final; + void * GetPrimitiveArrayCritical(JNIEnv * env, jarray array, jboolean *isCopy) final; + void ReleasePrimitiveArrayCritical(JNIEnv * env, jarray array, void *carray, jint mode) final; private: std::unordered_map cachedClasses; diff --git a/jni/include/org_opensearch_knn_jni_FaissService.h b/jni/include/org_opensearch_knn_jni_FaissService.h index d42ce197c..2969df3ae 100644 --- a/jni/include/org_opensearch_knn_jni_FaissService.h +++ b/jni/include/org_opensearch_knn_jni_FaissService.h @@ -128,6 +128,14 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndexFromT JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_loadIndex (JNIEnv *, jclass, jstring); +/* + * Class: org_opensearch_knn_jni_FaissService + * Method: loadIndexWithStream + * Signature: (Lorg/opensearch/knn/index/util/IndexInputWithBuffer;)J + */ +JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_loadIndexWithStream + (JNIEnv *, jclass, jobject); + /* * Class: org_opensearch_knn_jni_FaissService * Method: loadBinaryIndex @@ -136,6 +144,14 @@ JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_loadIndex JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_loadBinaryIndex (JNIEnv *, jclass, jstring); +/* + * Class: org_opensearch_knn_jni_FaissService + * Method: loadBinaryIndexWithStream + * Signature: (Lorg/opensearch/knn/index/util/IndexInputWithBuffer;)J + */ +JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_loadBinaryIndexWithStream + (JNIEnv *, jclass, jobject); + /* * Class: org_opensearch_knn_jni_FaissService * Method: isSharedIndexStateRequired diff --git a/jni/src/faiss_wrapper.cpp b/jni/src/faiss_wrapper.cpp index 45548e0f7..d1c7648dc 100644 --- a/jni/src/faiss_wrapper.cpp +++ b/jni/src/faiss_wrapper.cpp @@ -423,6 +423,20 @@ jlong knn_jni::faiss_wrapper::LoadIndex(knn_jni::JNIUtilInterface * jniUtil, JNI return (jlong) indexReader; } +jlong knn_jni::faiss_wrapper::LoadIndexWithStream(faiss::IOReader* ioReader) { + if (ioReader == nullptr) [[unlikely]] { + throw std::runtime_error("IOReader cannot be null"); + } + + faiss::Index* indexReader = + faiss::read_index(ioReader, + faiss::IO_FLAG_READ_ONLY + | faiss::IO_FLAG_PQ_SKIP_SDC_TABLE + | faiss::IO_FLAG_SKIP_PRECOMPUTE_TABLE); + + return (jlong) indexReader; +} + jlong knn_jni::faiss_wrapper::LoadBinaryIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jstring indexPathJ) { if (indexPathJ == nullptr) { throw std::runtime_error("Index path cannot be null"); @@ -436,6 +450,20 @@ jlong knn_jni::faiss_wrapper::LoadBinaryIndex(knn_jni::JNIUtilInterface * jniUti return (jlong) indexReader; } +jlong knn_jni::faiss_wrapper::LoadBinaryIndexWithStream(faiss::IOReader* ioReader) { + if (ioReader == nullptr) [[unlikely]] { + throw std::runtime_error("IOReader cannot be null"); + } + + faiss::IndexBinary* indexReader = + faiss::read_index_binary(ioReader, + faiss::IO_FLAG_READ_ONLY + | faiss::IO_FLAG_PQ_SKIP_SDC_TABLE + | faiss::IO_FLAG_SKIP_PRECOMPUTE_TABLE); + + return (jlong) indexReader; +} + bool knn_jni::faiss_wrapper::IsSharedIndexStateRequired(jlong indexPointerJ) { auto * index = reinterpret_cast(indexPointerJ); return isIndexIVFPQL2(index); diff --git a/jni/src/jni_util.cpp b/jni/src/jni_util.cpp index 82900b5ce..3eaf3b0a1 100644 --- a/jni/src/jni_util.cpp +++ b/jni/src/jni_util.cpp @@ -547,6 +547,34 @@ void knn_jni::JNIUtil::SetByteArrayRegion(JNIEnv *env, jbyteArray array, jsize s this->HasExceptionInStack(env, "Unable to set byte array region"); } +jobject knn_jni::JNIUtil::GetObjectField(JNIEnv * env, jobject obj, jfieldID fieldID) { + return env->GetObjectField(obj, fieldID); +} + +jclass knn_jni::JNIUtil::FindClassFromJNIEnv(JNIEnv * env, const char *name) { + return env->FindClass(name); +} + +jmethodID knn_jni::JNIUtil::GetMethodID(JNIEnv * env, jclass clazz, const char *name, const char *sig) { + return env->GetMethodID(clazz, name, sig); +} + +jfieldID knn_jni::JNIUtil::GetFieldID(JNIEnv * env, jclass clazz, const char *name, const char *sig) { + return env->GetFieldID(clazz, name, sig); +} + +jint knn_jni::JNIUtil::CallIntMethodLong(JNIEnv * env, jobject obj, jmethodID methodID, int64_t longArg) { + return env->CallIntMethod(obj, methodID, longArg); +} + +void * knn_jni::JNIUtil::GetPrimitiveArrayCritical(JNIEnv * env, jarray array, jboolean *isCopy) { + return env->GetPrimitiveArrayCritical(array, isCopy); +} + +void knn_jni::JNIUtil::ReleasePrimitiveArrayCritical(JNIEnv * env, jarray array, void *carray, jint mode) { + return env->ReleasePrimitiveArrayCritical(array, carray, mode); +} + jobject knn_jni::GetJObjectFromMapOrThrow(std::unordered_map map, std::string key) { if(map.find(key) == map.end()) { throw std::runtime_error(key + " not found"); diff --git a/jni/src/org_opensearch_knn_jni_FaissService.cpp b/jni/src/org_opensearch_knn_jni_FaissService.cpp index 70c986b7d..7326c7ba0 100644 --- a/jni/src/org_opensearch_knn_jni_FaissService.cpp +++ b/jni/src/org_opensearch_knn_jni_FaissService.cpp @@ -17,6 +17,7 @@ #include "faiss_wrapper.h" #include "jni_util.h" +#include "faiss_stream_support.h" static knn_jni::JNIUtil jniUtil; static const jint KNN_FAISS_JNI_VERSION = JNI_VERSION_1_1; @@ -217,6 +218,27 @@ JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_loadIndex(JNIEn return NULL; } +JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_loadIndexWithStream + (JNIEnv * env, jclass cls, jobject readStream) +{ + try { + // Create a mediator locally. + // Note that `indexInput` is `IndexInputWithBuffer` type. + knn_jni::stream::NativeEngineIndexInputMediator mediator {&jniUtil, env, readStream}; + + // Wrap the mediator with a glue code inheriting IOReader. + knn_jni::stream::FaissOpenSearchIOReader faissOpenSearchIOReader {&mediator}; + + // Pass IOReader to Faiss for loading vector index. + return knn_jni::faiss_wrapper::LoadIndexWithStream( + &faissOpenSearchIOReader); + } catch (...) { + jniUtil.CatchCppExceptionAndThrowJava(env); + } + + return NULL; +} + JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_loadBinaryIndex(JNIEnv * env, jclass cls, jstring indexPathJ) { try { @@ -227,6 +249,27 @@ JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_loadBinaryIndex return NULL; } +JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_loadBinaryIndexWithStream + (JNIEnv * env, jclass cls, jobject readStream) +{ + try { + // Create a mediator locally. + // Note that `indexInput` is `IndexInputWithBuffer` type. + knn_jni::stream::NativeEngineIndexInputMediator mediator {&jniUtil, env, readStream}; + + // Wrap the mediator with a glue code inheriting IOReader. + knn_jni::stream::FaissOpenSearchIOReader faissOpenSearchIOReader {&mediator}; + + // Pass IOReader to Faiss for loading vector index. + return knn_jni::faiss_wrapper::LoadBinaryIndexWithStream( + &faissOpenSearchIOReader); + } catch (...) { + jniUtil.CatchCppExceptionAndThrowJava(env); + } + + return NULL; +} + JNIEXPORT jboolean JNICALL Java_org_opensearch_knn_jni_FaissService_isSharedIndexStateRequired (JNIEnv * env, jclass cls, jlong indexPointerJ) { diff --git a/jni/tests/commons_test.cpp b/jni/tests/commons_test.cpp index d469fe268..39d7f3c99 100644 --- a/jni/tests/commons_test.cpp +++ b/jni/tests/commons_test.cpp @@ -179,7 +179,7 @@ TEST(StoreByteVectorTest, BasicAssertions) { } // Check that freeing vector data works - knn_jni::commons::freeVectorData(memoryAddress); + knn_jni::commons::freeBinaryVectorData(memoryAddress); } TEST(CommonTests, GetIntegerMethodParam) { diff --git a/jni/tests/faiss_stream_support_test.cpp b/jni/tests/faiss_stream_support_test.cpp new file mode 100644 index 000000000..4045985bb --- /dev/null +++ b/jni/tests/faiss_stream_support_test.cpp @@ -0,0 +1,132 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// The OpenSearch Contributors require contributions made to +// this file be licensed under the Apache-2.0 license or a +// compatible open source license. +// +// Modifications Copyright OpenSearch Contributors. See +// GitHub history for details. + +#include "faiss_stream_support.h" +#include +#include "test_util.h" +#include +#include +#include +#include + +using ::testing::Return; +using knn_jni::stream::FaissOpenSearchIOReader; +using knn_jni::stream::NativeEngineIndexInputMediator; +using test_util::MockJNIUtil; + +// Mocking IndexInputWithBuffer. +struct JavaIndexInputMock { + JavaIndexInputMock(std::string _readTargetBytes, int32_t _bufSize) + : readTargetBytes(std::move(_readTargetBytes)), + nextReadIdx(), + buffer(_bufSize) { + } + + // This method is simulating `copyBytes` in IndexInputWithBuffer. + int32_t simulateCopyReads(int64_t readBytes) { + readBytes = std::min(readBytes, (int64_t) buffer.size()); + readBytes = std::min(readBytes, (int64_t) (readTargetBytes.size() - nextReadIdx)); + std::memcpy(buffer.data(), readTargetBytes.data() + nextReadIdx, readBytes); + nextReadIdx += readBytes; + return (int32_t) readBytes; + } + + static std::string makeRandomBytes(int32_t bytesSize) { + // Define the list of possible characters + static const string CHARACTERS + = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuv" + "wxyz0123456789"; + + // Create a random number generator + std::random_device rd; + std::mt19937 generator(rd()); + + // Create a distribution to uniformly select from all characters + std::uniform_int_distribution<> distribution( + 0, CHARACTERS.size() - 1); + + // Pre-allocate the string with the desired length + std::string randomString(bytesSize, '\0'); + + // Use generate_n with a back_inserter iterator + std::generate_n(randomString.begin(), bytesSize, [&]() { + return CHARACTERS[distribution(generator)]; + }); + + return randomString; + } + + std::string readTargetBytes; + int64_t nextReadIdx; + std::vector buffer; +}; // struct JavaIndexInputMock + +void setUpMockJNIUtil(JavaIndexInputMock &javaIndexInputMock, MockJNIUtil &mockJni) { + // Set up mocking values + mocking behavior in a method. + ON_CALL(mockJni, FindClassFromJNIEnv).WillByDefault(Return((jclass) 1)); + ON_CALL(mockJni, GetMethodID).WillByDefault(Return((jmethodID) 1)); + ON_CALL(mockJni, GetFieldID).WillByDefault(Return((jfieldID) 1)); + ON_CALL(mockJni, GetObjectField).WillByDefault(Return((jobject) 1)); + ON_CALL(mockJni, CallIntMethodLong).WillByDefault([&javaIndexInputMock](JNIEnv *env, + jobject obj, + jmethodID methodID, + int64_t longArg) { + return javaIndexInputMock.simulateCopyReads(longArg); + }); + ON_CALL(mockJni, GetPrimitiveArrayCritical).WillByDefault([&javaIndexInputMock](JNIEnv *env, + jarray array, + jboolean *isCopy) { + return (jbyte *) javaIndexInputMock.buffer.data(); + }); + ON_CALL(mockJni, ReleasePrimitiveArrayCritical).WillByDefault(Return()); +} + +TEST(FaissStreamSupportTest, NativeEngineIndexInputMediatorCopyWhenEmpty) { + for (auto contentSize : std::vector{0, 2222, 7777, 1024, 77, 1}) { + // Set up mockings + MockJNIUtil mockJni; + JavaIndexInputMock javaIndexInputMock{ + JavaIndexInputMock::makeRandomBytes(contentSize), 1024}; + setUpMockJNIUtil(javaIndexInputMock, mockJni); + + // Prepare copying + NativeEngineIndexInputMediator mediator{&mockJni, nullptr, nullptr}; + std::string readBuffer(javaIndexInputMock.readTargetBytes.size(), '\0'); + + // Call copyBytes + mediator.copyBytes((int32_t) javaIndexInputMock.readTargetBytes.size(), (uint8_t *) readBuffer.data()); + + // Expected that we acquired the same contents as readTargetBytes + ASSERT_EQ(javaIndexInputMock.readTargetBytes, readBuffer); + } // End for +} + +TEST(FaissStreamSupportTest, FaissOpenSearchIOReaderCopy) { + for (auto contentSize : std::vector{0, 2222, 7777, 1024, 77, 1}) { + // Set up mockings + MockJNIUtil mockJni; + JavaIndexInputMock javaIndexInputMock{ + JavaIndexInputMock::makeRandomBytes(contentSize), 1024}; + setUpMockJNIUtil(javaIndexInputMock, mockJni); + + // Prepare copying + NativeEngineIndexInputMediator mediator{&mockJni, nullptr, nullptr}; + std::string readBuffer; + readBuffer.resize(javaIndexInputMock.readTargetBytes.size()); + FaissOpenSearchIOReader ioReader{&mediator}; + + // Read bytes + const auto readBytes = + ioReader((void *) readBuffer.data(), 1, javaIndexInputMock.readTargetBytes.size()); + + // Expected that we acquired the same contents as readTargetBytes + ASSERT_EQ(javaIndexInputMock.readTargetBytes.size(), readBytes); + ASSERT_EQ(javaIndexInputMock.readTargetBytes, readBuffer); + } // End for +} diff --git a/jni/tests/test_util.h b/jni/tests/test_util.h index ea02da6f2..286000c08 100644 --- a/jni/tests/test_util.h +++ b/jni/tests/test_util.h @@ -106,6 +106,14 @@ namespace test_util { (JNIEnv * env, jobjectArray array, jsize index, jobject val)); MOCK_METHOD(void, ThrowJavaException, (JNIEnv * env, const char* type, const char* message)); + MOCK_METHOD(jobject, GetObjectField, + (JNIEnv * env, jobject obj, jfieldID fieldID)); + MOCK_METHOD(jclass, FindClassFromJNIEnv, (JNIEnv * env, const char *name)); + MOCK_METHOD(jmethodID, GetMethodID, (JNIEnv * env, jclass clazz, const char *name, const char *sig)); + MOCK_METHOD(jfieldID, GetFieldID, (JNIEnv * env, jclass clazz, const char *name, const char *sig)); + MOCK_METHOD(jint, CallIntMethodLong, (JNIEnv * env, jobject obj, jmethodID methodID, int64_t longArg)); + MOCK_METHOD(void *, GetPrimitiveArrayCritical, (JNIEnv * env, jarray array, jboolean *isCopy)); + MOCK_METHOD(void, ReleasePrimitiveArrayCritical, (JNIEnv * env, jarray array, void *carray, jint mode)); }; // For our unit tests, we want to ensure that each test tests one function in diff --git a/src/main/java/org/opensearch/knn/index/KNNIndexShard.java b/src/main/java/org/opensearch/knn/index/KNNIndexShard.java index 5c096e4f7..4f339cb4e 100644 --- a/src/main/java/org/opensearch/knn/index/KNNIndexShard.java +++ b/src/main/java/org/opensearch/knn/index/KNNIndexShard.java @@ -13,6 +13,7 @@ import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.SegmentReader; +import org.apache.lucene.store.Directory; import org.apache.lucene.store.FSDirectory; import org.apache.lucene.store.FilterDirectory; import org.opensearch.common.lucene.Lucene; @@ -89,11 +90,13 @@ public String getIndexName() { */ public void warmup() throws IOException { log.info("[KNN] Warming up index: [{}]", getIndexName()); + final Directory directory = indexShard.store().directory(); try (Engine.Searcher searcher = indexShard.acquireSearcher("knn-warmup")) { getAllEngineFileContexts(searcher.getIndexReader()).forEach((engineFileContext) -> { try { nativeMemoryCacheManager.get( new NativeMemoryEntryContext.IndexEntryContext( + directory, engineFileContext.getIndexPath(), NativeMemoryLoadStrategy.IndexLoadStrategy.getInstance(), getParametersAtLoading( diff --git a/src/main/java/org/opensearch/knn/index/memory/NativeMemoryAllocation.java b/src/main/java/org/opensearch/knn/index/memory/NativeMemoryAllocation.java index 635bc3883..8adf35447 100644 --- a/src/main/java/org/opensearch/knn/index/memory/NativeMemoryAllocation.java +++ b/src/main/java/org/opensearch/knn/index/memory/NativeMemoryAllocation.java @@ -110,7 +110,7 @@ class IndexAllocation implements NativeMemoryAllocation { private final ExecutorService executor; private final long memoryAddress; - private final int size; + private final int sizeKb; private volatile boolean closed; @Getter private final KNNEngine knnEngine; @@ -130,7 +130,7 @@ class IndexAllocation implements NativeMemoryAllocation { * * @param executorService Executor service used to close the allocation * @param memoryAddress Pointer in memory to the index - * @param size Size this index consumes in kilobytes + * @param sizeKb Size this index consumes in kilobytes * @param knnEngine KNNEngine associated with the index allocation * @param indexPath File path to index * @param openSearchIndexName Name of OpenSearch index this index is associated with @@ -139,13 +139,13 @@ class IndexAllocation implements NativeMemoryAllocation { IndexAllocation( ExecutorService executorService, long memoryAddress, - int size, + int sizeKb, KNNEngine knnEngine, String indexPath, String openSearchIndexName, WatcherHandle watcherHandle ) { - this(executorService, memoryAddress, size, knnEngine, indexPath, openSearchIndexName, watcherHandle, null, false); + this(executorService, memoryAddress, sizeKb, knnEngine, indexPath, openSearchIndexName, watcherHandle, null, false); } /** @@ -153,7 +153,7 @@ class IndexAllocation implements NativeMemoryAllocation { * * @param executorService Executor service used to close the allocation * @param memoryAddress Pointer in memory to the index - * @param size Size this index consumes in kilobytes + * @param sizeKb Size this index consumes in kilobytes * @param knnEngine KNNEngine associated with the index allocation * @param indexPath File path to index * @param openSearchIndexName Name of OpenSearch index this index is associated with @@ -163,7 +163,7 @@ class IndexAllocation implements NativeMemoryAllocation { IndexAllocation( ExecutorService executorService, long memoryAddress, - int size, + int sizeKb, KNNEngine knnEngine, String indexPath, String openSearchIndexName, @@ -178,7 +178,7 @@ class IndexAllocation implements NativeMemoryAllocation { this.openSearchIndexName = openSearchIndexName; this.memoryAddress = memoryAddress; this.readWriteLock = new ReentrantReadWriteLock(); - this.size = size; + this.sizeKb = sizeKb; this.watcherHandle = watcherHandle; this.sharedIndexState = sharedIndexState; this.isBinaryIndex = isBinaryIndex; @@ -187,9 +187,12 @@ class IndexAllocation implements NativeMemoryAllocation { protected void closeInternal() { Runnable onClose = () -> { - writeLock(); - cleanup(); - writeUnlock(); + try { + writeLock(); + cleanup(); + } finally { + writeUnlock(); + } }; // The close operation needs to be blocking to prevent overflow @@ -269,7 +272,7 @@ public void writeUnlock() { @Override public int getSizeInKB() { - return size; + return sizeKb; } @Override @@ -325,9 +328,12 @@ public TrainingDataAllocation(ExecutorService executor, long memoryAddress, int @Override public void close() { executor.execute(() -> { - writeLock(); - cleanup(); - writeUnlock(); + try { + writeLock(); + cleanup(); + } finally { + writeUnlock(); + } }); } diff --git a/src/main/java/org/opensearch/knn/index/memory/NativeMemoryEntryContext.java b/src/main/java/org/opensearch/knn/index/memory/NativeMemoryEntryContext.java index dd219593d..00bf023f9 100644 --- a/src/main/java/org/opensearch/knn/index/memory/NativeMemoryEntryContext.java +++ b/src/main/java/org/opensearch/knn/index/memory/NativeMemoryEntryContext.java @@ -12,6 +12,7 @@ package org.opensearch.knn.index.memory; import lombok.Getter; +import org.apache.lucene.store.Directory; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.Nullable; import org.opensearch.knn.index.engine.qframe.QuantizationConfig; @@ -64,32 +65,40 @@ public String getKey() { public static class IndexEntryContext extends NativeMemoryEntryContext { + @Getter + private final Directory directory; private final NativeMemoryLoadStrategy.IndexLoadStrategy indexLoadStrategy; + @Getter private final String openSearchIndexName; + @Getter private final Map parameters; @Nullable + @Getter private final String modelId; /** * Constructor * - * @param indexPath path to index file. Also used as key in cache. - * @param indexLoadStrategy strategy to load index into memory - * @param parameters load time parameters - * @param openSearchIndexName opensearch index associated with index + * @param directory Lucene directory to create required IndexInput/IndexOutput to access files. + * @param indexPath Path to index file. Also used as key in cache. + * @param indexLoadStrategy Strategy to load index into memory + * @param parameters Load time parameters + * @param openSearchIndexName Opensearch index associated with index */ public IndexEntryContext( + Directory directory, String indexPath, NativeMemoryLoadStrategy.IndexLoadStrategy indexLoadStrategy, Map parameters, String openSearchIndexName ) { - this(indexPath, indexLoadStrategy, parameters, openSearchIndexName, null); + this(directory, indexPath, indexLoadStrategy, parameters, openSearchIndexName, null); } /** * Constructor * + * @param directory Lucene directory to create required IndexInput/IndexOutput to access files. * @param indexPath path to index file. Also used as key in cache. * @param indexLoadStrategy strategy to load index into memory * @param parameters load time parameters @@ -97,6 +106,7 @@ public IndexEntryContext( * @param modelId model to be loaded. If none available, pass null */ public IndexEntryContext( + Directory directory, String indexPath, NativeMemoryLoadStrategy.IndexLoadStrategy indexLoadStrategy, Map parameters, @@ -104,6 +114,7 @@ public IndexEntryContext( String modelId ) { super(indexPath); + this.directory = directory; this.indexLoadStrategy = indexLoadStrategy; this.openSearchIndexName = openSearchIndexName; this.parameters = parameters; @@ -120,33 +131,6 @@ public NativeMemoryAllocation.IndexAllocation load() throws IOException { return indexLoadStrategy.load(this); } - /** - * Getter for OpenSearch index name. - * - * @return OpenSearch index name - */ - public String getOpenSearchIndexName() { - return openSearchIndexName; - } - - /** - * Getter for parameters. - * - * @return parameters - */ - public Map getParameters() { - return parameters; - } - - /** - * Getter - * - * @return return model ID for the index. null if no model is in use - */ - public String getModelId() { - return modelId; - } - private static class IndexSizeCalculator implements Function { static IndexSizeCalculator INSTANCE = new IndexSizeCalculator(); diff --git a/src/main/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategy.java b/src/main/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategy.java index 960c4f5f0..51158d00c 100644 --- a/src/main/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategy.java +++ b/src/main/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategy.java @@ -12,8 +12,12 @@ package org.opensearch.knn.index.memory; import lombok.extern.log4j.Log4j2; +import org.apache.lucene.store.Directory; +import org.apache.lucene.store.IOContext; +import org.apache.lucene.store.IndexInput; import org.opensearch.core.action.ActionListener; import org.opensearch.knn.index.engine.qframe.QuantizationConfig; +import org.opensearch.knn.index.store.IndexInputWithBuffer; import org.opensearch.knn.index.util.IndexUtil; import org.opensearch.knn.jni.JNIService; import org.opensearch.knn.index.engine.KNNEngine; @@ -87,9 +91,9 @@ public void onFileDeleted(Path indexFilePath) { }; } - @Override - public NativeMemoryAllocation.IndexAllocation load(NativeMemoryEntryContext.IndexEntryContext indexEntryContext) - throws IOException { + private NativeMemoryAllocation.IndexAllocation loadWithAbsoluteIndexPath( + NativeMemoryEntryContext.IndexEntryContext indexEntryContext + ) throws IOException { Path indexPath = Paths.get(indexEntryContext.getKey()); FileWatcher fileWatcher = new FileWatcher(indexPath); fileWatcher.addListener(indexFileOnDeleteListener); @@ -97,6 +101,54 @@ public NativeMemoryAllocation.IndexAllocation load(NativeMemoryEntryContext.Inde KNNEngine knnEngine = KNNEngine.getEngineNameFromPath(indexPath.toString()); long indexAddress = JNIService.loadIndex(indexPath.toString(), indexEntryContext.getParameters(), knnEngine); + return createIndexAllocation( + indexEntryContext, + knnEngine, + indexAddress, + fileWatcher, + indexEntryContext.calculateSizeInKB(), + indexPath + ); + } + + @Override + public NativeMemoryAllocation.IndexAllocation load(NativeMemoryEntryContext.IndexEntryContext indexEntryContext) + throws IOException { + final Path absoluteIndexPath = Paths.get(indexEntryContext.getKey()); + final KNNEngine knnEngine = KNNEngine.getEngineNameFromPath(absoluteIndexPath.toString()); + if (knnEngine != KNNEngine.FAISS) { + // We will support other non-FAISS native engines (ex: NMSLIB) soon. + return loadWithAbsoluteIndexPath(indexEntryContext); + } + + final FileWatcher fileWatcher = new FileWatcher(absoluteIndexPath); + fileWatcher.addListener(indexFileOnDeleteListener); + fileWatcher.init(); + + final Directory directory = indexEntryContext.getDirectory(); + + // Ex: Input -> /a/b/c/_0_NativeEngines990KnnVectorsFormat_0.vec + // Output -> _0_NativeEngines990KnnVectorsFormat_0.vec + final String logicalIndexPath = absoluteIndexPath.getFileName().toString(); + + final int indexSizeKb = Math.toIntExact(directory.fileLength(logicalIndexPath) / 1024); + + try (IndexInput readStream = directory.openInput(logicalIndexPath, IOContext.READONCE)) { + IndexInputWithBuffer indexInputWithBuffer = new IndexInputWithBuffer(readStream); + long indexAddress = JNIService.loadIndex(indexInputWithBuffer, indexEntryContext.getParameters(), knnEngine); + + return createIndexAllocation(indexEntryContext, knnEngine, indexAddress, fileWatcher, indexSizeKb, absoluteIndexPath); + } + } + + private NativeMemoryAllocation.IndexAllocation createIndexAllocation( + final NativeMemoryEntryContext.IndexEntryContext indexEntryContext, + final KNNEngine knnEngine, + final long indexAddress, + final FileWatcher fileWatcher, + final int indexSizeKb, + final Path absoluteIndexPath + ) throws IOException { SharedIndexState sharedIndexState = null; String modelId = indexEntryContext.getModelId(); if (IndexUtil.isSharedIndexStateRequired(knnEngine, modelId, indexAddress)) { @@ -109,9 +161,9 @@ public NativeMemoryAllocation.IndexAllocation load(NativeMemoryEntryContext.Inde return new NativeMemoryAllocation.IndexAllocation( executor, indexAddress, - indexEntryContext.calculateSizeInKB(), + indexSizeKb, knnEngine, - indexPath.toString(), + absoluteIndexPath.toString(), indexEntryContext.getOpenSearchIndexName(), watcherHandle, sharedIndexState, diff --git a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java index 1c31ed725..0fd2fddf7 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java @@ -286,6 +286,7 @@ private Map doANNSearch( try { indexAllocation = nativeMemoryCacheManager.get( new NativeMemoryEntryContext.IndexEntryContext( + reader.directory(), indexPath.toString(), NativeMemoryLoadStrategy.IndexLoadStrategy.getInstance(), getParametersAtLoading( diff --git a/src/main/java/org/opensearch/knn/index/store/IndexInputWithBuffer.java b/src/main/java/org/opensearch/knn/index/store/IndexInputWithBuffer.java new file mode 100644 index 000000000..273a4deac --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/store/IndexInputWithBuffer.java @@ -0,0 +1,46 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.store; + +import lombok.NonNull; +import org.apache.lucene.store.IndexInput; + +import java.io.IOException; + +/** + * This class contains a Lucene's IndexInput with a reader buffer. + * A Java reference of this class will be passed to native engines, then 'copyBytes' method will be + * called by native engine via JNI API. + * Therefore, this class servers as a read layer in native engines to read the bytes it wants. + */ +public class IndexInputWithBuffer { + private IndexInput indexInput; + // 4K buffer. + private byte[] buffer = new byte[4 * 1024]; + + public IndexInputWithBuffer(@NonNull IndexInput indexInput) { + this.indexInput = indexInput; + } + + /** + * This method will be invoked in native engines via JNI API. + * Then it will call IndexInput to read required bytes then copy them into a read buffer. + * + * @param nbytes Desired number of bytes to be read. + * @return The number of read bytes in a buffer. + * @throws IOException + */ + private int copyBytes(long nbytes) throws IOException { + final int readBytes = (int) Math.min(nbytes, buffer.length); + indexInput.readBytes(buffer, 0, readBytes); + return readBytes; + } + + @Override + public String toString() { + return "{indexInput=" + indexInput + ", len(buffer)=" + buffer.length + "}"; + } +} diff --git a/src/main/java/org/opensearch/knn/jni/FaissService.java b/src/main/java/org/opensearch/knn/jni/FaissService.java index 4bceed015..c56726c66 100644 --- a/src/main/java/org/opensearch/knn/jni/FaissService.java +++ b/src/main/java/org/opensearch/knn/jni/FaissService.java @@ -14,6 +14,7 @@ import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.query.KNNQueryResult; import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.store.IndexInputWithBuffer; import java.security.AccessController; import java.security.PrivilegedAction; @@ -217,6 +218,15 @@ public static native void createByteIndexFromTemplate( */ public static native long loadIndex(String indexPath); + /** + * Load an index into memory via a wrapping having Lucene's IndexInput. + * Instead of directly accessing an index path, this will make Faiss delegate IndexInput to load bytes. + * + * @param readStream IndexInput wrapper having a Lucene's IndexInput reference. + * @return pointer to location in memory the index resides in + */ + public static native long loadIndexWithStream(IndexInputWithBuffer readStream); + /** * Load a binary index into memory * @@ -225,6 +235,15 @@ public static native void createByteIndexFromTemplate( */ public static native long loadBinaryIndex(String indexPath); + /** + * Load a binary index into memory with a wrapping having Lucene's IndexInput. + * Instead of directly accessing an index path, this will make Faiss delegate IndexInput to load bytes. + * + * @param readStream IndexInput wrapper having a Lucene's IndexInput reference. + * @return pointer to location in memory the index resides in + */ + public static native long loadBinaryIndexWithStream(IndexInputWithBuffer readStream); + /** * Determine if index contains shared state. * diff --git a/src/main/java/org/opensearch/knn/jni/JNIService.java b/src/main/java/org/opensearch/knn/jni/JNIService.java index 94c1ec48e..448241f9c 100644 --- a/src/main/java/org/opensearch/knn/jni/JNIService.java +++ b/src/main/java/org/opensearch/knn/jni/JNIService.java @@ -16,6 +16,7 @@ import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.engine.KNNEngine; import org.opensearch.knn.index.query.KNNQueryResult; +import org.opensearch.knn.index.store.IndexInputWithBuffer; import org.opensearch.knn.index.util.IndexUtil; import java.util.Locale; @@ -211,6 +212,28 @@ public static long loadIndex(String indexPath, Map parameters, K ); } + /** + * Load an index via Lucene's IndexInput. + * + * @param readStream A wrapper having Lucene's IndexInput to load bytes from a file. + * @param parameters Parameters to be used when loading index + * @param knnEngine Engine to load index + * @return Pointer to location in memory the index resides in + */ + public static long loadIndex(IndexInputWithBuffer readStream, Map parameters, KNNEngine knnEngine) { + if (KNNEngine.FAISS == knnEngine) { + if (IndexUtil.isBinaryIndex(knnEngine, parameters)) { + return FaissService.loadBinaryIndexWithStream(readStream); + } else { + return FaissService.loadIndexWithStream(readStream); + } + } + + throw new IllegalArgumentException( + String.format(Locale.ROOT, "LoadIndex not supported for provided engine : %s", knnEngine.getName()) + ); + } + /** * Determine if index contains shared state. Currently, we cannot do this in the plugin because we do not store the * model definition anywhere. Only faiss supports indices that have shared state. So for all other engines it will diff --git a/src/test/java/org/opensearch/knn/index/memory/NativeMemoryAllocationTests.java b/src/test/java/org/opensearch/knn/index/memory/NativeMemoryAllocationTests.java index cb5fbaeba..906ff4cb7 100644 --- a/src/test/java/org/opensearch/knn/index/memory/NativeMemoryAllocationTests.java +++ b/src/test/java/org/opensearch/knn/index/memory/NativeMemoryAllocationTests.java @@ -37,6 +37,7 @@ import java.util.concurrent.Executors; import java.util.concurrent.Future; import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicReference; import static org.mockito.Mockito.doNothing; @@ -259,11 +260,12 @@ public void testIndexAllocation_closeDefault() { } public void testIndexAllocation_closeBlocking() throws InterruptedException, ExecutionException { + // Prepare mocking and a thread pool. WatcherHandle watcherHandle = (WatcherHandle) mock(WatcherHandle.class); - ExecutorService executorService = Executors.newFixedThreadPool(2); - AtomicReference expectedException = new AtomicReference<>(); + ExecutorService executorService = Executors.newSingleThreadExecutor(); - // Blocking close + // Enable `KNN_FORCE_EVICT_CACHE_ENABLED_SETTING` to force it to block other threads. + // Having it false will make `IndexAllocation` to run close logic in a different thread. when(clusterSettings.get(KNN_FORCE_EVICT_CACHE_ENABLED_SETTING)).thenReturn(true); NativeMemoryAllocation.IndexAllocation blockingIndexAllocation = new NativeMemoryAllocation.IndexAllocation( mock(ExecutorService.class), @@ -275,19 +277,21 @@ public void testIndexAllocation_closeBlocking() throws InterruptedException, Exe watcherHandle ); - executorService.submit(blockingIndexAllocation::readLock); + // Acquire a read lock + blockingIndexAllocation.readLock(); + + // This should be blocked as a read lock is still being held. Future closingThread = executorService.submit(blockingIndexAllocation::close); // Check if thread is currently blocked try { closingThread.get(5, TimeUnit.SECONDS); - } catch (Exception e) { - expectedException.set(e); - } - - assertNotNull(expectedException.get()); + fail("Closing should be blocked. We are still holding a read lock."); + } catch (TimeoutException ignored) {} - executorService.submit(blockingIndexAllocation::readUnlock); + // Now, we unlock a read lock. + blockingIndexAllocation.readUnlock(); + // As we don't hold any locking, the closing thread can now good to acquire a write lock. closingThread.get(); // Waits until close diff --git a/src/test/java/org/opensearch/knn/index/memory/NativeMemoryEntryContextTests.java b/src/test/java/org/opensearch/knn/index/memory/NativeMemoryEntryContextTests.java index 1720da1ed..72cab9a1b 100644 --- a/src/test/java/org/opensearch/knn/index/memory/NativeMemoryEntryContextTests.java +++ b/src/test/java/org/opensearch/knn/index/memory/NativeMemoryEntryContextTests.java @@ -12,6 +12,7 @@ package org.opensearch.knn.index.memory; import com.google.common.collect.ImmutableMap; +import org.apache.lucene.store.Directory; import org.opensearch.cluster.service.ClusterService; import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.index.engine.qframe.QuantizationConfig; @@ -44,6 +45,7 @@ public void testAbstract_getKey() { public void testIndexEntryContext_load() throws IOException { NativeMemoryLoadStrategy.IndexLoadStrategy indexLoadStrategy = mock(NativeMemoryLoadStrategy.IndexLoadStrategy.class); NativeMemoryEntryContext.IndexEntryContext indexEntryContext = new NativeMemoryEntryContext.IndexEntryContext( + (Directory) null, "test", indexLoadStrategy, null, @@ -82,6 +84,7 @@ public void testIndexEntryContext_calculateSize() throws IOException { // Check that the indexEntryContext will return the same thing NativeMemoryEntryContext.IndexEntryContext indexEntryContext = new NativeMemoryEntryContext.IndexEntryContext( + (Directory) null, tmpFile.toAbsolutePath().toString(), null, null, @@ -94,6 +97,7 @@ public void testIndexEntryContext_calculateSize() throws IOException { public void testIndexEntryContext_getOpenSearchIndexName() { String openSearchIndexName = "test-index"; NativeMemoryEntryContext.IndexEntryContext indexEntryContext = new NativeMemoryEntryContext.IndexEntryContext( + (Directory) null, "test", null, null, @@ -106,6 +110,7 @@ public void testIndexEntryContext_getOpenSearchIndexName() { public void testIndexEntryContext_getParameters() { Map parameters = ImmutableMap.of("test-1", 10); NativeMemoryEntryContext.IndexEntryContext indexEntryContext = new NativeMemoryEntryContext.IndexEntryContext( + (Directory) null, "test", null, parameters, diff --git a/src/test/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategyTests.java b/src/test/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategyTests.java index bdd8d7e45..373afddc7 100644 --- a/src/test/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategyTests.java +++ b/src/test/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategyTests.java @@ -12,6 +12,8 @@ package org.opensearch.knn.index.memory; import com.google.common.collect.ImmutableMap; +import org.apache.lucene.store.Directory; +import org.apache.lucene.store.MMapDirectory; import org.opensearch.core.action.ActionListener; import org.opensearch.action.search.SearchResponse; import org.opensearch.knn.KNNTestCase; @@ -47,6 +49,7 @@ public class NativeMemoryLoadStrategyTests extends KNNTestCase { public void testIndexLoadStrategy_load() throws IOException { // Create basic nmslib HNSW index Path dir = createTempDir(); + Directory luceneDirectory = new MMapDirectory(dir); KNNEngine knnEngine = KNNEngine.NMSLIB; String indexName = "test1" + knnEngine.getExtension(); String path = dir.resolve(indexName).toAbsolutePath().toString(); @@ -68,6 +71,7 @@ public void testIndexLoadStrategy_load() throws IOException { NativeMemoryLoadStrategy.IndexLoadStrategy.initialize(resourceWatcherService); NativeMemoryEntryContext.IndexEntryContext indexEntryContext = new NativeMemoryEntryContext.IndexEntryContext( + luceneDirectory, path, NativeMemoryLoadStrategy.IndexLoadStrategy.getInstance(), parameters, @@ -87,6 +91,7 @@ public void testIndexLoadStrategy_load() throws IOException { public void testLoad_whenFaissBinary_thenSuccess() throws IOException { Path dir = createTempDir(); + Directory luceneDirectory = new MMapDirectory(dir); KNNEngine knnEngine = KNNEngine.FAISS; String indexName = "test1" + knnEngine.getExtension(); String path = dir.resolve(indexName).toAbsolutePath().toString(); @@ -116,6 +121,7 @@ public void testLoad_whenFaissBinary_thenSuccess() throws IOException { NativeMemoryLoadStrategy.IndexLoadStrategy.initialize(resourceWatcherService); NativeMemoryEntryContext.IndexEntryContext indexEntryContext = new NativeMemoryEntryContext.IndexEntryContext( + luceneDirectory, path, NativeMemoryLoadStrategy.IndexLoadStrategy.getInstance(), parameters, diff --git a/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java b/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java index c78478f4d..8566b0223 100644 --- a/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java +++ b/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java @@ -14,6 +14,10 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import lombok.SneakyThrows; +import org.apache.lucene.store.Directory; +import org.apache.lucene.store.IOContext; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.store.MMapDirectory; import org.junit.BeforeClass; import org.opensearch.Version; import org.opensearch.common.xcontent.XContentFactory; @@ -29,6 +33,7 @@ import org.opensearch.knn.index.engine.MethodComponentContext; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.store.IndexInputWithBuffer; import java.io.IOException; import java.net.URL; @@ -871,6 +876,29 @@ public void testQueryIndex_faiss_invalid_nullQueryVector() throws IOException { expectThrows(Exception.class, () -> JNIService.queryIndex(pointer, null, 10, null, KNNEngine.FAISS, null, 0, null)); } + public void testQueryIndex_faiss_streaming_invalid_nullQueryVector() throws IOException { + Path tmpFile = createTempFile(); + + TestUtils.createIndex( + testData.indexData.docs, + testData.loadDataToMemoryAddress(), + testData.indexData.getDimension(), + tmpFile.toAbsolutePath().toString(), + ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, faissMethod, KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), + KNNEngine.FAISS + ); + assertTrue(tmpFile.toFile().length() > 0); + + try (final Directory directory = new MMapDirectory(tmpFile.getParent())) { + try (IndexInput indexInput = directory.openInput(tmpFile.getFileName().toString(), IOContext.READONCE)) { + long pointer = JNIService.loadIndex(new IndexInputWithBuffer(indexInput), Collections.emptyMap(), KNNEngine.FAISS); + assertNotEquals(0, pointer); + + expectThrows(Exception.class, () -> JNIService.queryIndex(pointer, null, 10, null, KNNEngine.FAISS, null, 0, null)); + } + } + } + public void testQueryIndex_faiss_valid() throws IOException { int k = 10; @@ -930,6 +958,68 @@ public void testQueryIndex_faiss_valid() throws IOException { } } + public void testQueryIndex_faiss_streaming_valid() throws IOException { + int k = 10; + int efSearch = 100; + + List methods = ImmutableList.of(faissMethod); + List spaces = ImmutableList.of(SpaceType.L2, SpaceType.INNER_PRODUCT); + for (String method : methods) { + for (SpaceType spaceType : spaces) { + Path tmpFile = createTempFile(); + TestUtils.createIndex( + testData.indexData.docs, + testData.loadDataToMemoryAddress(), + testData.indexData.getDimension(), + tmpFile.toAbsolutePath().toString(), + ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, method, KNNConstants.SPACE_TYPE, spaceType.getValue()), + KNNEngine.FAISS + ); + assertTrue(tmpFile.toFile().length() > 0); + + try (final Directory directory = new MMapDirectory(tmpFile.getParent())) { + try (IndexInput indexInput = directory.openInput(tmpFile.getFileName().toString(), IOContext.READONCE)) { + long pointer = JNIService.loadIndex( + new IndexInputWithBuffer(indexInput), + ImmutableMap.of(KNNConstants.SPACE_TYPE, spaceType.getValue()), + KNNEngine.FAISS + ); + assertNotEquals(0, pointer); + + for (float[] query : testData.queries) { + KNNQueryResult[] results = JNIService.queryIndex( + pointer, + query, + k, + Map.of("ef_search", efSearch), + KNNEngine.FAISS, + null, + 0, + null + ); + assertEquals(k, results.length); + } + + // Filter will result in no ids + for (float[] query : testData.queries) { + KNNQueryResult[] results = JNIService.queryIndex( + pointer, + query, + k, + Map.of("ef_search", efSearch), + KNNEngine.FAISS, + new long[] { 0 }, + 0, + null + ); + assertEquals(0, results.length); + } // End for + } // End try + } // End try + } // End for + } // End for + } + public void testQueryIndex_faiss_parentIds() throws IOException { int k = 100; @@ -978,6 +1068,58 @@ public void testQueryIndex_faiss_parentIds() throws IOException { } } + public void testQueryIndex_faiss_streaming_parentIds() throws IOException { + + int k = 100; + int efSearch = 100; + + List methods = ImmutableList.of(faissMethod); + List spaces = ImmutableList.of(SpaceType.L2, SpaceType.INNER_PRODUCT); + int[] parentIds = toParentIdArray(testDataNested.indexData.docs); + Map idToParentIdMap = toIdToParentIdMap(testDataNested.indexData.docs); + for (String method : methods) { + for (SpaceType spaceType : spaces) { + Path tmpFile = createTempFile(); + TestUtils.createIndex( + testDataNested.indexData.docs, + testData.loadDataToMemoryAddress(), + testDataNested.indexData.getDimension(), + tmpFile.toAbsolutePath().toString(), + ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, method, KNNConstants.SPACE_TYPE, spaceType.getValue()), + KNNEngine.FAISS + ); + assertTrue(tmpFile.toFile().length() > 0); + + try (final Directory directory = new MMapDirectory(tmpFile.getParent())) { + try (IndexInput indexInput = directory.openInput(tmpFile.getFileName().toString(), IOContext.READONCE)) { + long pointer = JNIService.loadIndex( + new IndexInputWithBuffer(indexInput), + ImmutableMap.of(KNNConstants.SPACE_TYPE, spaceType.getValue()), + KNNEngine.FAISS + ); + assertNotEquals(0, pointer); + + for (float[] query : testDataNested.queries) { + KNNQueryResult[] results = JNIService.queryIndex( + pointer, + query, + k, + Map.of("ef_search", efSearch), + KNNEngine.FAISS, + null, + 0, + parentIds + ); + // Verify there is no more than one result from same parent + Set parentIdSet = toParentIdSet(results, idToParentIdMap); + assertEquals(results.length, parentIdSet.size()); + } // End for + } // End try + } // End try + } // End for + } // End for + } + @SneakyThrows public void testQueryBinaryIndex_faiss_valid() { int k = 10; @@ -1016,6 +1158,53 @@ public void testQueryBinaryIndex_faiss_valid() { } } + @SneakyThrows + public void testQueryBinaryIndex_faiss_streaming_valid() { + int k = 10; + List methods = ImmutableList.of(faissBinaryMethod); + for (String method : methods) { + Path tmpFile = createTempFile(); + long memoryAddr = testData.loadBinaryDataToMemoryAddress(); + TestUtils.createIndex( + testData.indexData.docs, + memoryAddr, + testData.indexData.getDimension(), + tmpFile.toAbsolutePath().toString(), + ImmutableMap.of( + INDEX_DESCRIPTION_PARAMETER, + method, + KNNConstants.SPACE_TYPE, + SpaceType.HAMMING.getValue(), + KNNConstants.VECTOR_DATA_TYPE_FIELD, + VectorDataType.BINARY.getValue() + ), + KNNEngine.FAISS + ); + assertTrue(tmpFile.toFile().length() > 0); + + try (final Directory directory = new MMapDirectory(tmpFile.getParent())) { + try (IndexInput indexInput = directory.openInput(tmpFile.getFileName().toString(), IOContext.READONCE)) { + long pointer = JNIService.loadIndex( + new IndexInputWithBuffer(indexInput), + ImmutableMap.of( + INDEX_DESCRIPTION_PARAMETER, + method, + KNNConstants.VECTOR_DATA_TYPE_FIELD, + VectorDataType.BINARY.getValue() + ), + KNNEngine.FAISS + ); + assertNotEquals(0, pointer); + + for (byte[] query : testData.binaryQueries) { + KNNQueryResult[] results = JNIService.queryBinaryIndex(pointer, query, k, null, KNNEngine.FAISS, null, 0, null); + assertEquals(k, results.length); + } // End for + } // End try + } // End try + } // End for + } + private Set toParentIdSet(KNNQueryResult[] results, Map idToParentIdMap) { return Arrays.stream(results).map(result -> idToParentIdMap.get(result.getId())).collect(Collectors.toSet()); }