stdlib/internal/types/collections/set.codon

pull/13/head
Ishak Numanagić 2022-01-24 09:39:50 +01:00
parent 2d707e35a2
commit 0b1b6a6450
1 changed files with 87 additions and 69 deletions

View File

@ -1,14 +1,18 @@
# set implementation based on klib's khash
# (c) 2022 Exaloop Inc. All rights reserved.
# set implementation based on klib's khash
from internal.attributes import commutative, associative
import internal.khash as khash
import internal.gc as gc
def _set_hash(key):
def _set_hash(key) -> int:
k = key.__hash__()
return (k >> 33) ^ k ^ (k << 11)
class Set[K]:
class Set:
K: type
_n_buckets: int
_size: int
_n_occupied: int
@ -17,8 +21,8 @@ class Set[K]:
_flags: Ptr[u32]
_keys: Ptr[K]
# Magic methods
def _init(self):
# Magic methods
def _init(self) -> void:
self._n_buckets = 0
self._size = 0
self._n_occupied = 0
@ -26,53 +30,53 @@ class Set[K]:
self._flags = Ptr[u32]()
self._keys = Ptr[K]()
def __init__(self):
def __init__(self) -> void:
self._init()
def __init__(self, g: Generator[K]):
def __init__(self, g: Generator[K]) -> void:
self._init()
for a in g:
self.add(a)
def __sub__(self, other: Set[K]):
def __sub__(self, other: Set[K]) -> Set[K]:
return self.difference(other)
def __isub__(self, other: Set[K]):
def __isub__(self, other: Set[K]) -> Set[K]:
self.difference_update(other)
return self
@commutative
@associative
def __and__(self, other: Set[K]):
def __and__(self, other: Set[K]) -> Set[K]:
return self.intersection(other)
def __iand__(self, other: Set[K]):
def __iand__(self, other: Set[K]) -> Set[K]:
self.intersection_update(other)
return self
@commutative
@associative
def __or__(self, other: Set[K]):
def __or__(self, other: Set[K]) -> Set[K]:
return self.union(other)
def __ior__(self, other: Set[K]):
def __ior__(self, other: Set[K]) -> Set[K]:
for a in other:
self.add(a)
return self
@commutative
@associative
def __xor__(self, other: Set[K]):
def __xor__(self, other: Set[K]) -> Set[K]:
return self.symmetric_difference(other)
def __ixor__(self, other: Set[K]):
def __ixor__(self, other: Set[K]) -> Set[K]:
self.symmetric_difference_update(other)
return self
def __contains__(self, key: K):
def __contains__(self, key: K) -> bool:
return self._kh_get(key) != self._kh_end()
def __eq__(self, other: Set[K]):
def __eq__(self, other: Set[K]) -> bool:
if self.__len__() != other.__len__():
return False
for k in self:
@ -80,35 +84,35 @@ class Set[K]:
return False
return True
def __ne__(self, other: Set[K]):
def __ne__(self, other: Set[K]) -> bool:
return not (self == other)
def __le__(self, other: Set[K]):
def __le__(self, other: Set[K]) -> bool:
return self.issubset(other)
def __ge__(self, other: Set[K]):
def __ge__(self, other: Set[K]) -> bool:
return self.issuperset(other)
def __lt__(self, other: Set[K]):
def __lt__(self, other: Set[K]) -> bool:
return self != other and self <= other
def __gt__(self, other: Set[K]):
def __gt__(self, other: Set[K]) -> bool:
return self != other and self >= other
def __iter__(self):
def __iter__(self) -> Generator[K]:
i = self._kh_begin()
while i < self._kh_end():
if self._kh_exist(i):
yield self._keys[i]
i += 1
def __len__(self):
def __len__(self) -> int:
return self._size
def __bool__(self):
def __bool__(self) -> bool:
return self.__len__() != 0
def __copy__(self):
def __copy__(self) -> Set[K]:
if self.__len__() == 0:
return Set[K]()
n = self._n_buckets
@ -117,12 +121,14 @@ class Set[K]:
keys_copy = Ptr[K](n)
str.memcpy(flags_copy.as_byte(), self._flags.as_byte(), f * gc.sizeof(u32))
str.memcpy(keys_copy.as_byte(), self._keys.as_byte(), n * gc.sizeof(K))
return Set[K](n, self._size, self._n_occupied, self._upper_bound, flags_copy, keys_copy)
return Set[K](
n, self._size, self._n_occupied, self._upper_bound, flags_copy, keys_copy
)
def __deepcopy__(self):
def __deepcopy__(self) -> Set[K]:
return {s.__deepcopy__() for s in self}
def __repr__(self):
def __repr__(self) -> str:
n = self.__len__()
if n == 0:
return "{}"
@ -139,50 +145,49 @@ class Set[K]:
lst.append("}")
return str.cat(lst)
# Helper methods
# Helper methods
def resize(self, new_n_buckets: int):
def resize(self, new_n_buckets: int) -> void:
self._kh_resize(new_n_buckets)
def add(self, key: K):
def add(self, key: K) -> void:
self._kh_put(key)
def update(self, other: Set[K]):
def update(self, other: Set[K]) -> void:
for k in other:
self.add(k)
def remove(self, key: K):
def remove(self, key: K) -> void:
x = self._kh_get(key)
if x != self._kh_end():
self._kh_del(x)
else:
raise KeyError(str(key))
def pop(self):
def pop(self) -> K:
if self.__len__() == 0:
raise ValueError("empty set")
for a in self:
self.remove(a)
return a
def discard(self, key: K):
def discard(self, key: K) -> void:
x = self._kh_get(key)
if x != self._kh_end():
self._kh_del(x)
def difference(self, other: Set[K]):
def difference(self, other: Set[K]) -> Set[K]:
s = Set[K]()
for a in self:
if a not in other:
s.add(a)
return s
def difference_update(self, other: Set[K]):
def difference_update(self, other: Set[K]) -> void:
for a in other:
self.discard(a)
def intersection(self, other: Set[K]):
def intersection(self, other: Set[K]) -> Set[K]:
if other.__len__() < self.__len__():
self, other = other, self
s = Set[K]()
@ -191,12 +196,12 @@ class Set[K]:
s.add(a)
return s
def intersection_update(self, other: Set[K]):
def intersection_update(self, other: Set[K]) -> void:
for a in self:
if a not in other:
self.discard(a)
def symmetric_difference(self, other: Set[K]):
def symmetric_difference(self, other: Set[K]) -> Set[K]:
s = Set[K]()
for a in self:
if a not in other:
@ -206,7 +211,7 @@ class Set[K]:
s.add(a)
return s
def symmetric_difference_update(self, other: Set[K]):
def symmetric_difference_update(self, other: Set[K]) -> void:
for a in other:
if a in self:
self.discard(a)
@ -214,16 +219,18 @@ class Set[K]:
if a in other:
self.discard(a)
def union(self, other: Set[K]):
def union(self, other: Set[K]) -> Set[K]:
s = Set[K]()
s.resize(self._n_buckets if self._n_buckets >= other._n_buckets else other._n_buckets)
s.resize(
self._n_buckets if self._n_buckets >= other._n_buckets else other._n_buckets
)
for a in self:
s.add(a)
for a in other:
s.add(a)
return s
def isdisjoint(self, other: Set[K]):
def isdisjoint(self, other: Set[K]) -> bool:
if other.__len__() < self.__len__():
self, other = other, self
for a in self:
@ -231,7 +238,7 @@ class Set[K]:
return False
return True
def issubset(self, other: Set[K]):
def issubset(self, other: Set[K]) -> bool:
if other.__len__() < self.__len__():
return False
for a in self:
@ -239,36 +246,37 @@ class Set[K]:
return False
return True
def issuperset(self, other: Set[K]):
def issuperset(self, other: Set[K]) -> bool:
return other.issubset(self)
def clear(self):
def clear(self) -> void:
self._kh_clear()
def copy(self):
def copy(self) -> Set[K]:
return self.__copy__()
# Internal helpers
# Internal helpers
def _kh_clear(self):
def _kh_clear(self) -> void:
if self._flags:
i = 0
n = khash.__ac_fsize(self._n_buckets)
while i < n:
self._flags[i] = u32(0xaaaaaaaa)
self._flags[i] = u32(0xAAAAAAAA)
i += 1
self._size = 0
self._n_occupied = 0
def _kh_get(self, key: K):
def _kh_get(self, key: K) -> int:
if self._n_buckets:
step = 0
mask = self._n_buckets - 1
k = _set_hash(key)
i = k & mask
last = i
while not khash.__ac_isempty(self._flags, i) and (khash.__ac_isdel(self._flags, i) or self._keys[i] != key):
while not khash.__ac_isempty(self._flags, i) and (
khash.__ac_isdel(self._flags, i) or self._keys[i] != key
):
step += 1
i = (i + step) & mask
if i == last:
@ -277,7 +285,7 @@ class Set[K]:
else:
return 0
def _kh_resize(self, new_n_buckets: int):
def _kh_resize(self, new_n_buckets: int) -> void:
HASH_UPPER = 0.77
new_flags = Ptr[u32]()
j = 1
@ -295,18 +303,20 @@ class Set[K]:
if new_n_buckets < 4:
new_n_buckets = 4
if self._size >= int(new_n_buckets*HASH_UPPER + 0.5):
if self._size >= int(new_n_buckets * HASH_UPPER + 0.5):
j = 0
else:
fsize = khash.__ac_fsize(new_n_buckets)
new_flags = Ptr[u32](fsize)
i = 0
while i < fsize:
new_flags[i] = u32(0xaaaaaaaa)
new_flags[i] = u32(0xAAAAAAAA)
i += 1
if self._n_buckets < new_n_buckets:
self._keys = Ptr[K](gc.realloc(self._keys.as_byte(), new_n_buckets * gc.sizeof(K)))
self._keys = Ptr[K](
gc.realloc(self._keys.as_byte(), new_n_buckets * gc.sizeof(K))
)
if j:
j = 0
@ -326,7 +336,10 @@ class Set[K]:
i = (i + step) & new_mask
khash.__ac_set_isempty_false(new_flags, i)
if i < self._n_buckets and khash.__ac_iseither(self._flags, i) == 0:
if (
i < self._n_buckets
and khash.__ac_iseither(self._flags, i) == 0
):
self._keys[i], key = key, self._keys[i]
khash.__ac_set_isdel_true(self._flags, i)
else:
@ -335,14 +348,16 @@ class Set[K]:
j += 1
if self._n_buckets > new_n_buckets:
self._keys = Ptr[K](gc.realloc(self._keys.as_byte(), new_n_buckets * gc.sizeof(K)))
self._keys = Ptr[K](
gc.realloc(self._keys.as_byte(), new_n_buckets * gc.sizeof(K))
)
self._flags = new_flags
self._n_buckets = new_n_buckets
self._n_occupied = self._size
self._upper_bound = int(self._n_buckets*HASH_UPPER + 0.5)
self._upper_bound = int(self._n_buckets * HASH_UPPER + 0.5)
def _kh_put(self, key: K):
def _kh_put(self, key: K) -> Tuple[int, int]:
if self._n_occupied >= self._upper_bound:
if self._n_buckets > (self._size << 1):
self._kh_resize(self._n_buckets - 1)
@ -359,7 +374,9 @@ class Set[K]:
x = i
else:
last = i
while not khash.__ac_isempty(self._flags, i) and (khash.__ac_isdel(self._flags, i) or self._keys[i] != key):
while not khash.__ac_isempty(self._flags, i) and (
khash.__ac_isdel(self._flags, i) or self._keys[i] != key
):
if khash.__ac_isdel(self._flags, i):
site = i
step += 1
@ -389,18 +406,19 @@ class Set[K]:
return (ret, x)
def _kh_del(self, x: int):
def _kh_del(self, x: int) -> void:
if x != self._n_buckets and not khash.__ac_iseither(self._flags, x):
khash.__ac_set_isdel_true(self._flags, x)
self._size -= 1
def _kh_begin(self):
def _kh_begin(self) -> int:
return 0
def _kh_end(self):
def _kh_end(self) -> int:
return self._n_buckets
def _kh_exist(self, x: int):
def _kh_exist(self, x: int) -> bool:
return not khash.__ac_iseither(self._flags, x)
set = Set