Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP demo for using KNN in Kotlin based on https://github.com/realm/realm-core/pull/6759 #1752

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion buildSrc/src/main/kotlin/Config.kt
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ val HOST_OS: OperatingSystem = findHostOs()

object Realm {
val ciBuild = (System.getenv("CI") != null)
const val version = "1.15.0"
const val version = "1.15.0-KNN"
const val group = "io.realm.kotlin"
const val projectUrl = "https://realm.io"
const val pluginPortalId = "io.realm.kotlin"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,7 @@ expect object RealmInterop {
): RealmQueryPointer
fun realm_query_find_first(query: RealmQueryPointer): Link?
fun realm_query_find_all(query: RealmQueryPointer): RealmResultsPointer
fun realm_query_find_knn(query: RealmQueryPointer, property: String, queryVector: Array<Float>, numberOfNeighbours: Int): RealmResultsPointer
fun realm_query_count(query: RealmQueryPointer): Long
fun realm_query_append_query(
query: RealmQueryPointer,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1777,6 +1777,10 @@ actual object RealmInterop {
return LongPointerWrapper(realmc.realm_query_find_all(query.cptr()))
}

actual fun realm_query_find_knn(query: RealmQueryPointer, property: String, queryVector: Array<Float>, numberOfNeighbours: Int): RealmResultsPointer {
return LongPointerWrapper(realmc.realm_knnsearch(query.cptr(), property, queryVector.toFloatArray(), numberOfNeighbours))
}

actual fun realm_query_count(query: RealmQueryPointer): Long {
val count = LongArray(1)
realmc.realm_query_count(query.cptr(), count)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1499,6 +1499,10 @@ actual object RealmInterop {
return CPointerWrapper(realm_wrapper.realm_query_find_all(query.cptr()))
}

actual fun realm_query_find_knn(query: RealmQueryPointer, property: String, queryVector: Array<Float>, numberOfNeighbours: Int): RealmResultsPointer {
TODO("realm_query_find_knn not yet implemented for darwin")
}

actual fun realm_query_count(query: RealmQueryPointer): Long {
memScoped {
val count = alloc<ULongVar>()
Expand Down
19 changes: 19 additions & 0 deletions packages/jni-swig-stub/src/main/jni/realm_api_helpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -937,6 +937,25 @@ void realm_sync_websocket_closed(int64_t observer_ptr, bool was_clean, int error
realm_sync_socket_websocket_closed(reinterpret_cast<realm_websocket_observer_t*>(observer_ptr), was_clean, static_cast<realm_web_socket_errno_e>(error_code), reason);
}

realm_results_t* realm_knnsearch(const realm_query_t* existing_query, const char* property, jfloatArray floats, int numberOfNeighbours) {
auto table = existing_query->query.get_table();
auto jenv = get_env(false);

jsize len = jenv->GetArrayLength( floats);
jfloat *vec = jenv->GetFloatArrayElements(floats, 0);
std::vector<float> query_vector;
query_vector.reserve(len);
for (size_t i = 0; i < len; ++i) {
query_vector.push_back(vec[i]);
}
jenv->ReleaseFloatArrayElements(floats, vec, 0);

auto col_lst = table->get_column_key(property);

Results results(existing_query->weak_realm.lock(), table->where());
return new realm_results{results.knn_search(col_lst, query_vector, numberOfNeighbours)};
}

realm_sync_socket_t* realm_sync_websocket_new(int64_t sync_client_config_ptr, jobject websocket_transport) {
auto jenv = get_env(false); // Always called from JVM
realm_sync_socket_t* socket_provider = realm_sync_socket_new(jenv->NewGlobalRef(websocket_transport), /*userdata*/
Expand Down
1 change: 1 addition & 0 deletions packages/jni-swig-stub/src/main/jni/realm_api_helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,4 +161,5 @@ bool realm_sync_websocket_message(int64_t observer_ptr, jbyteArray data, size_t

void realm_sync_websocket_closed(int64_t observer_ptr, bool was_clean, int error_code, const char* reason);

realm_results_t* realm_knnsearch(const realm_query_t* existing_query, const char* property, jfloatArray floats, int numberOfNeighbours);
#endif //TEST_REALM_API_HELPERS_H
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,10 @@ internal class ObjectQuery<E : BaseRealmObject> constructor(
return query(stringBuilder.toString())
}

override fun knn(property: String, queryVector: Array<Float>, numberOfNeighbours: Int): RealmResults<E> {
return RealmResultsImpl(realmReference, RealmInterop.realm_query_find_knn(queryPointer, property, queryVector, numberOfNeighbours), classKey, clazz, mediator)
}

override fun distinct(property: String, vararg extraProperties: String): RealmQuery<E> {
val stringBuilder = StringBuilder().append("TRUEPREDICATE DISTINCT($property")
extraProperties.forEach { extraProperty ->
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ public interface RealmQuery<T : BaseRealmObject> : RealmElementQuery<T> {
*/
public fun sort(property: String, sortOrder: Sort = Sort.ASCENDING): RealmQuery<T>

public fun knn(property: String, queryVector: Array<Float>, numberOfNeighbours: Int): RealmResults<T>

/**
* Sorts the query result by the specific property name according to [Pair]s of properties and
* sorting order.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3079,6 +3079,66 @@ class QueryTests {
channel.close()
}

// ----------------
// KNN search
// ----------------
@Test
fun knn_search() {
val realm = Realm.open(RealmConfiguration.create(setOf(VectorEmbeddingSample::class)))
realm.writeBlocking {
copyToRealm(VectorEmbeddingSample().apply {
stringField = "vector 1"
embedding.addAll(
listOf(
0.003f, 0.004f, 0.005f, 0.100f, 0.010f
)
)
})

copyToRealm(VectorEmbeddingSample().apply {
stringField = "vector 2"
embedding.addAll(
listOf(
0.001f, 0.004f, 0.005f, 0.100f, 0.010f
)
)
})

copyToRealm(VectorEmbeddingSample().apply {
stringField = "vector 3"
embedding.addAll(
listOf(
0.001f, 0.004f, 0.005f, 0.100f, 0.010f
)
)
})

copyToRealm(VectorEmbeddingSample().apply {
stringField = "vector 4"
embedding.addAll(
listOf(
0.004f, 0.005f, 0.010f, 0.025f, 0.100f
)
)
})

copyToRealm(VectorEmbeddingSample().apply {
stringField = "vector 5"
embedding.addAll(
listOf(
0.003f, 0.007f, 0.008f, 0.020f, 0.100f
)
)
})
}
val knn: RealmResults<VectorEmbeddingSample> = realm.query<VectorEmbeddingSample>()
.knn("embedding", arrayOf(0.003f, 0.005f, 0.010f, 0.020f, 0.100f), 2)
assertEquals(2, knn.size)
assertEquals("vector 4", knn[0].stringField)
assertEquals("vector 5", knn[1].stringField)
}


// ----------------
// Coercion helpers
// ----------------
Expand Down Expand Up @@ -3297,6 +3357,10 @@ private data class PropertyDescriptor constructor(
val values: List<Any?>
)

class VectorEmbeddingSample : RealmObject {
var stringField: String = ""
var embedding: RealmList<Float> = realmListOf()
}
/**
* Use this and not [io.realm.kotlin.entities.Sample] as that class has default initializers that make
* aggregating operations harder to assert.
Expand Down