diff --git a/buildSrc/src/main/kotlin/Config.kt b/buildSrc/src/main/kotlin/Config.kt index be311376b0..05cf3a0767 100644 --- a/buildSrc/src/main/kotlin/Config.kt +++ b/buildSrc/src/main/kotlin/Config.kt @@ -62,7 +62,7 @@ val HOST_OS: OperatingSystem = findHostOs() object Realm { val ciBuild = (System.getenv("CI") != null) - const val version = "1.15.0" + const val version = "1.15.0-KNN" const val group = "io.realm.kotlin" const val projectUrl = "https://realm.io" const val pluginPortalId = "io.realm.kotlin" diff --git a/packages/cinterop/src/commonMain/kotlin/io/realm/kotlin/internal/interop/RealmInterop.kt b/packages/cinterop/src/commonMain/kotlin/io/realm/kotlin/internal/interop/RealmInterop.kt index 1a42b6faac..3f8281fac9 100644 --- a/packages/cinterop/src/commonMain/kotlin/io/realm/kotlin/internal/interop/RealmInterop.kt +++ b/packages/cinterop/src/commonMain/kotlin/io/realm/kotlin/internal/interop/RealmInterop.kt @@ -407,6 +407,7 @@ expect object RealmInterop { ): RealmQueryPointer fun realm_query_find_first(query: RealmQueryPointer): Link? fun realm_query_find_all(query: RealmQueryPointer): RealmResultsPointer + fun realm_query_find_knn(query: RealmQueryPointer, property: String, queryVector: Array, numberOfNeighbours: Int): RealmResultsPointer fun realm_query_count(query: RealmQueryPointer): Long fun realm_query_append_query( query: RealmQueryPointer, diff --git a/packages/cinterop/src/jvm/kotlin/io/realm/kotlin/internal/interop/RealmInterop.kt b/packages/cinterop/src/jvm/kotlin/io/realm/kotlin/internal/interop/RealmInterop.kt index 7797adcf14..ed68272384 100644 --- a/packages/cinterop/src/jvm/kotlin/io/realm/kotlin/internal/interop/RealmInterop.kt +++ b/packages/cinterop/src/jvm/kotlin/io/realm/kotlin/internal/interop/RealmInterop.kt @@ -1777,6 +1777,10 @@ actual object RealmInterop { return LongPointerWrapper(realmc.realm_query_find_all(query.cptr())) } + actual fun realm_query_find_knn(query: RealmQueryPointer, property: String, queryVector: Array, numberOfNeighbours: Int): RealmResultsPointer { + return LongPointerWrapper(realmc.realm_knnsearch(query.cptr(), property, queryVector.toFloatArray(), numberOfNeighbours)) + } + actual fun realm_query_count(query: RealmQueryPointer): Long { val count = LongArray(1) realmc.realm_query_count(query.cptr(), count) diff --git a/packages/cinterop/src/nativeDarwin/kotlin/io/realm/kotlin/internal/interop/RealmInterop.kt b/packages/cinterop/src/nativeDarwin/kotlin/io/realm/kotlin/internal/interop/RealmInterop.kt index 10b979dafa..d7e15603b2 100644 --- a/packages/cinterop/src/nativeDarwin/kotlin/io/realm/kotlin/internal/interop/RealmInterop.kt +++ b/packages/cinterop/src/nativeDarwin/kotlin/io/realm/kotlin/internal/interop/RealmInterop.kt @@ -1499,6 +1499,10 @@ actual object RealmInterop { return CPointerWrapper(realm_wrapper.realm_query_find_all(query.cptr())) } + actual fun realm_query_find_knn(query: RealmQueryPointer, property: String, queryVector: Array, numberOfNeighbours: Int): RealmResultsPointer { + TODO("realm_query_find_knn not yet implemented for darwin") + } + actual fun realm_query_count(query: RealmQueryPointer): Long { memScoped { val count = alloc() diff --git a/packages/external/core b/packages/external/core index 316889b967..5d5bb6c5d0 160000 --- a/packages/external/core +++ b/packages/external/core @@ -1 +1 @@ -Subproject commit 316889b967f845fbc10b4422f96c7eadd47136f2 +Subproject commit 5d5bb6c5d0698dac3324dfb72d65fda12a4c3bb6 diff --git a/packages/jni-swig-stub/src/main/jni/realm_api_helpers.cpp b/packages/jni-swig-stub/src/main/jni/realm_api_helpers.cpp index df18021fc7..a3009b106c 100644 --- a/packages/jni-swig-stub/src/main/jni/realm_api_helpers.cpp +++ b/packages/jni-swig-stub/src/main/jni/realm_api_helpers.cpp @@ -937,6 +937,25 @@ void realm_sync_websocket_closed(int64_t observer_ptr, bool was_clean, int error realm_sync_socket_websocket_closed(reinterpret_cast(observer_ptr), was_clean, static_cast(error_code), reason); } +realm_results_t* realm_knnsearch(const realm_query_t* existing_query, const char* property, jfloatArray floats, int numberOfNeighbours) { + auto table = existing_query->query.get_table(); + auto jenv = get_env(false); + + jsize len = jenv->GetArrayLength( floats); + jfloat *vec = jenv->GetFloatArrayElements(floats, 0); + std::vector query_vector; + query_vector.reserve(len); + for (size_t i = 0; i < len; ++i) { + query_vector.push_back(vec[i]); + } + jenv->ReleaseFloatArrayElements(floats, vec, 0); + + auto col_lst = table->get_column_key(property); + + Results results(existing_query->weak_realm.lock(), table->where()); + return new realm_results{results.knn_search(col_lst, query_vector, numberOfNeighbours)}; +} + realm_sync_socket_t* realm_sync_websocket_new(int64_t sync_client_config_ptr, jobject websocket_transport) { auto jenv = get_env(false); // Always called from JVM realm_sync_socket_t* socket_provider = realm_sync_socket_new(jenv->NewGlobalRef(websocket_transport), /*userdata*/ diff --git a/packages/jni-swig-stub/src/main/jni/realm_api_helpers.h b/packages/jni-swig-stub/src/main/jni/realm_api_helpers.h index 2fae610c7b..3fc9b918a5 100644 --- a/packages/jni-swig-stub/src/main/jni/realm_api_helpers.h +++ b/packages/jni-swig-stub/src/main/jni/realm_api_helpers.h @@ -161,4 +161,5 @@ bool realm_sync_websocket_message(int64_t observer_ptr, jbyteArray data, size_t void realm_sync_websocket_closed(int64_t observer_ptr, bool was_clean, int error_code, const char* reason); +realm_results_t* realm_knnsearch(const realm_query_t* existing_query, const char* property, jfloatArray floats, int numberOfNeighbours); #endif //TEST_REALM_API_HELPERS_H diff --git a/packages/library-base/src/commonMain/kotlin/io/realm/kotlin/internal/query/ObjectQuery.kt b/packages/library-base/src/commonMain/kotlin/io/realm/kotlin/internal/query/ObjectQuery.kt index 1e66439e55..7dedbd1a1d 100644 --- a/packages/library-base/src/commonMain/kotlin/io/realm/kotlin/internal/query/ObjectQuery.kt +++ b/packages/library-base/src/commonMain/kotlin/io/realm/kotlin/internal/query/ObjectQuery.kt @@ -115,6 +115,10 @@ internal class ObjectQuery constructor( return query(stringBuilder.toString()) } + override fun knn(property: String, queryVector: Array, numberOfNeighbours: Int): RealmResults { + return RealmResultsImpl(realmReference, RealmInterop.realm_query_find_knn(queryPointer, property, queryVector, numberOfNeighbours), classKey, clazz, mediator) + } + override fun distinct(property: String, vararg extraProperties: String): RealmQuery { val stringBuilder = StringBuilder().append("TRUEPREDICATE DISTINCT($property") extraProperties.forEach { extraProperty -> diff --git a/packages/library-base/src/commonMain/kotlin/io/realm/kotlin/query/RealmQuery.kt b/packages/library-base/src/commonMain/kotlin/io/realm/kotlin/query/RealmQuery.kt index 69ba166b34..4320caed26 100644 --- a/packages/library-base/src/commonMain/kotlin/io/realm/kotlin/query/RealmQuery.kt +++ b/packages/library-base/src/commonMain/kotlin/io/realm/kotlin/query/RealmQuery.kt @@ -67,6 +67,8 @@ public interface RealmQuery : RealmElementQuery { */ public fun sort(property: String, sortOrder: Sort = Sort.ASCENDING): RealmQuery + public fun knn(property: String, queryVector: Array, numberOfNeighbours: Int): RealmResults + /** * Sorts the query result by the specific property name according to [Pair]s of properties and * sorting order. diff --git a/packages/test-base/src/commonTest/kotlin/io/realm/kotlin/test/common/QueryTests.kt b/packages/test-base/src/commonTest/kotlin/io/realm/kotlin/test/common/QueryTests.kt index 0b6eb5dc3c..85ec69ee88 100644 --- a/packages/test-base/src/commonTest/kotlin/io/realm/kotlin/test/common/QueryTests.kt +++ b/packages/test-base/src/commonTest/kotlin/io/realm/kotlin/test/common/QueryTests.kt @@ -3079,6 +3079,66 @@ class QueryTests { channel.close() } + // ---------------- + // KNN search + // ---------------- + @Test + fun knn_search() { + val realm = Realm.open(RealmConfiguration.create(setOf(VectorEmbeddingSample::class))) + realm.writeBlocking { + copyToRealm(VectorEmbeddingSample().apply { + stringField = "vector 1" + embedding.addAll( + listOf( + 0.003f, 0.004f, 0.005f, 0.100f, 0.010f + ) + ) + }) + + copyToRealm(VectorEmbeddingSample().apply { + stringField = "vector 2" + embedding.addAll( + listOf( + 0.001f, 0.004f, 0.005f, 0.100f, 0.010f + ) + ) + }) + + copyToRealm(VectorEmbeddingSample().apply { + stringField = "vector 3" + embedding.addAll( + listOf( + 0.001f, 0.004f, 0.005f, 0.100f, 0.010f + ) + ) + }) + + copyToRealm(VectorEmbeddingSample().apply { + stringField = "vector 4" + embedding.addAll( + listOf( + 0.004f, 0.005f, 0.010f, 0.025f, 0.100f + ) + ) + }) + + copyToRealm(VectorEmbeddingSample().apply { + stringField = "vector 5" + embedding.addAll( + listOf( + 0.003f, 0.007f, 0.008f, 0.020f, 0.100f + ) + ) + }) + } + val knn: RealmResults = realm.query() + .knn("embedding", arrayOf(0.003f, 0.005f, 0.010f, 0.020f, 0.100f), 2) + assertEquals(2, knn.size) + assertEquals("vector 4", knn[0].stringField) + assertEquals("vector 5", knn[1].stringField) + } + + // ---------------- // Coercion helpers // ---------------- @@ -3297,6 +3357,10 @@ private data class PropertyDescriptor constructor( val values: List ) +class VectorEmbeddingSample : RealmObject { + var stringField: String = "" + var embedding: RealmList = realmListOf() +} /** * Use this and not [io.realm.kotlin.entities.Sample] as that class has default initializers that make * aggregating operations harder to assert.