codon/test/core/trees.codon

247 lines
6.4 KiB
Python

import sys
def max(a, b):
return a if a>=b else b
class AVLNode[K,V]:
key: K
value: V
parent: Optional[AVLNode[K,V]]
left: Optional[AVLNode[K,V]]
right: Optional[AVLNode[K,V]]
height: int
def __init__(self: AVLNode[K,V], parent: Optional[AVLNode[K,V]], k: K, v: V):
self.key = k
self.value = v
self.parent = parent
self.left = None
self.right = None
self.height = -1
def find(self, k):
if k < self.key:
if self.left is None:
return None
else:
return self.left.find(k)
elif k == self.key:
return self
else:
if self.right is None:
return None
else:
return self.right.find(k)
def find_min(self):
current: Optional[AVLNode[K, V]] = self
while current.left is not None:
current = current.left
return current
def next_larger(self):
if self.right is not None:
return self.right.find_min()
current: Optional[AVLNode[K, V]] = self
while current.parent is not None and current is current.parent.right:
current = current.parent
return current.parent
def insert(self, node):
if node is None:
return
if node.key < self.key:
if self.left is None:
node.parent = self
self.left = node
else:
self.left.insert(node)
else:
if self.right is None:
node.parent = self
self.right = node
else:
self.right.insert(node)
def delete(self):
if self.left is None or self.right is None:
if self is self.parent.left:
self.parent.left = self.left if self.left else self.right
if self.parent.left is not None:
self.parent.left.parent = self.parent
else:
self.parent.right = self.left if self.left else self.right
if self.parent.right is not None:
self.parent.right.parent = self.parent
return self
else:
s = self.next_larger()
self.key, s.key = s.key, self.key
return s.delete()
def __iter__(self: AVLNode[K,V]):
if self.left:
for i in self.left:
yield i
yield self
if self.right:
for i in self.right:
yield i
def height(node):
if node is None:
return -1
else:
return node.height
def update_height(node):
node.height = max(height(node.left), height(node.right)) + 1
class AVL[K,V]:
root: Optional[AVLNode[K,V]]
def __init__(self: AVL[K,V]):
self.root = None
def find(self, k):
if not self.root:
return None
return self.root.find(k)
def find_min(self):
if not self.root:
return None
return self.root.find_min()
def next_larger(self, k):
node = self.find(k)
return node.next_larger() if node else None
def left_rotate(self, x):
y = x.right
y.parent = x.parent
if y.parent is None:
self.root = y
else:
if y.parent.left is x:
y.parent.left = y
elif y.parent.right is x:
y.parent.right = y
x.right = y.left
if x.right is not None:
x.right.parent = x
y.left = x
x.parent = y
update_height(x)
update_height(y)
def right_rotate(self, x):
y = x.left
y.parent = x.parent
if y.parent is None:
self.root = y
else:
if y.parent.left is x:
y.parent.left = y
elif y.parent.right is x:
y.parent.right = y
x.left = y.right
if x.left is not None:
x.left.parent = x
y.right = x
x.parent = y
update_height(x)
update_height(y)
def rebalance(self, node: Optional[AVLNode[K, V]]):
while node is not None:
update_height(node)
if height(node.left) >= 2 + height(node.right):
if height(node.left.left) >= height(node.left.right):
self.right_rotate(node)
else:
self.left_rotate(node.left)
self.right_rotate(node)
elif height(node.right) >= 2 + height(node.left):
if height(node.right.right) >= height(node.right.left):
self.left_rotate(node)
else:
self.right_rotate(node.right)
self.left_rotate(node)
node = node.parent
def insert(self, k, v):
node = AVLNode[K,V](None, k, v)
if self.root is None:
# The root's parent is None.
self.root = node
else:
self.root.insert(node)
self.rebalance(node)
def delete(self, k):
node = self.find(k)
if node is None:
return
deleted = None
if node is self.root:
pseudoroot = AVLNode[K,V](None, 0, 0)
pseudoroot.left = self.root
self.root.parent = pseudoroot
deleted = self.root.delete()
self.root = pseudoroot.left
if self.root is not None:
self.root.parent = None
else:
deleted = node.delete()
self.rebalance(deleted.parent)
def __setitem__(self: AVL[K,V], k: K, v: V):
self.insert(k, v)
def __getitem__(self: AVL[K,V], k: K):
nd = self.find(k)
if not nd:
print 'whoops', k, 'not found'
sys.exit(1)
return nd.value
def __delitem__(self: AVL[K,V], k: K):
self.delete(k)
def __contains__(self: AVL[K,V], k: K):
return self.find(k) is not None
def __iter__(self: AVL[K,V]):
if self.root:
for i in self.root:
yield i.key, i.value
d1 = AVL[int,int]()
for a in range(5):
d1[a] = a*a
# EXPECT: 0
# EXPECT: 1
# EXPECT: 4
# EXPECT: 9
# EXPECT: 16
for a in range(5):
print d1[a]
print 2 in d1 # EXPECT: True
del d1[2]
print 2 in d1 # EXPECT: False
d1[2] = 44
print 2 in d1 # EXPECT: True
print d1[2] # EXPECT: 44
del d1[3]
del d1[4]
# EXPECT: 0 0
# EXPECT: 1 1
# EXPECT: 2 44
for t in d1:
print t[0], t[1]