Skip to content

Commit

Permalink
Introducing a loading layer in FAISS native engine. (#2139)
Browse files Browse the repository at this point in the history
* Introducing a loading layer in FAISS native engine.

Signed-off-by: Dooyong Kim <[email protected]>

* Update change log.

Signed-off-by: Dooyong Kim <[email protected]>

* Added unit tests for Faiss stream support.

Signed-off-by: Dooyong Kim <[email protected]>

* Fix a bug to pass a KB size integer value as a byte size integer parameter.

Signed-off-by: Dooyong Kim <[email protected]>

* Fix a casting bugs when it tries to laod more than 4G sized index file.

Signed-off-by: Dooyong Kim <[email protected]>

* Added unit tests for new methods in JNIService.

Signed-off-by: Dooyong Kim <[email protected]>

* Fix formatting and removed nmslib_stream_support.

Signed-off-by: Dooyong Kim <[email protected]>

* Removing redundant exception message in JNIService.loadIndex.

Signed-off-by: Dooyong Kim <[email protected]>

* Fix a flaky testing - testIndexAllocation_closeBlocking

Signed-off-by: Dooyong Kim <[email protected]>

---------

Signed-off-by: Dooyong Kim <[email protected]>
Signed-off-by: Doo Yong Kim <[email protected]>
Co-authored-by: Dooyong Kim <[email protected]>
  • Loading branch information
0ctopus13prime and Dooyong Kim authored Oct 3, 2024
1 parent 07f4df2 commit f3b2bd0
Show file tree
Hide file tree
Showing 24 changed files with 864 additions and 103 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
## [Unreleased 3.0](https://github.com/opensearch-project/k-NN/compare/2.x...HEAD)
### Features
### Enhancements
* Introducing a loading layer in FAISS [#2033](https://github.com/opensearch-project/k-NN/issues/2033)
### Bug Fixes
* Add DocValuesProducers for releasing memory when close index [#1946](https://github.com/opensearch-project/k-NN/pull/1946)
### Infrastructure
Expand Down
1 change: 1 addition & 0 deletions jni/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ if ("${WIN32}" STREQUAL "")
tests/nmslib_wrapper_unit_test.cpp
tests/test_util.cpp
tests/commons_test.cpp
tests/faiss_stream_support_test.cpp
tests/faiss_index_service_test.cpp
)

Expand Down
136 changes: 136 additions & 0 deletions jni/include/faiss_stream_support.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

#ifndef OPENSEARCH_KNN_JNI_STREAM_SUPPORT_H
#define OPENSEARCH_KNN_JNI_STREAM_SUPPORT_H

#include "faiss/impl/io.h"
#include "jni_util.h"

#include <jni.h>
#include <stdexcept>
#include <iostream>
#include <cstring>

namespace knn_jni {
namespace stream {

/**
* This class contains Java IndexInputWithBuffer reference and calls its API to copy required bytes into a read buffer.
*/

class NativeEngineIndexInputMediator {
public:
// Expect IndexInputWithBuffer is given as `_indexInput`.
NativeEngineIndexInputMediator(JNIUtilInterface *_jni_interface,
JNIEnv *_env,
jobject _indexInput)
: jni_interface(_jni_interface),
env(_env),
indexInput(_indexInput),
bufferArray((jbyteArray) (_jni_interface->GetObjectField(_env,
_indexInput,
getBufferFieldId(_jni_interface, _env)))),
copyBytesMethod(getCopyBytesMethod(_jni_interface, _env)) {
}

void copyBytes(int64_t nbytes, uint8_t *destination) {
while (nbytes > 0) {
// Call `copyBytes` to read bytes as many as possible.
const auto readBytes =
jni_interface->CallIntMethodLong(env, indexInput, copyBytesMethod, nbytes);

// === Critical Section Start ===

// Get primitive array pointer, no copy is happening in OpenJDK.
auto primitiveArray =
(jbyte *) jni_interface->GetPrimitiveArrayCritical(env, bufferArray, nullptr);

// Copy Java bytes to C++ destination address.
std::memcpy(destination, primitiveArray, readBytes);

// Release the acquired primitive array pointer.
// JNI_ABORT tells JVM to directly free memory without copying back to Java byte[].
// Since we're merely copying data, we don't need to copying back.
jni_interface->ReleasePrimitiveArrayCritical(env, bufferArray, primitiveArray, JNI_ABORT);

// === Critical Section End ===

destination += readBytes;
nbytes -= readBytes;
} // End while
}

private:
static jclass getIndexInputWithBufferClass(JNIUtilInterface *jni_interface, JNIEnv *env) {
static jclass INDEX_INPUT_WITH_BUFFER_CLASS =
jni_interface->FindClassFromJNIEnv(env, "org/opensearch/knn/index/store/IndexInputWithBuffer");
return INDEX_INPUT_WITH_BUFFER_CLASS;
}

static jmethodID getCopyBytesMethod(JNIUtilInterface *jni_interface, JNIEnv *env) {
static jmethodID COPY_METHOD_ID =
jni_interface->GetMethodID(env, getIndexInputWithBufferClass(jni_interface, env), "copyBytes", "(J)I");
return COPY_METHOD_ID;
}

static jfieldID getBufferFieldId(JNIUtilInterface *jni_interface, JNIEnv *env) {
static jfieldID BUFFER_FIELD_ID =
jni_interface->GetFieldID(env, getIndexInputWithBufferClass(jni_interface, env), "buffer", "[B");
return BUFFER_FIELD_ID;
}

JNIUtilInterface *jni_interface;
JNIEnv *env;

// `IndexInputWithBuffer` instance having `IndexInput` instance obtained from `Directory` for reading.
jobject indexInput;
jbyteArray bufferArray;
jmethodID copyBytesMethod;
}; // class NativeEngineIndexInputMediator



/**
* A glue component inheriting IOReader to be passed down to Faiss library.
* This will then indirectly call the mediator component and eventually read required bytes from Lucene's IndexInput.
*/
class FaissOpenSearchIOReader final : public faiss::IOReader {
public:
explicit FaissOpenSearchIOReader(NativeEngineIndexInputMediator *_mediator)
: faiss::IOReader(),
mediator(_mediator) {
name = "FaissOpenSearchIOReader";
}

size_t operator()(void *ptr, size_t size, size_t nitems) final {
const auto readBytes = size * nitems;
if (readBytes > 0) {
// Mediator calls IndexInput, then copy read bytes to `ptr`.
mediator->copyBytes(readBytes, (uint8_t *) ptr);
}
return nitems;
}

int filedescriptor() final {
throw std::runtime_error("filedescriptor() is not supported in FaissOpenSearchIOReader.");
}

private:
NativeEngineIndexInputMediator *mediator;
}; // class FaissOpenSearchIOReader



}
}

#endif //OPENSEARCH_KNN_JNI_STREAM_SUPPORT_H
10 changes: 10 additions & 0 deletions jni/include/faiss_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,21 @@ namespace knn_jni {
// Return a pointer to the loaded index
jlong LoadIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jstring indexPathJ);

// Loads an index with a reader implemented IOReader
//
// Returns a pointer of the loaded index
jlong LoadIndexWithStream(faiss::IOReader* ioReader);

// Load a binary index from indexPathJ into memory.
//
// Return a pointer to the loaded index
jlong LoadBinaryIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jstring indexPathJ);

// Loads a binary index with a reader implemented IOReader
//
// Returns a pointer of the loaded index
jlong LoadBinaryIndexWithStream(faiss::IOReader* ioReader);

// Check if a loaded index requires shared state
bool IsSharedIndexStateRequired(jlong indexPointerJ);

Expand Down
102 changes: 61 additions & 41 deletions jni/include/jni_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@
namespace knn_jni {

// Interface for making calls to JNI
class JNIUtilInterface {
public:
struct JNIUtilInterface {
// -------------------------- EXCEPTION HANDLING ----------------------------
// Takes the name of a Java exception type and a message and throws the corresponding exception
// to the JVM
Expand Down Expand Up @@ -127,56 +126,77 @@ namespace knn_jni {

virtual void SetByteArrayRegion(JNIEnv *env, jbyteArray array, jsize start, jsize len, const jbyte * buf) = 0;

virtual jobject GetObjectField(JNIEnv * env, jobject obj, jfieldID fieldID) = 0;

virtual jclass FindClassFromJNIEnv(JNIEnv * env, const char *name) = 0;

virtual jmethodID GetMethodID(JNIEnv * env, jclass clazz, const char *name, const char *sig) = 0;

virtual jfieldID GetFieldID(JNIEnv * env, jclass clazz, const char *name, const char *sig) = 0;

virtual void * GetPrimitiveArrayCritical(JNIEnv * env, jarray array, jboolean *isCopy) = 0;

virtual void ReleasePrimitiveArrayCritical(JNIEnv * env, jarray array, void *carray, jint mode) = 0;

virtual jint CallIntMethodLong(JNIEnv * env, jobject obj, jmethodID methodID, int64_t longArg) = 0;

// --------------------------------------------------------------------------
};

jobject GetJObjectFromMapOrThrow(std::unordered_map<std::string, jobject> map, std::string key);

// Class that implements JNIUtilInterface methods
class JNIUtil: public JNIUtilInterface {
class JNIUtil final : public JNIUtilInterface {
public:
// Initialize and Uninitialize methods are used for caching/cleaning up Java classes and methods
void Initialize(JNIEnv* env);
void Uninitialize(JNIEnv* env);

void ThrowJavaException(JNIEnv* env, const char* type = "", const char* message = "");
void HasExceptionInStack(JNIEnv* env);
void HasExceptionInStack(JNIEnv* env, const std::string& message);
void CatchCppExceptionAndThrowJava(JNIEnv* env);
jclass FindClass(JNIEnv * env, const std::string& className);
jmethodID FindMethod(JNIEnv * env, const std::string& className, const std::string& methodName);
std::string ConvertJavaStringToCppString(JNIEnv * env, jstring javaString);
std::unordered_map<std::string, jobject> ConvertJavaMapToCppMap(JNIEnv *env, jobject parametersJ);
std::string ConvertJavaObjectToCppString(JNIEnv *env, jobject objectJ);
int ConvertJavaObjectToCppInteger(JNIEnv *env, jobject objectJ);
std::vector<float> Convert2dJavaObjectArrayToCppFloatVector(JNIEnv *env, jobjectArray array2dJ, int dim);
std::vector<int64_t> ConvertJavaIntArrayToCppIntVector(JNIEnv *env, jintArray arrayJ);
int GetInnerDimensionOf2dJavaFloatArray(JNIEnv *env, jobjectArray array2dJ);
int GetInnerDimensionOf2dJavaByteArray(JNIEnv *env, jobjectArray array2dJ);
int GetJavaObjectArrayLength(JNIEnv *env, jobjectArray arrayJ);
int GetJavaIntArrayLength(JNIEnv *env, jintArray arrayJ);
int GetJavaLongArrayLength(JNIEnv *env, jlongArray arrayJ);
int GetJavaBytesArrayLength(JNIEnv *env, jbyteArray arrayJ);
int GetJavaFloatArrayLength(JNIEnv *env, jfloatArray arrayJ);

void DeleteLocalRef(JNIEnv *env, jobject obj);
jbyte * GetByteArrayElements(JNIEnv *env, jbyteArray array, jboolean * isCopy);
jfloat * GetFloatArrayElements(JNIEnv *env, jfloatArray array, jboolean * isCopy);
jint * GetIntArrayElements(JNIEnv *env, jintArray array, jboolean * isCopy);
jlong * GetLongArrayElements(JNIEnv *env, jlongArray array, jboolean * isCopy);
jobject GetObjectArrayElement(JNIEnv *env, jobjectArray array, jsize index);
jobject NewObject(JNIEnv *env, jclass clazz, jmethodID methodId, int id, float distance);
jobjectArray NewObjectArray(JNIEnv *env, jsize len, jclass clazz, jobject init);
jbyteArray NewByteArray(JNIEnv *env, jsize len);
void ReleaseByteArrayElements(JNIEnv *env, jbyteArray array, jbyte *elems, int mode);
void ReleaseFloatArrayElements(JNIEnv *env, jfloatArray array, jfloat *elems, int mode);
void ReleaseIntArrayElements(JNIEnv *env, jintArray array, jint *elems, jint mode);
void ReleaseLongArrayElements(JNIEnv *env, jlongArray array, jlong *elems, jint mode);
void SetObjectArrayElement(JNIEnv *env, jobjectArray array, jsize index, jobject val);
void SetByteArrayRegion(JNIEnv *env, jbyteArray array, jsize start, jsize len, const jbyte * buf);
void Convert2dJavaObjectArrayAndStoreToFloatVector(JNIEnv *env, jobjectArray array2dJ, int dim, std::vector<float> *vect);
void Convert2dJavaObjectArrayAndStoreToBinaryVector(JNIEnv *env, jobjectArray array2dJ, int dim, std::vector<uint8_t> *vect);
void Convert2dJavaObjectArrayAndStoreToByteVector(JNIEnv *env, jobjectArray array2dJ, int dim, std::vector<int8_t> *vect);
void ThrowJavaException(JNIEnv* env, const char* type = "", const char* message = "") final;
void HasExceptionInStack(JNIEnv* env) final;
void HasExceptionInStack(JNIEnv* env, const std::string& message) final;
void CatchCppExceptionAndThrowJava(JNIEnv* env) final;
jclass FindClass(JNIEnv * env, const std::string& className) final;
jmethodID FindMethod(JNIEnv * env, const std::string& className, const std::string& methodName) final;
std::string ConvertJavaStringToCppString(JNIEnv * env, jstring javaString) final;
std::unordered_map<std::string, jobject> ConvertJavaMapToCppMap(JNIEnv *env, jobject parametersJ) final;
std::string ConvertJavaObjectToCppString(JNIEnv *env, jobject objectJ) final;
int ConvertJavaObjectToCppInteger(JNIEnv *env, jobject objectJ) final;
std::vector<float> Convert2dJavaObjectArrayToCppFloatVector(JNIEnv *env, jobjectArray array2dJ, int dim) final;
std::vector<int64_t> ConvertJavaIntArrayToCppIntVector(JNIEnv *env, jintArray arrayJ) final;
int GetInnerDimensionOf2dJavaFloatArray(JNIEnv *env, jobjectArray array2dJ) final;
int GetInnerDimensionOf2dJavaByteArray(JNIEnv *env, jobjectArray array2dJ) final;
int GetJavaObjectArrayLength(JNIEnv *env, jobjectArray arrayJ) final;
int GetJavaIntArrayLength(JNIEnv *env, jintArray arrayJ) final;
int GetJavaLongArrayLength(JNIEnv *env, jlongArray arrayJ) final;
int GetJavaBytesArrayLength(JNIEnv *env, jbyteArray arrayJ) final;
int GetJavaFloatArrayLength(JNIEnv *env, jfloatArray arrayJ) final;

void DeleteLocalRef(JNIEnv *env, jobject obj) final;
jbyte * GetByteArrayElements(JNIEnv *env, jbyteArray array, jboolean * isCopy) final;
jfloat * GetFloatArrayElements(JNIEnv *env, jfloatArray array, jboolean * isCopy) final;
jint * GetIntArrayElements(JNIEnv *env, jintArray array, jboolean * isCopy) final;
jlong * GetLongArrayElements(JNIEnv *env, jlongArray array, jboolean * isCopy) final;
jobject GetObjectArrayElement(JNIEnv *env, jobjectArray array, jsize index) final;
jobject NewObject(JNIEnv *env, jclass clazz, jmethodID methodId, int id, float distance) final;
jobjectArray NewObjectArray(JNIEnv *env, jsize len, jclass clazz, jobject init) final;
jbyteArray NewByteArray(JNIEnv *env, jsize len) final;
void ReleaseByteArrayElements(JNIEnv *env, jbyteArray array, jbyte *elems, int mode) final;
void ReleaseFloatArrayElements(JNIEnv *env, jfloatArray array, jfloat *elems, int mode) final;
void ReleaseIntArrayElements(JNIEnv *env, jintArray array, jint *elems, jint mode) final;
void ReleaseLongArrayElements(JNIEnv *env, jlongArray array, jlong *elems, jint mode) final;
void SetObjectArrayElement(JNIEnv *env, jobjectArray array, jsize index, jobject val) final;
void SetByteArrayRegion(JNIEnv *env, jbyteArray array, jsize start, jsize len, const jbyte * buf) final;
void Convert2dJavaObjectArrayAndStoreToFloatVector(JNIEnv *env, jobjectArray array2dJ, int dim, std::vector<float> *vect) final;
void Convert2dJavaObjectArrayAndStoreToBinaryVector(JNIEnv *env, jobjectArray array2dJ, int dim, std::vector<uint8_t> *vect) final;
void Convert2dJavaObjectArrayAndStoreToByteVector(JNIEnv *env, jobjectArray array2dJ, int dim, std::vector<int8_t> *vect) final;
jobject GetObjectField(JNIEnv * env, jobject obj, jfieldID fieldID) final;
jclass FindClassFromJNIEnv(JNIEnv * env, const char *name) final;
jmethodID GetMethodID(JNIEnv * env, jclass clazz, const char *name, const char *sig) final;
jfieldID GetFieldID(JNIEnv * env, jclass clazz, const char *name, const char *sig) final;
jint CallIntMethodLong(JNIEnv * env, jobject obj, jmethodID methodID, int64_t longArg) final;
void * GetPrimitiveArrayCritical(JNIEnv * env, jarray array, jboolean *isCopy) final;
void ReleasePrimitiveArrayCritical(JNIEnv * env, jarray array, void *carray, jint mode) final;

private:
std::unordered_map<std::string, jclass> cachedClasses;
Expand Down
16 changes: 16 additions & 0 deletions jni/include/org_opensearch_knn_jni_FaissService.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,14 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndexFromT
JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_loadIndex
(JNIEnv *, jclass, jstring);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: loadIndexWithStream
* Signature: (Lorg/opensearch/knn/index/util/IndexInputWithBuffer;)J
*/
JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_loadIndexWithStream
(JNIEnv *, jclass, jobject);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: loadBinaryIndex
Expand All @@ -136,6 +144,14 @@ JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_loadIndex
JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_loadBinaryIndex
(JNIEnv *, jclass, jstring);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: loadBinaryIndexWithStream
* Signature: (Lorg/opensearch/knn/index/util/IndexInputWithBuffer;)J
*/
JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_loadBinaryIndexWithStream
(JNIEnv *, jclass, jobject);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: isSharedIndexStateRequired
Expand Down
28 changes: 28 additions & 0 deletions jni/src/faiss_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,20 @@ jlong knn_jni::faiss_wrapper::LoadIndex(knn_jni::JNIUtilInterface * jniUtil, JNI
return (jlong) indexReader;
}

jlong knn_jni::faiss_wrapper::LoadIndexWithStream(faiss::IOReader* ioReader) {
if (ioReader == nullptr) [[unlikely]] {
throw std::runtime_error("IOReader cannot be null");
}

faiss::Index* indexReader =
faiss::read_index(ioReader,
faiss::IO_FLAG_READ_ONLY
| faiss::IO_FLAG_PQ_SKIP_SDC_TABLE
| faiss::IO_FLAG_SKIP_PRECOMPUTE_TABLE);

return (jlong) indexReader;
}

jlong knn_jni::faiss_wrapper::LoadBinaryIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jstring indexPathJ) {
if (indexPathJ == nullptr) {
throw std::runtime_error("Index path cannot be null");
Expand All @@ -436,6 +450,20 @@ jlong knn_jni::faiss_wrapper::LoadBinaryIndex(knn_jni::JNIUtilInterface * jniUti
return (jlong) indexReader;
}

jlong knn_jni::faiss_wrapper::LoadBinaryIndexWithStream(faiss::IOReader* ioReader) {
if (ioReader == nullptr) [[unlikely]] {
throw std::runtime_error("IOReader cannot be null");
}

faiss::IndexBinary* indexReader =
faiss::read_index_binary(ioReader,
faiss::IO_FLAG_READ_ONLY
| faiss::IO_FLAG_PQ_SKIP_SDC_TABLE
| faiss::IO_FLAG_SKIP_PRECOMPUTE_TABLE);

return (jlong) indexReader;
}

bool knn_jni::faiss_wrapper::IsSharedIndexStateRequired(jlong indexPointerJ) {
auto * index = reinterpret_cast<faiss::Index*>(indexPointerJ);
return isIndexIVFPQL2(index);
Expand Down
Loading

0 comments on commit f3b2bd0

Please sign in to comment.