From cc23f28cba911ce0f03f43911e543856fc01b309 Mon Sep 17 00:00:00 2001 From: Ralph Gasser Date: Mon, 6 May 2024 16:38:34 +0200 Subject: [PATCH] Minor optimisations. --- .../graph/AbstractDynamicExplorationGraph.kt | 143 ++++++++++-------- .../graph/InMemoryDynamicExplorationGraph.kt | 6 +- .../utilities/graph/memory/InMemoryGraph.kt | 6 +- 3 files changed, 85 insertions(+), 70 deletions(-) diff --git a/cottontaildb-dbms/src/main/kotlin/org/vitrivr/cottontail/dbms/index/diskann/graph/AbstractDynamicExplorationGraph.kt b/cottontaildb-dbms/src/main/kotlin/org/vitrivr/cottontail/dbms/index/diskann/graph/AbstractDynamicExplorationGraph.kt index b9aaadd3a..fe159252d 100644 --- a/cottontaildb-dbms/src/main/kotlin/org/vitrivr/cottontail/dbms/index/diskann/graph/AbstractDynamicExplorationGraph.kt +++ b/cottontaildb-dbms/src/main/kotlin/org/vitrivr/cottontail/dbms/index/diskann/graph/AbstractDynamicExplorationGraph.kt @@ -1,13 +1,11 @@ package org.vitrivr.cottontail.dbms.index.diskann.graph import it.unimi.dsi.fastutil.objects.Object2FloatLinkedOpenHashMap -import org.apache.lucene.search.Weight +import it.unimi.dsi.fastutil.objects.ObjectArraySet import org.vitrivr.cottontail.core.database.TupleId import org.vitrivr.cottontail.core.types.VectorValue import org.vitrivr.cottontail.utilities.graph.Graph import java.lang.Math.floorDiv -import java.util.* -import kotlin.collections.HashMap import kotlin.collections.HashSet import kotlin.math.max @@ -20,7 +18,7 @@ import kotlin.math.max * @author Ralph Gasser * @version 1.0.0 */ -abstract class AbstractDynamicExplorationGraph(private val degree: Int, val graph: Graph.Node>) { +abstract class AbstractDynamicExplorationGraph,V>(private val degree: Int, val graph: Graph.Node>, private val epsilonExt: Float = 0.3f, private val kExt: Int = 60) { init { @@ -33,60 +31,45 @@ abstract class AbstractDynamicExplorationGraph(private val degree: Int, val * @param identifier The identifier [I] of the entry to index. * @param vector The vector [V] of the entry to index. */ - fun index(identifier: I, vector: V, epsilon: Double) { + fun index(identifier: I, vector: V) { val count = this.size() /* Create new (empty) node and store vector. */ val newNode = Node(identifier) this.storeVector(identifier, vector) + this.graph.addVertex(newNode) - if (count <= this.degree) - { /* Case 1: Graph does not satisfy regularity condition since it is too small: Create new node and make all existing nodes connect to */ - this.graph.addVertex(newNode) + if (count <= this.degree) { /* Case 1: Graph does not satisfy regularity condition since it is too small: Create new node and make all existing nodes connect to */ for (node in this.graph) { - val distance = this.distance(vector, node.vector).toFloat() - if (node != newNode) { - this.graph.addEdge(newNode, node, distance) - this.graph.addEdge(node, newNode, distance) - } + if (node == newNode) continue + val distance = this.distance(vector, node.vector) + this.graph.addEdge(newNode, node, distance) + this.graph.addEdge(node, newNode, distance) } - } else { /* Case 2: Graph is not regular. */ - val search = this.search(vector, this.degree, epsilon) - val connect = HashMap() + } else { /* Case 2: Graph is regular. */ + val results = this.search(vector, this.kExt, this.epsilonExt) var skipRng = false /* Start insert procedure. */ - while (connect.size < this.degree) { - val nodesToExplore = search.entries.filter { !connect.contains(it.key) }.associate { it.key to it.value }.toMutableMap() - while (connect.size < this.degree && nodesToExplore.isNotEmpty()) { - var closestNode = nodesToExplore.keys.first() - var closestDistance = Double.MAX_VALUE - for ((node, _) in nodesToExplore.entries) { - val distance = this.distance(vector, node.vector) - if (distance < closestDistance) { - closestDistance = distance - closestNode = node - } - } - nodesToExplore.remove(closestNode) - - /* Identify the best vertex to connect to existing vertex. */ - if (skipRng || checkMrng(newNode, connect, closestNode)) { - val farthestNodeFromClosest = this.graph.edges(closestNode).filter { !connect.contains(it.key) }.maxBy { it.value }.key - connect[closestNode] = this.distance(closestNode.vector, newNode.vector).toFloat() - connect[farthestNodeFromClosest] = this.distance(farthestNodeFromClosest.vector, newNode.vector).toFloat() - - /* Update receiving node. */ - this.graph.removeEdge(farthestNodeFromClosest, closestNode) - } + var newNeighbours = this.graph.edges(newNode) + while (newNeighbours.size < this.degree) { + for ((candidateNode, candidateWeight) in results) { + if (newNeighbours.size >= this.degree) break + if (newNeighbours.contains(candidateNode)) continue + if (!(skipRng || checkMrng(newNode, candidateNode, candidateWeight))) continue + + /* Find candidate neighbour. */ + val (candidateNeighbour,candidateNeighbourWeight) = this.graph.edges(candidateNode).filter { !newNeighbours.contains(it.key) }.maxBy { it.value } + + /* Remove edge from candidate node to candidate neighbour. */ + this.graph.removeEdge(candidateNode, candidateNeighbour) + + /* Add edges to new nodes. */ + this.graph.addEdge(newNode, candidateNode, candidateWeight) + this.graph.addEdge(newNode, candidateNeighbour, candidateNeighbourWeight) } skipRng = true - } - - /* */ - this.graph.addVertex(newNode) - for ((node, weight) in connect) { - this.graph.addEdge(newNode, node, weight) + newNeighbours = this.graph.edges(newNode) /* Fetch new neighbours. */ } } } @@ -99,7 +82,7 @@ abstract class AbstractDynamicExplorationGraph(private val degree: Int, val * @param epsilon The epsilon value for the search. * @return [List] of [Triple]s containing the [TupleId], distance and [VectorValue] of the approximate nearest neighbours. */ - fun search(query: V, k: Int, epsilon: Double): Map { + fun search(query: V, k: Int, epsilon: Float): List { val seed = this.getSeedNodes(this.degree) val checked = HashSet() var r = Float.MAX_VALUE @@ -111,7 +94,7 @@ abstract class AbstractDynamicExplorationGraph(private val degree: Int, val while (seed.isNotEmpty()) { /* Find seed node closest to query. */ var closestNode: Node = seed.first() - var closestDistance = Double.MAX_VALUE + var closestDistance = Float.MAX_VALUE for (node in seed) { val distance = this.distance(query, node.vector) if (distance < closestDistance) { @@ -133,7 +116,7 @@ abstract class AbstractDynamicExplorationGraph(private val degree: Int, val if (distance < r * (1 + epsilon)) { seed.add(node) if (distance <= r) { - results[node] = distance.toFloat() + results[node] = distance if (results.size > k) { val largest = results.maxBy { it.value } results.removeFloat(largest.key) @@ -148,7 +131,7 @@ abstract class AbstractDynamicExplorationGraph(private val degree: Int, val } } - return results + return results.map { Distance(it.key, it.value) }.sorted() } /** @@ -173,24 +156,26 @@ abstract class AbstractDynamicExplorationGraph(private val degree: Int, val * * @param a The first vector [V]s. * @param b The first vector [V]s. - * @return [Double] distance between the two vectors. + * @return [Float] distance between the two vectors. */ - protected abstract fun distance(a: V, b: V): Double + protected abstract fun distance(a: V, b: V): Float /** * Obtains random seed [Node]s for range search. * * @param size The number of seed [Node]s to obtain. - * @return [MutableMap of [AbstractDynamicExplorationGraph.Node]s keyed by [NodeId] + * @return [MutableSet] of [AbstractDynamicExplorationGraph.Node]s */ private fun getSeedNodes(size: Int): MutableSet { - require(size <= this.size()) { "Negative size of $size" } - val set = HashSet() + val graphSize = this.graph.size() + val sampleSize = size.toLong() + require(sampleSize <= graphSize) { "The sample size $sampleSize exceeds graph size of graph (s = $sampleSize, g = $graphSize)" } + val set = ObjectArraySet(size) for ((i, node) in this.graph.withIndex()) { - if (i % floorDiv(this.graph.size(), size.toLong()) == 0L) { + if (i % floorDiv(graphSize, sampleSize) == 0L) { set.add(node) + if (set.size >= size) break } - if (set.size >= size) break } return set } @@ -202,12 +187,11 @@ abstract class AbstractDynamicExplorationGraph(private val degree: Int, val * @param v2 The second [Node]. * @return True if MRNG condition is satisfied, false otherwise. */ - private fun checkMrng(v1: Node, v1N: Map, v2: Node): Boolean { + private fun checkMrng(v1: Node, v2: Node, targetWeight: Float): Boolean { + val v1N = this.graph.edges(v1) val v2N = this.graph.edges(v2) - val neighbours = v1N.keys intersect v2N.keys - val distance = this.distance(v1.vector, v2.vector) - for (node in neighbours) { - if (distance > max(v2N[node] ?: 0.0f, v1N[node] ?: 0.0f)) { + for (node in (v1N.keys intersect v2N.keys)) { + if (targetWeight > max(v2N[node]!!, v1N[node]!!)) { return false } } @@ -220,10 +204,45 @@ abstract class AbstractDynamicExplorationGraph(private val degree: Int, val * @author Ralph Gasser * @version 1.0.0 */ - inner class Node(val identifier: I) { + inner class Node(val identifier: I): Comparable { /** The [VectorValue]; this value is loaded lazily. */ val vector: V by lazy { loadVector(this.identifier) } + override fun compareTo(other: Node): Int = this.identifier.compareTo(other.identifier) override fun equals(other: Any?): Boolean = other is AbstractDynamicExplorationGraph<*,*>.Node && other.identifier == this.identifier override fun hashCode(): Int = this.identifier.hashCode() } + + /** + * A [Distance] element produced by this [AbstractDynamicExplorationGraph]. + * + * @author Ralph Gasser + * @version 1.0.0 + */ + inner class Distance(val identifier: Node, val distance: Float): Comparable { + override fun compareTo(other: Distance): Int { + val result = this.distance.compareTo(other.distance) + return if (result == 0) { + this.identifier.compareTo(other.identifier) + } else { + result + } + } + + operator fun component1(): Node = this.identifier + operator fun component2(): Float = this.distance + + override fun equals(other: Any?): Boolean { + if (this === other) return true + if (other !is AbstractDynamicExplorationGraph<*,*>.Distance) return false + if (this.distance != other.distance) return false + if (this.identifier != other.identifier) return false + return true + } + + override fun hashCode(): Int { + var result = identifier.hashCode() + result = 31 * result + distance.hashCode() + return result + } + } } \ No newline at end of file diff --git a/cottontaildb-dbms/src/main/kotlin/org/vitrivr/cottontail/dbms/index/diskann/graph/InMemoryDynamicExplorationGraph.kt b/cottontaildb-dbms/src/main/kotlin/org/vitrivr/cottontail/dbms/index/diskann/graph/InMemoryDynamicExplorationGraph.kt index 9e23d8d0c..628a2f236 100644 --- a/cottontaildb-dbms/src/main/kotlin/org/vitrivr/cottontail/dbms/index/diskann/graph/InMemoryDynamicExplorationGraph.kt +++ b/cottontaildb-dbms/src/main/kotlin/org/vitrivr/cottontail/dbms/index/diskann/graph/InMemoryDynamicExplorationGraph.kt @@ -1,17 +1,15 @@ package org.vitrivr.cottontail.dbms.index.diskann.graph -import it.unimi.dsi.fastutil.longs.Long2ObjectOpenHashMap import it.unimi.dsi.fastutil.objects.Object2ObjectOpenHashMap -import org.vitrivr.cottontail.utilities.graph.Graph import org.vitrivr.cottontail.utilities.graph.memory.InMemoryGraph /** * */ -class InMemoryDynamicExplorationGraph(degree: Int, private val df: (V, V) -> Double): AbstractDynamicExplorationGraph(degree, InMemoryGraph(degree)) { +class InMemoryDynamicExplorationGraph,V>(degree: Int, private val df: (V, V) -> Float): AbstractDynamicExplorationGraph(degree, InMemoryGraph(degree)) { private val vectors = Object2ObjectOpenHashMap() override fun size(): Long = this.graph.size() - override fun distance(a: V, b: V): Double = this.df(a, b) + override fun distance(a: V, b: V): Float = this.df(a, b) override fun loadVector(identifier: I): V = this.vectors[identifier] ?: throw NoSuchElementException("Could not find vector for identifier $identifier") override fun storeVector(identifier: I, vector: V) { this.vectors[identifier] = vector diff --git a/cottontaildb-dbms/src/main/kotlin/org/vitrivr/cottontail/utilities/graph/memory/InMemoryGraph.kt b/cottontaildb-dbms/src/main/kotlin/org/vitrivr/cottontail/utilities/graph/memory/InMemoryGraph.kt index 5a65ce674..a418d4f14 100644 --- a/cottontaildb-dbms/src/main/kotlin/org/vitrivr/cottontail/utilities/graph/memory/InMemoryGraph.kt +++ b/cottontaildb-dbms/src/main/kotlin/org/vitrivr/cottontail/utilities/graph/memory/InMemoryGraph.kt @@ -3,8 +3,6 @@ package org.vitrivr.cottontail.utilities.graph.memory import it.unimi.dsi.fastutil.objects.Object2FloatOpenHashMap import it.unimi.dsi.fastutil.objects.Object2ObjectLinkedOpenHashMap import org.vitrivr.cottontail.utilities.graph.Graph -import kotlin.math.max -import kotlin.math.sign /** * An in memory implementation of the [Graph] interface. @@ -52,8 +50,8 @@ class InMemoryGraph(val maxDegree: Int = Int.MAX_VALUE): Graph { val e1 = this.map[from] ?: throw NoSuchElementException("The vertex $from does not exist in the graph." ) val e2 = this.map[to] ?: throw NoSuchElementException("The vertex $to does not exist in the graph." ) if (!e1.containsKey(to) && !e2.containsKey(from)) { - check(e1.size <= this.maxDegree) { "The vertex $from already has too many edges (maxDegree = ${this.maxDegree})." } - check(e2.size <= this.maxDegree) { "The vertex $from already has too many edges (maxDegree = ${this.maxDegree})." } + check(e1.size < this.maxDegree) { "The vertex $from already has too many edges (maxDegree = ${this.maxDegree})." } + check(e2.size < this.maxDegree) { "The vertex $from already has too many edges (maxDegree = ${this.maxDegree})." } e1[to] = weight e2[from] = weight return true