from __future__ import annotations
__all__ = ('VecDB', 'VecDBElement', 'Distance', 'EuclideanDistanceSquared')
from genlayer.py.storage import DynArray, TreeMap
from genlayer.py.types import u32
from genlayer.py.storage.annotations import allow_storage
import typing
import numpy as np
import math
[docs]
class Distance(typing.Protocol):
[docs]
def __call__(self, l, r) -> typing.Any: ...
[docs]
@allow_storage
class EuclideanDistanceSquared(Distance):
[docs]
def __call__(self, l, r):
return np.sum((l - r) ** 2)
[docs]
def batch(self, l, r):
return ((l - r) ** 2).sum(axis=1)
Id = typing.NewType('Id', int)
_Id = Id
NO_PARENT = u32(0xFFFFFFFF) # Constant for no parent node
@allow_storage
class CoverTreeNode:
"""A node in the cover tree structure"""
element_id: u32
level: u32
children: DynArray[u32] # Indices of child nodes
parent: u32 # Index of parent node, NO_PARENT if root
def __init__(self, element_id: u32, level: u32):
self.element_id = element_id
self.level = level
class VecDBElement[T: np.number, S: int, V, Dist]:
distance: Dist
"""
Distance from search point to this element, if any
"""
__slots__ = ('_idx', '_db', 'distance')
def __init__(self, db: VecDB[T, S, V], idx: u32, distance: Dist):
self._idx = idx
self._db = db
self.distance = distance
def __repr__(self) -> str:
return f'VecDB.Element(id={self.id!r}, key={self.key!r}, value={self.value!r}, distance={self.distance})'
@property
def key(self) -> np.ndarray[tuple[S], np.dtype[T]]:
"""
Key (vector) of this element
"""
return self._db._keys[self._idx]
@property
def id(self) -> Id:
"""
Id (unique key) of this element
"""
return Id(self._idx)
@property
def value(self) -> V:
"""
Value of this element
"""
return self._db._values[self._idx]
@value.setter
def value(self, v: V):
self._db._values[self._idx] = v
def remove(self) -> None:
"""
Removes current element from the db
"""
self._db._remove_from_tree(self._idx)
self._db._free_idx[self._idx] = None
[docs]
@allow_storage
class VecDB[T: np.number, S: int, V, D: Distance]:
"""
Data structure that supports storing and querying vector data using Cover Trees
Cover trees provide logarithmic time nearest neighbor search with theoretical guarantees.
There are two entities that can act as a key:
#. vector (can have duplicates)
#. id (int alias, can't have duplicates)
.. warning::
import :py:mod:`numpy` before ``from genlayer import *`` if you wish to use :py:class:`VecDB`!
"""
type Id = _Id
"""
:py:class:`int` alias to prevent confusion
"""
type Element = VecDBElement
"""
Shorthand to prevent global namespace pollution
"""
_keys: DynArray[np.ndarray[tuple[S], np.dtype[T]]]
_values: DynArray[V]
_free_idx: TreeMap[u32, None]
_nodes: DynArray[CoverTreeNode]
_free_nodes: TreeMap[u32, None]
_root_idx: u32
_base: float # Base for cover tree levels (typically 1.3)
_max_level: u32
_dist_func: D
_initialized: bool = False
[docs]
def __init__(self):
self._do_init()
def _do_init(self):
if self._initialized:
return
self._initialized = True
self._root_idx = NO_PARENT
self._base = 1.3
self._max_level = u32(0)
[docs]
def __len__(self) -> int:
return len(self._keys) - len(self._free_idx)
[docs]
def get_by_id(self, id: Id) -> VecDBElement[T, S, V, None]:
res = self.get_by_id_or_none(id)
if res is None:
raise KeyError(f'no element with id {id}')
return res
[docs]
def get_by_id_or_none(self, id: Id) -> VecDBElement[T, S, V, None] | None:
if u32(id) in self._free_idx:
return None
return VecDBElement(self, u32(id), None)
def _distance(self, idx1: u32, idx2: u32) -> T:
"""Compute distance between two elements by their indices"""
return self._dist_func(self._keys[idx1], self._keys[idx2])
def _distance_to_point(self, idx: u32, point: np.ndarray[tuple[S], np.dtype[T]]) -> T:
"""Compute distance from element to query point"""
return self._dist_func(self._keys[idx], point)
def _allocate_node(self, element_id: u32, level: u32) -> u32:
"""Allocate a new node and return its index"""
if len(self._free_nodes) > 0:
node_idx = self._free_nodes.popitem()[0]
self._nodes[node_idx] = CoverTreeNode(element_id, level)
return node_idx
else:
node = CoverTreeNode(element_id, level)
self._nodes.append(node)
return u32(len(self._nodes) - 1)
def _free_node(self, node_idx: u32) -> None:
"""Mark a node as free"""
self._free_nodes[node_idx] = None
[docs]
def insert(self, key: np.ndarray[tuple[S], np.dtype[T]], val: V) -> Id:
self._do_init()
# Add to storage arrays
if len(self._free_idx) > 0:
idx = self._free_idx.popitem()[0]
self._keys[idx] = key
self._values[idx] = val
else:
self._keys.append(key)
self._values.append(val)
idx = u32(len(self._keys) - 1)
# Insert into cover tree
self._insert_into_tree(idx)
return Id(idx)
def _insert_into_tree(self, new_idx: u32) -> None:
"""Insert element into cover tree structure"""
if self._root_idx == NO_PARENT:
# First element becomes root
self._root_idx = self._allocate_node(new_idx, u32(0))
self._nodes[self._root_idx].parent = NO_PARENT
self._max_level = u32(0)
return
# Find insertion level based on distance to nearest neighbor
nearest_dist = float('inf')
for i in range(len(self._keys)):
if u32(i) in self._free_idx or i == new_idx:
continue
dist = self._distance(new_idx, u32(i))
if dist < nearest_dist:
nearest_dist = dist
# Determine level for new node
if nearest_dist == 0:
level = 0
else:
print(nearest_dist, self._base)
level = min(self._max_level, int(math.log(nearest_dist) / math.log(self._base)))
level = max(level, 0)
# Create new node
new_node_idx = self._allocate_node(new_idx, u32(level))
# Insert at appropriate level
self._insert_node_at_level(new_node_idx, u32(level))
def _insert_node_at_level(self, new_node_idx: u32, level: u32) -> None:
"""Insert node at specified level in the tree"""
if level > self._max_level:
# Need to create new root
old_root_idx = self._root_idx
self._nodes[new_node_idx].level = u32(level + 1)
self._nodes[new_node_idx].parent = NO_PARENT
self._root_idx = new_node_idx
if old_root_idx != NO_PARENT:
self._nodes[new_node_idx].children.append(old_root_idx)
self._nodes[old_root_idx].parent = new_node_idx
self._max_level = u32(level + 1)
return
# Find best parent at level + 1
parent_candidates: list[u32] = self._find_nodes_at_level(level + 1)
if len(parent_candidates) == 0 and self._root_idx != NO_PARENT:
parent_candidates.append(self._root_idx)
best_parent_idx = NO_PARENT
best_distance = float('inf')
for candidate_idx in parent_candidates:
dist = self._distance(
self._nodes[candidate_idx].element_id, self._nodes[new_node_idx].element_id
)
if dist < best_distance:
best_distance = dist
best_parent_idx = candidate_idx
if best_parent_idx != NO_PARENT:
self._nodes[new_node_idx].parent = best_parent_idx
self._nodes[best_parent_idx].children.append(new_node_idx)
def _find_nodes_at_level(self, target_level: int) -> list[u32]:
"""Find all node indices at specified level"""
nodes: list[u32] = []
if self._root_idx == NO_PARENT:
return nodes
stack: list[u32] = [self._root_idx]
while len(stack) > 0:
node_idx = stack.pop()
node = self._nodes[node_idx]
if node.level == target_level:
nodes.append(node_idx)
elif node.level > target_level:
for i in range(len(node.children)):
stack.append(node.children[i])
return nodes
def _remove_from_tree(self, idx: u32) -> None:
"""Remove element from cover tree structure"""
# Find and remove the node
node_idx = self._find_node_by_id(idx)
if node_idx == NO_PARENT:
return
node = self._nodes[node_idx]
parent_idx = node.parent
children_idxs: list[u32] = [node.children[i] for i in range(len(node.children))]
if parent_idx != NO_PARENT:
# Remove from parent's children list
parent = self._nodes[parent_idx]
for i in range(len(parent.children)):
if parent.children[i] == node_idx:
parent.children[i : i + 1] = []
break
# Reattach children to parent
for child_idx in children_idxs:
self._nodes[child_idx].parent = parent_idx
parent.children.append(child_idx)
elif node_idx == self._root_idx:
# Removing root
if len(children_idxs) > 0:
# Promote highest level child to root
best_child_idx = children_idxs[0]
best_level = self._nodes[best_child_idx].level
for i in range(1, len(children_idxs)):
child_idx = children_idxs[i]
if self._nodes[child_idx].level > best_level:
best_level = self._nodes[child_idx].level
best_child_idx = child_idx
self._root_idx = best_child_idx
self._nodes[best_child_idx].parent = NO_PARENT
# Reattach other children
for child_idx in children_idxs:
if child_idx != best_child_idx:
self._nodes[child_idx].parent = best_child_idx
self._nodes[best_child_idx].children.append(child_idx)
else:
self._root_idx = NO_PARENT
self._max_level = u32(0)
# Free the node
self._free_node(node_idx)
def _find_node_by_id(self, element_id: u32) -> u32:
"""Find node index with given element ID"""
if self._root_idx == NO_PARENT:
return NO_PARENT
stack: list[u32] = [self._root_idx]
while len(stack) > 0:
node_idx = stack.pop()
if node_idx in self._free_nodes:
continue
node = self._nodes[node_idx]
if node.element_id == element_id:
return node_idx
for i in range(len(node.children)):
stack.append(node.children[i])
return NO_PARENT
[docs]
def knn(
self, v: np.ndarray[tuple[S], np.dtype[T]], k: int
) -> typing.Iterator[VecDBElement[T, S, V, T]]:
"""Find k nearest neighbors using cover tree"""
self._do_init()
if self._root_idx == NO_PARENT or k <= 0:
return
# Use a priority queue approach for efficiency
candidates: list[tuple[T, u32]] = [] # Will store (distance, element_id) tuples
# Traverse tree to find candidates
stack: list[u32] = [self._root_idx]
while len(stack) > 0:
node_idx = stack.pop()
if node_idx in self._free_nodes:
continue
node = self._nodes[node_idx]
if node.element_id not in self._free_idx:
dist = self._distance_to_point(node.element_id, v)
if np.isfinite(dist):
candidates.append((dist, node.element_id))
for i in range(len(node.children)):
stack.append(node.children[i])
# Sort by distance
candidates.sort(key=lambda x: x[0])
count = 0
for dist, idx in candidates:
if count >= k:
break
yield VecDBElement(self, idx, dist)
count += 1
[docs]
def __iter__(self):
self._do_init()
for i in range(len(self._keys)):
if u32(i) in self._free_idx:
continue
yield VecDBElement(self, u32(i), None)