__all__ = ('TreeMap',)
import abc
import typing
import collections.abc
from .vec import DynArray
from genlayer.py.types import u32, i8
_NO_OBJ = object()
class _Node[K, V]:
key: K
value: V
left: u32
right: u32
balance: i8
def __init__(self, k: K, v: V):
self.key = k
if v is not _NO_OBJ:
self.value = v
self.left = u32(0)
self.right = u32(0)
self.balance = i8(0)
class Comparable(typing.Protocol):
@abc.abstractmethod
def __eq__(self, other: typing.Any, /) -> bool: ...
@abc.abstractmethod
def __lt__(self, other: typing.Any, /) -> bool: ...
[docs]
class TreeMap[K: Comparable, V](collections.abc.MutableMapping[K, V]):
"""
Represents a mapping from keys to values that can be persisted on the blockchain
"""
_root: u32
_slots: DynArray[_Node[K, V]]
_free_slots: DynArray[u32]
[docs]
def __init__(self):
"""
This class can't be created with ``TreeMap()``
:raises TypeError: always
"""
raise TypeError("this class can't be instantiated by user")
def __len__(self) -> int:
return len(self._slots) - len(self._free_slots)
def _alloc_slot(self) -> tuple[int, _Node[K, V]]:
if len(self._free_slots) > 0:
idx = int(self._free_slots[-1])
self._free_slots.pop()
slot = self._slots[idx]
else:
idx = len(self._slots)
slot = self._slots.append_new_get()
return (idx, slot)
def _free_slot(self, slot: u32):
if slot + 1 == len(self._slots):
self._slots.pop()
else:
self._free_slots.append(slot)
def _rot_left(self, par: int, cur: int):
par_node = self._slots[par - 1]
cur_node = self._slots[cur - 1]
cur_l = cur_node.left
cur_node.left = u32(par)
par_node.right = cur_l
if cur_node.balance == 0:
par_node.balance = i8(+1)
cur_node.balance = i8(-1)
else:
par_node.balance = i8(0)
cur_node.balance = i8(0)
def _rot_right(self, par: int, cur: int):
par_node = self._slots[par - 1]
cur_node = self._slots[cur - 1]
cur_r = cur_node.right
cur_node.right = u32(par)
par_node.left = cur_r
if cur_node.balance == 0:
par_node.balance = i8(-1)
cur_node.balance = i8(+1)
else:
par_node.balance = i8(0)
cur_node.balance = i8(0)
def _rot_right_left(self, gpar: int, par: int, cur: int):
gpar_node = self._slots[gpar - 1]
par_node = self._slots[par - 1]
cur_node = self._slots[cur - 1]
cur_l = cur_node.left
cur_r = cur_node.right
gpar_node.right = cur_l
par_node.left = cur_r
cur_node.left = u32(gpar)
cur_node.right = u32(par)
if cur_node.balance == 0:
par_node.balance = i8(0)
gpar_node.balance = i8(0)
elif cur_node.balance > 0:
gpar_node.balance = i8(-1)
par_node.balance = i8(0)
else:
gpar_node.balance = i8(0)
par_node.balance = i8(1)
cur_node.balance = i8(0)
def _rot_left_right(self, gpar: int, par: int, cur: int):
gpar_node = self._slots[gpar - 1]
par_node = self._slots[par - 1]
cur_node = self._slots[cur - 1]
cur_l = cur_node.left
cur_r = cur_node.right
gpar_node.left = cur_r
par_node.right = cur_l
cur_node.left = u32(par)
cur_node.right = u32(gpar)
if cur_node.balance == 0:
par_node.balance = i8(0)
gpar_node.balance = i8(0)
elif cur_node.balance > 0:
par_node.balance = i8(-1)
gpar_node.balance = i8(0)
else:
par_node.balance = i8(0)
gpar_node.balance = i8(1)
cur_node.balance = i8(0)
def _find_seq(self, k):
seq = []
cur = self._root
is_less = True
while True:
seq.append(cur)
if cur == 0:
break
cur_node = self._slots[cur - 1]
if cur_node.key == k:
break
is_less = k < cur_node.key
if is_less:
cur = cur_node.left
else:
cur = cur_node.right
return (seq, is_less)
def __delitem__(self, k: K):
seq, is_less = self._find_seq(k)
# not found
if seq[-1] == 0:
raise KeyError('key not found')
del_node = self._slots[seq[-1] - 1]
del_left = del_node.left
del_right = del_node.right
del_balance = del_node.balance
del del_node
self._free_slot(seq[-1] - 1)
special_null = False
seq_move_to = len(seq) - 1
# it has <=1 child
if del_left == 0 or del_right == 0:
if del_left == 0:
seq[seq_move_to] = del_right
else:
seq[seq_move_to] = del_left
special_null = True
else:
# we need to go right and then left*
seq.append(del_right)
while True:
cur_node = self._slots[seq[-1] - 1]
lft = cur_node.left
if lft != 0:
seq.append(lft)
else:
break
seq[seq_move_to] = seq[-1]
node_moved_to_deleted = self._slots[seq[-1] - 1]
node_moved_to_deleted.left = del_left
if seq_move_to + 2 != len(seq):
# we moved left
parent_of_node_moved_to_deleted = self._slots[seq[-2] - 1]
parent_of_node_moved_to_deleted.left = node_moved_to_deleted.right
node_moved_to_deleted.right = del_right
seq[-1] = parent_of_node_moved_to_deleted.left
else:
# we moved right once
seq[-1] = node_moved_to_deleted.right
# update parent link
if seq_move_to > 0:
par_node = self._slots[seq[seq_move_to - 1] - 1]
if is_less:
par_node.left = seq[seq_move_to]
else:
par_node.right = seq[seq_move_to]
else:
self._root = seq[seq_move_to]
# patch balance
if seq[seq_move_to] != 0:
seq_move_to_node = self._slots[seq[seq_move_to] - 1]
if special_null:
seq_move_to_node.balance = i8(0)
else:
seq_move_to_node.balance = del_balance
# rebalance
while len(seq) >= 2:
cur = seq[-1]
par = seq[-2]
par_node = self._slots[par - 1]
if special_null:
is_left = is_less
else:
is_left = cur == par_node.left
special_null = False
# we inserted to null place, so we increaced it depth
delta = -(-1 if is_left else 1)
new_b = par_node.balance + delta
if new_b == -2:
gp = 0 if len(seq) == 2 else seq[-3]
sib = par_node.left
sib_node = self._slots[sib - 1]
sib_bal = sib_node.balance
if sib_bal > 0:
right_child = sib_node.right
self._rot_left_right(par, sib, right_child)
seq.pop() # cur
seq.pop() # par
seq.append(right_child)
else:
self._rot_right(par, sib)
seq.pop(-2) # par
seq[-1] = sib
if gp != 0:
gp = self._slots[gp - 1]
if gp.left == par:
gp.left = u32(seq[-1])
else:
assert gp.right == par
gp.right = u32(seq[-1])
if sib_bal == 0:
break
elif new_b == 2:
gp = 0 if len(seq) == 2 else seq[-3]
sib = par_node.right
sib_node = self._slots[sib - 1]
sib_bal = sib_node.balance
if sib_bal < 0:
left_child = sib_node.left
self._rot_right_left(par, sib, left_child)
seq.pop() # cur
seq.pop() # par
seq.append(left_child)
else:
self._rot_left(par, sib)
seq.pop(-2) # par
seq[-1] = sib
if gp != 0:
gp = self._slots[gp - 1]
if gp.left == par:
gp.left = u32(seq[-1])
else:
assert gp.right == par
gp.right = u32(seq[-1])
if sib_bal == 0:
break
else:
par_node.balance = i8(new_b)
if new_b != 0:
break
seq.pop()
if self._root != seq[0]:
self._root = seq[0]
def __setitem__(self, k: K, v: V):
def setter(node: _Node[K, V]):
node.value = v
self._get_set(
k,
setter,
lambda: v,
)
[docs]
def compute_if_absent(self, k: K, supplier: typing.Callable[[], V]) -> V:
"""
:returns: Value associated with `k` if it is present, otherwise get's new value from the supplier, stores it at `k` and returns
"""
res: list[V] = []
def existing(node: _Node[K, V]):
res.append(node.value)
self._get_set(
k,
existing,
supplier,
)
return res[0]
[docs]
def get_or_insert_default(self, k: K) -> V:
return self._get_set(
k,
lambda _k: None,
lambda: _NO_OBJ, # type: ignore
)
def _get_set(
self,
k: K,
exists: typing.Callable[[_Node[K, V]], None],
does_not_exist: typing.Callable[[], V],
) -> V:
seq, is_less = self._find_seq(k)
# exists
if seq[-1] != 0:
slot = self._slots[seq[-1] - 1]
exists(slot)
return slot.value
# patch root
if len(seq) == 1:
idx, cur_node = self._alloc_slot()
self._root = u32(idx + 1)
cur_node.__init__(k, does_not_exist())
return cur_node.value
# alloc new
new_idx, new_slot = self._alloc_slot()
if is_less:
self._slots[seq[-2] - 1].left = u32(new_idx + 1)
else:
self._slots[seq[-2] - 1].right = u32(new_idx + 1)
seq[-1] = new_idx + 1
new_slot.__init__(k, does_not_exist())
# rebalance
while len(seq) >= 2:
cur = seq[-1]
par = seq[-2]
par_node = self._slots[par - 1]
is_left = cur == par_node.left
# we inserted to null place, so we increaced it depth
delta = -1 if is_left else 1
new_b = par_node.balance + delta
if new_b == -2:
gp = 0 if len(seq) == 2 else seq[-3]
cur_node = self._slots[cur - 1]
if cur_node.balance > 0:
right_child = cur_node.right
self._rot_left_right(par, cur, right_child)
seq.pop() # cur
seq.pop() # par
seq.append(right_child)
else:
self._rot_right(par, cur)
seq.pop(-2) # par
if gp != 0:
gp = self._slots[gp - 1]
if gp.left == par:
gp.left = u32(seq[-1])
else:
gp.right = u32(seq[-1])
break
elif new_b == 2:
gp = 0 if len(seq) == 2 else seq[-3]
cur_node = self._slots[cur - 1]
if cur_node.balance < 0:
left_child = cur_node.left
self._rot_right_left(par, cur, left_child)
seq.pop() # cur
seq.pop() # par
seq.append(left_child)
else:
self._rot_left(par, cur)
seq.pop(-2) # par
if gp != 0:
gp = self._slots[gp - 1]
if gp.left == par:
gp.left = u32(seq[-1])
else:
gp.right = u32(seq[-1])
break
else:
par_node.balance = i8(new_b)
if new_b == 0:
break
seq.pop()
if self._root != seq[0]:
self._root = seq[0]
return new_slot.value
def _get_fn[T](
self,
k: K,
found: collections.abc.Callable[[_Node[K, V]], T],
not_found: collections.abc.Callable[[], T],
) -> T:
idx = self._root
while idx != 0:
_Node = self._slots[idx - 1]
if _Node.key == k:
return found(_Node)
if k < _Node.key:
idx = _Node.left
else:
idx = _Node.right
return not_found()
[docs]
def get[G](self, k: K, default: G = None) -> V | G:
"""
:returns: Value associated with `k` or `default` if there is no such value
"""
return self._get_fn(k, lambda n: n.value, lambda: default)
def __getitem__(self, k: K) -> V:
def not_found() -> V:
raise KeyError()
return self._get_fn(k, lambda x: x.value, not_found)
def __contains__(self, k: K) -> bool:
return self._get_fn(k, lambda _: True, lambda: False)
def _visit[T](
self, cb: collections.abc.Callable[[_Node[K, V]], T]
) -> typing.Generator[T, None, None]:
def go(idx) -> typing.Generator[T, None, None]:
if idx == 0:
return
slot = self._slots[idx - 1]
yield from go(slot.left)
yield cb(slot)
yield from go(slot.right)
yield from go(self._root)
[docs]
def __repr__(self) -> str:
import json
ret: list[str] = []
ret.append('{')
comma = False
for k, v in self.items():
if comma:
ret.append(',')
comma = True
ret.append(json.dumps(k))
ret.append(':')
ret.append(repr(v))
ret.append('}')
return ''.join(ret)
def __iter__(self):
yield from self._visit(lambda n: n.key)
[docs]
def items(self) -> collections.abc.ItemsView[K, V]:
return _ItemsView(self)
class _ItemsView[K: Comparable, V](collections.abc.ItemsView):
__slots__ = ('_parent',)
def __init__(self, parent: TreeMap[K, V]):
self._parent = parent
def __iter__(self):
yield from self._parent._visit(lambda n: (n.key, n.value))
def __contains__(self, item: object) -> bool:
return any(item == x for x in iter(self))
def __len__(self):
return len(self._parent)