Skip to content

Commit

Permalink
Minor optimisations.
Browse files Browse the repository at this point in the history
  • Loading branch information
Ralph Gasser committed May 6, 2024
1 parent ec780fe commit cc23f28
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 70 deletions.
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
package org.vitrivr.cottontail.dbms.index.diskann.graph

import it.unimi.dsi.fastutil.objects.Object2FloatLinkedOpenHashMap
import org.apache.lucene.search.Weight
import it.unimi.dsi.fastutil.objects.ObjectArraySet
import org.vitrivr.cottontail.core.database.TupleId
import org.vitrivr.cottontail.core.types.VectorValue
import org.vitrivr.cottontail.utilities.graph.Graph
import java.lang.Math.floorDiv
import java.util.*
import kotlin.collections.HashMap
import kotlin.collections.HashSet
import kotlin.math.max

Expand All @@ -20,7 +18,7 @@ import kotlin.math.max
* @author Ralph Gasser
* @version 1.0.0
*/
abstract class AbstractDynamicExplorationGraph<I,V>(private val degree: Int, val graph: Graph<AbstractDynamicExplorationGraph<I,V>.Node>) {
abstract class AbstractDynamicExplorationGraph<I:Comparable<I>,V>(private val degree: Int, val graph: Graph<AbstractDynamicExplorationGraph<I,V>.Node>, private val epsilonExt: Float = 0.3f, private val kExt: Int = 60) {


init {
Expand All @@ -33,60 +31,45 @@ abstract class AbstractDynamicExplorationGraph<I,V>(private val degree: Int, val
* @param identifier The identifier [I] of the entry to index.
* @param vector The vector [V] of the entry to index.
*/
fun index(identifier: I, vector: V, epsilon: Double) {
fun index(identifier: I, vector: V) {
val count = this.size()

/* Create new (empty) node and store vector. */
val newNode = Node(identifier)
this.storeVector(identifier, vector)
this.graph.addVertex(newNode)

if (count <= this.degree)
{ /* Case 1: Graph does not satisfy regularity condition since it is too small: Create new node and make all existing nodes connect to */
this.graph.addVertex(newNode)
if (count <= this.degree) { /* Case 1: Graph does not satisfy regularity condition since it is too small: Create new node and make all existing nodes connect to */
for (node in this.graph) {
val distance = this.distance(vector, node.vector).toFloat()
if (node != newNode) {
this.graph.addEdge(newNode, node, distance)
this.graph.addEdge(node, newNode, distance)
}
if (node == newNode) continue
val distance = this.distance(vector, node.vector)
this.graph.addEdge(newNode, node, distance)
this.graph.addEdge(node, newNode, distance)
}
} else { /* Case 2: Graph is not regular. */
val search = this.search(vector, this.degree, epsilon)
val connect = HashMap<Node, Float>()
} else { /* Case 2: Graph is regular. */
val results = this.search(vector, this.kExt, this.epsilonExt)
var skipRng = false

/* Start insert procedure. */
while (connect.size < this.degree) {
val nodesToExplore = search.entries.filter { !connect.contains(it.key) }.associate { it.key to it.value }.toMutableMap()
while (connect.size < this.degree && nodesToExplore.isNotEmpty()) {
var closestNode = nodesToExplore.keys.first()
var closestDistance = Double.MAX_VALUE
for ((node, _) in nodesToExplore.entries) {
val distance = this.distance(vector, node.vector)
if (distance < closestDistance) {
closestDistance = distance
closestNode = node
}
}
nodesToExplore.remove(closestNode)

/* Identify the best vertex to connect to existing vertex. */
if (skipRng || checkMrng(newNode, connect, closestNode)) {
val farthestNodeFromClosest = this.graph.edges(closestNode).filter { !connect.contains(it.key) }.maxBy { it.value }.key
connect[closestNode] = this.distance(closestNode.vector, newNode.vector).toFloat()
connect[farthestNodeFromClosest] = this.distance(farthestNodeFromClosest.vector, newNode.vector).toFloat()

/* Update receiving node. */
this.graph.removeEdge(farthestNodeFromClosest, closestNode)
}
var newNeighbours = this.graph.edges(newNode)
while (newNeighbours.size < this.degree) {
for ((candidateNode, candidateWeight) in results) {
if (newNeighbours.size >= this.degree) break
if (newNeighbours.contains(candidateNode)) continue
if (!(skipRng || checkMrng(newNode, candidateNode, candidateWeight))) continue

/* Find candidate neighbour. */
val (candidateNeighbour,candidateNeighbourWeight) = this.graph.edges(candidateNode).filter { !newNeighbours.contains(it.key) }.maxBy { it.value }

/* Remove edge from candidate node to candidate neighbour. */
this.graph.removeEdge(candidateNode, candidateNeighbour)

/* Add edges to new nodes. */
this.graph.addEdge(newNode, candidateNode, candidateWeight)
this.graph.addEdge(newNode, candidateNeighbour, candidateNeighbourWeight)
}
skipRng = true
}

/* */
this.graph.addVertex(newNode)
for ((node, weight) in connect) {
this.graph.addEdge(newNode, node, weight)
newNeighbours = this.graph.edges(newNode) /* Fetch new neighbours. */
}
}
}
Expand All @@ -99,7 +82,7 @@ abstract class AbstractDynamicExplorationGraph<I,V>(private val degree: Int, val
* @param epsilon The epsilon value for the search.
* @return [List] of [Triple]s containing the [TupleId], distance and [VectorValue] of the approximate nearest neighbours.
*/
fun search(query: V, k: Int, epsilon: Double): Map<Node,Float> {
fun search(query: V, k: Int, epsilon: Float): List<Distance> {
val seed = this.getSeedNodes(this.degree)
val checked = HashSet<Node>()
var r = Float.MAX_VALUE
Expand All @@ -111,7 +94,7 @@ abstract class AbstractDynamicExplorationGraph<I,V>(private val degree: Int, val
while (seed.isNotEmpty()) {
/* Find seed node closest to query. */
var closestNode: Node = seed.first()
var closestDistance = Double.MAX_VALUE
var closestDistance = Float.MAX_VALUE
for (node in seed) {
val distance = this.distance(query, node.vector)
if (distance < closestDistance) {
Expand All @@ -133,7 +116,7 @@ abstract class AbstractDynamicExplorationGraph<I,V>(private val degree: Int, val
if (distance < r * (1 + epsilon)) {
seed.add(node)
if (distance <= r) {
results[node] = distance.toFloat()
results[node] = distance
if (results.size > k) {
val largest = results.maxBy { it.value }
results.removeFloat(largest.key)
Expand All @@ -148,7 +131,7 @@ abstract class AbstractDynamicExplorationGraph<I,V>(private val degree: Int, val
}
}

return results
return results.map { Distance(it.key, it.value) }.sorted()
}

/**
Expand All @@ -173,24 +156,26 @@ abstract class AbstractDynamicExplorationGraph<I,V>(private val degree: Int, val
*
* @param a The first vector [V]s.
* @param b The first vector [V]s.
* @return [Double] distance between the two vectors.
* @return [Float] distance between the two vectors.
*/
protected abstract fun distance(a: V, b: V): Double
protected abstract fun distance(a: V, b: V): Float

/**
* Obtains random seed [Node]s for range search.
*
* @param size The number of seed [Node]s to obtain.
* @return [MutableMap of [AbstractDynamicExplorationGraph.Node]s keyed by [NodeId]
* @return [MutableSet] of [AbstractDynamicExplorationGraph.Node]s
*/
private fun getSeedNodes(size: Int): MutableSet<Node> {
require(size <= this.size()) { "Negative size of $size" }
val set = HashSet<Node>()
val graphSize = this.graph.size()
val sampleSize = size.toLong()
require(sampleSize <= graphSize) { "The sample size $sampleSize exceeds graph size of graph (s = $sampleSize, g = $graphSize)" }
val set = ObjectArraySet<Node>(size)
for ((i, node) in this.graph.withIndex()) {
if (i % floorDiv(this.graph.size(), size.toLong()) == 0L) {
if (i % floorDiv(graphSize, sampleSize) == 0L) {
set.add(node)
if (set.size >= size) break
}
if (set.size >= size) break
}
return set
}
Expand All @@ -202,12 +187,11 @@ abstract class AbstractDynamicExplorationGraph<I,V>(private val degree: Int, val
* @param v2 The second [Node].
* @return True if MRNG condition is satisfied, false otherwise.
*/
private fun checkMrng(v1: Node, v1N: Map<Node,Float>, v2: Node): Boolean {
private fun checkMrng(v1: Node, v2: Node, targetWeight: Float): Boolean {
val v1N = this.graph.edges(v1)
val v2N = this.graph.edges(v2)
val neighbours = v1N.keys intersect v2N.keys
val distance = this.distance(v1.vector, v2.vector)
for (node in neighbours) {
if (distance > max(v2N[node] ?: 0.0f, v1N[node] ?: 0.0f)) {
for (node in (v1N.keys intersect v2N.keys)) {
if (targetWeight > max(v2N[node]!!, v1N[node]!!)) {
return false
}
}
Expand All @@ -220,10 +204,45 @@ abstract class AbstractDynamicExplorationGraph<I,V>(private val degree: Int, val
* @author Ralph Gasser
* @version 1.0.0
*/
inner class Node(val identifier: I) {
inner class Node(val identifier: I): Comparable<Node> {
/** The [VectorValue]; this value is loaded lazily. */
val vector: V by lazy { loadVector(this.identifier) }
override fun compareTo(other: Node): Int = this.identifier.compareTo(other.identifier)
override fun equals(other: Any?): Boolean = other is AbstractDynamicExplorationGraph<*,*>.Node && other.identifier == this.identifier
override fun hashCode(): Int = this.identifier.hashCode()
}

/**
* A [Distance] element produced by this [AbstractDynamicExplorationGraph].
*
* @author Ralph Gasser
* @version 1.0.0
*/
inner class Distance(val identifier: Node, val distance: Float): Comparable<Distance> {
override fun compareTo(other: Distance): Int {
val result = this.distance.compareTo(other.distance)
return if (result == 0) {
this.identifier.compareTo(other.identifier)
} else {
result
}
}

operator fun component1(): Node = this.identifier
operator fun component2(): Float = this.distance

override fun equals(other: Any?): Boolean {
if (this === other) return true
if (other !is AbstractDynamicExplorationGraph<*,*>.Distance) return false
if (this.distance != other.distance) return false
if (this.identifier != other.identifier) return false
return true
}

override fun hashCode(): Int {
var result = identifier.hashCode()
result = 31 * result + distance.hashCode()
return result
}
}
}
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
package org.vitrivr.cottontail.dbms.index.diskann.graph

import it.unimi.dsi.fastutil.longs.Long2ObjectOpenHashMap
import it.unimi.dsi.fastutil.objects.Object2ObjectOpenHashMap
import org.vitrivr.cottontail.utilities.graph.Graph
import org.vitrivr.cottontail.utilities.graph.memory.InMemoryGraph

/**
*
*/
class InMemoryDynamicExplorationGraph<I,V>(degree: Int, private val df: (V, V) -> Double): AbstractDynamicExplorationGraph<I,V>(degree, InMemoryGraph(degree)) {
class InMemoryDynamicExplorationGraph<I: Comparable<I>,V>(degree: Int, private val df: (V, V) -> Float): AbstractDynamicExplorationGraph<I,V>(degree, InMemoryGraph(degree)) {
private val vectors = Object2ObjectOpenHashMap<I,V>()
override fun size(): Long = this.graph.size()
override fun distance(a: V, b: V): Double = this.df(a, b)
override fun distance(a: V, b: V): Float = this.df(a, b)
override fun loadVector(identifier: I): V = this.vectors[identifier] ?: throw NoSuchElementException("Could not find vector for identifier $identifier")
override fun storeVector(identifier: I, vector: V) {
this.vectors[identifier] = vector
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@ package org.vitrivr.cottontail.utilities.graph.memory
import it.unimi.dsi.fastutil.objects.Object2FloatOpenHashMap
import it.unimi.dsi.fastutil.objects.Object2ObjectLinkedOpenHashMap
import org.vitrivr.cottontail.utilities.graph.Graph
import kotlin.math.max
import kotlin.math.sign

/**
* An in memory implementation of the [Graph] interface.
Expand Down Expand Up @@ -52,8 +50,8 @@ class InMemoryGraph<V>(val maxDegree: Int = Int.MAX_VALUE): Graph<V> {
val e1 = this.map[from] ?: throw NoSuchElementException("The vertex $from does not exist in the graph." )
val e2 = this.map[to] ?: throw NoSuchElementException("The vertex $to does not exist in the graph." )
if (!e1.containsKey(to) && !e2.containsKey(from)) {
check(e1.size <= this.maxDegree) { "The vertex $from already has too many edges (maxDegree = ${this.maxDegree})." }
check(e2.size <= this.maxDegree) { "The vertex $from already has too many edges (maxDegree = ${this.maxDegree})." }
check(e1.size < this.maxDegree) { "The vertex $from already has too many edges (maxDegree = ${this.maxDegree})." }
check(e2.size < this.maxDegree) { "The vertex $from already has too many edges (maxDegree = ${this.maxDegree})." }
e1[to] = weight
e2[from] = weight
return true
Expand Down

0 comments on commit cc23f28

Please sign in to comment.