mirror of https://github.com/exaloop/codon
stdlib/internal/types/collections/set.codon
parent
2d707e35a2
commit
0b1b6a6450
|
@ -1,14 +1,18 @@
|
|||
# set implementation based on klib's khash
|
||||
# (c) 2022 Exaloop Inc. All rights reserved.
|
||||
|
||||
# set implementation based on klib's khash
|
||||
from internal.attributes import commutative, associative
|
||||
import internal.khash as khash
|
||||
import internal.gc as gc
|
||||
|
||||
def _set_hash(key):
|
||||
|
||||
def _set_hash(key) -> int:
|
||||
k = key.__hash__()
|
||||
return (k >> 33) ^ k ^ (k << 11)
|
||||
|
||||
class Set[K]:
|
||||
|
||||
class Set:
|
||||
K: type
|
||||
_n_buckets: int
|
||||
_size: int
|
||||
_n_occupied: int
|
||||
|
@ -18,7 +22,7 @@ class Set[K]:
|
|||
_keys: Ptr[K]
|
||||
|
||||
# Magic methods
|
||||
def _init(self):
|
||||
def _init(self) -> void:
|
||||
self._n_buckets = 0
|
||||
self._size = 0
|
||||
self._n_occupied = 0
|
||||
|
@ -26,53 +30,53 @@ class Set[K]:
|
|||
self._flags = Ptr[u32]()
|
||||
self._keys = Ptr[K]()
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> void:
|
||||
self._init()
|
||||
|
||||
def __init__(self, g: Generator[K]):
|
||||
def __init__(self, g: Generator[K]) -> void:
|
||||
self._init()
|
||||
for a in g:
|
||||
self.add(a)
|
||||
|
||||
def __sub__(self, other: Set[K]):
|
||||
def __sub__(self, other: Set[K]) -> Set[K]:
|
||||
return self.difference(other)
|
||||
|
||||
def __isub__(self, other: Set[K]):
|
||||
def __isub__(self, other: Set[K]) -> Set[K]:
|
||||
self.difference_update(other)
|
||||
return self
|
||||
|
||||
@commutative
|
||||
@associative
|
||||
def __and__(self, other: Set[K]):
|
||||
def __and__(self, other: Set[K]) -> Set[K]:
|
||||
return self.intersection(other)
|
||||
|
||||
def __iand__(self, other: Set[K]):
|
||||
def __iand__(self, other: Set[K]) -> Set[K]:
|
||||
self.intersection_update(other)
|
||||
return self
|
||||
|
||||
@commutative
|
||||
@associative
|
||||
def __or__(self, other: Set[K]):
|
||||
def __or__(self, other: Set[K]) -> Set[K]:
|
||||
return self.union(other)
|
||||
|
||||
def __ior__(self, other: Set[K]):
|
||||
def __ior__(self, other: Set[K]) -> Set[K]:
|
||||
for a in other:
|
||||
self.add(a)
|
||||
return self
|
||||
|
||||
@commutative
|
||||
@associative
|
||||
def __xor__(self, other: Set[K]):
|
||||
def __xor__(self, other: Set[K]) -> Set[K]:
|
||||
return self.symmetric_difference(other)
|
||||
|
||||
def __ixor__(self, other: Set[K]):
|
||||
def __ixor__(self, other: Set[K]) -> Set[K]:
|
||||
self.symmetric_difference_update(other)
|
||||
return self
|
||||
|
||||
def __contains__(self, key: K):
|
||||
def __contains__(self, key: K) -> bool:
|
||||
return self._kh_get(key) != self._kh_end()
|
||||
|
||||
def __eq__(self, other: Set[K]):
|
||||
def __eq__(self, other: Set[K]) -> bool:
|
||||
if self.__len__() != other.__len__():
|
||||
return False
|
||||
for k in self:
|
||||
|
@ -80,35 +84,35 @@ class Set[K]:
|
|||
return False
|
||||
return True
|
||||
|
||||
def __ne__(self, other: Set[K]):
|
||||
def __ne__(self, other: Set[K]) -> bool:
|
||||
return not (self == other)
|
||||
|
||||
def __le__(self, other: Set[K]):
|
||||
def __le__(self, other: Set[K]) -> bool:
|
||||
return self.issubset(other)
|
||||
|
||||
def __ge__(self, other: Set[K]):
|
||||
def __ge__(self, other: Set[K]) -> bool:
|
||||
return self.issuperset(other)
|
||||
|
||||
def __lt__(self, other: Set[K]):
|
||||
def __lt__(self, other: Set[K]) -> bool:
|
||||
return self != other and self <= other
|
||||
|
||||
def __gt__(self, other: Set[K]):
|
||||
def __gt__(self, other: Set[K]) -> bool:
|
||||
return self != other and self >= other
|
||||
|
||||
def __iter__(self):
|
||||
def __iter__(self) -> Generator[K]:
|
||||
i = self._kh_begin()
|
||||
while i < self._kh_end():
|
||||
if self._kh_exist(i):
|
||||
yield self._keys[i]
|
||||
i += 1
|
||||
|
||||
def __len__(self):
|
||||
def __len__(self) -> int:
|
||||
return self._size
|
||||
|
||||
def __bool__(self):
|
||||
def __bool__(self) -> bool:
|
||||
return self.__len__() != 0
|
||||
|
||||
def __copy__(self):
|
||||
def __copy__(self) -> Set[K]:
|
||||
if self.__len__() == 0:
|
||||
return Set[K]()
|
||||
n = self._n_buckets
|
||||
|
@ -117,12 +121,14 @@ class Set[K]:
|
|||
keys_copy = Ptr[K](n)
|
||||
str.memcpy(flags_copy.as_byte(), self._flags.as_byte(), f * gc.sizeof(u32))
|
||||
str.memcpy(keys_copy.as_byte(), self._keys.as_byte(), n * gc.sizeof(K))
|
||||
return Set[K](n, self._size, self._n_occupied, self._upper_bound, flags_copy, keys_copy)
|
||||
return Set[K](
|
||||
n, self._size, self._n_occupied, self._upper_bound, flags_copy, keys_copy
|
||||
)
|
||||
|
||||
def __deepcopy__(self):
|
||||
def __deepcopy__(self) -> Set[K]:
|
||||
return {s.__deepcopy__() for s in self}
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
n = self.__len__()
|
||||
if n == 0:
|
||||
return "{}"
|
||||
|
@ -139,50 +145,49 @@ class Set[K]:
|
|||
lst.append("}")
|
||||
return str.cat(lst)
|
||||
|
||||
|
||||
# Helper methods
|
||||
|
||||
def resize(self, new_n_buckets: int):
|
||||
def resize(self, new_n_buckets: int) -> void:
|
||||
self._kh_resize(new_n_buckets)
|
||||
|
||||
def add(self, key: K):
|
||||
def add(self, key: K) -> void:
|
||||
self._kh_put(key)
|
||||
|
||||
def update(self, other: Set[K]):
|
||||
def update(self, other: Set[K]) -> void:
|
||||
for k in other:
|
||||
self.add(k)
|
||||
|
||||
def remove(self, key: K):
|
||||
def remove(self, key: K) -> void:
|
||||
x = self._kh_get(key)
|
||||
if x != self._kh_end():
|
||||
self._kh_del(x)
|
||||
else:
|
||||
raise KeyError(str(key))
|
||||
|
||||
def pop(self):
|
||||
def pop(self) -> K:
|
||||
if self.__len__() == 0:
|
||||
raise ValueError("empty set")
|
||||
for a in self:
|
||||
self.remove(a)
|
||||
return a
|
||||
|
||||
def discard(self, key: K):
|
||||
def discard(self, key: K) -> void:
|
||||
x = self._kh_get(key)
|
||||
if x != self._kh_end():
|
||||
self._kh_del(x)
|
||||
|
||||
def difference(self, other: Set[K]):
|
||||
def difference(self, other: Set[K]) -> Set[K]:
|
||||
s = Set[K]()
|
||||
for a in self:
|
||||
if a not in other:
|
||||
s.add(a)
|
||||
return s
|
||||
|
||||
def difference_update(self, other: Set[K]):
|
||||
def difference_update(self, other: Set[K]) -> void:
|
||||
for a in other:
|
||||
self.discard(a)
|
||||
|
||||
def intersection(self, other: Set[K]):
|
||||
def intersection(self, other: Set[K]) -> Set[K]:
|
||||
if other.__len__() < self.__len__():
|
||||
self, other = other, self
|
||||
s = Set[K]()
|
||||
|
@ -191,12 +196,12 @@ class Set[K]:
|
|||
s.add(a)
|
||||
return s
|
||||
|
||||
def intersection_update(self, other: Set[K]):
|
||||
def intersection_update(self, other: Set[K]) -> void:
|
||||
for a in self:
|
||||
if a not in other:
|
||||
self.discard(a)
|
||||
|
||||
def symmetric_difference(self, other: Set[K]):
|
||||
def symmetric_difference(self, other: Set[K]) -> Set[K]:
|
||||
s = Set[K]()
|
||||
for a in self:
|
||||
if a not in other:
|
||||
|
@ -206,7 +211,7 @@ class Set[K]:
|
|||
s.add(a)
|
||||
return s
|
||||
|
||||
def symmetric_difference_update(self, other: Set[K]):
|
||||
def symmetric_difference_update(self, other: Set[K]) -> void:
|
||||
for a in other:
|
||||
if a in self:
|
||||
self.discard(a)
|
||||
|
@ -214,16 +219,18 @@ class Set[K]:
|
|||
if a in other:
|
||||
self.discard(a)
|
||||
|
||||
def union(self, other: Set[K]):
|
||||
def union(self, other: Set[K]) -> Set[K]:
|
||||
s = Set[K]()
|
||||
s.resize(self._n_buckets if self._n_buckets >= other._n_buckets else other._n_buckets)
|
||||
s.resize(
|
||||
self._n_buckets if self._n_buckets >= other._n_buckets else other._n_buckets
|
||||
)
|
||||
for a in self:
|
||||
s.add(a)
|
||||
for a in other:
|
||||
s.add(a)
|
||||
return s
|
||||
|
||||
def isdisjoint(self, other: Set[K]):
|
||||
def isdisjoint(self, other: Set[K]) -> bool:
|
||||
if other.__len__() < self.__len__():
|
||||
self, other = other, self
|
||||
for a in self:
|
||||
|
@ -231,7 +238,7 @@ class Set[K]:
|
|||
return False
|
||||
return True
|
||||
|
||||
def issubset(self, other: Set[K]):
|
||||
def issubset(self, other: Set[K]) -> bool:
|
||||
if other.__len__() < self.__len__():
|
||||
return False
|
||||
for a in self:
|
||||
|
@ -239,36 +246,37 @@ class Set[K]:
|
|||
return False
|
||||
return True
|
||||
|
||||
def issuperset(self, other: Set[K]):
|
||||
def issuperset(self, other: Set[K]) -> bool:
|
||||
return other.issubset(self)
|
||||
|
||||
def clear(self):
|
||||
def clear(self) -> void:
|
||||
self._kh_clear()
|
||||
|
||||
def copy(self):
|
||||
def copy(self) -> Set[K]:
|
||||
return self.__copy__()
|
||||
|
||||
|
||||
# Internal helpers
|
||||
|
||||
def _kh_clear(self):
|
||||
def _kh_clear(self) -> void:
|
||||
if self._flags:
|
||||
i = 0
|
||||
n = khash.__ac_fsize(self._n_buckets)
|
||||
while i < n:
|
||||
self._flags[i] = u32(0xaaaaaaaa)
|
||||
self._flags[i] = u32(0xAAAAAAAA)
|
||||
i += 1
|
||||
self._size = 0
|
||||
self._n_occupied = 0
|
||||
|
||||
def _kh_get(self, key: K):
|
||||
def _kh_get(self, key: K) -> int:
|
||||
if self._n_buckets:
|
||||
step = 0
|
||||
mask = self._n_buckets - 1
|
||||
k = _set_hash(key)
|
||||
i = k & mask
|
||||
last = i
|
||||
while not khash.__ac_isempty(self._flags, i) and (khash.__ac_isdel(self._flags, i) or self._keys[i] != key):
|
||||
while not khash.__ac_isempty(self._flags, i) and (
|
||||
khash.__ac_isdel(self._flags, i) or self._keys[i] != key
|
||||
):
|
||||
step += 1
|
||||
i = (i + step) & mask
|
||||
if i == last:
|
||||
|
@ -277,7 +285,7 @@ class Set[K]:
|
|||
else:
|
||||
return 0
|
||||
|
||||
def _kh_resize(self, new_n_buckets: int):
|
||||
def _kh_resize(self, new_n_buckets: int) -> void:
|
||||
HASH_UPPER = 0.77
|
||||
new_flags = Ptr[u32]()
|
||||
j = 1
|
||||
|
@ -302,11 +310,13 @@ class Set[K]:
|
|||
new_flags = Ptr[u32](fsize)
|
||||
i = 0
|
||||
while i < fsize:
|
||||
new_flags[i] = u32(0xaaaaaaaa)
|
||||
new_flags[i] = u32(0xAAAAAAAA)
|
||||
i += 1
|
||||
|
||||
if self._n_buckets < new_n_buckets:
|
||||
self._keys = Ptr[K](gc.realloc(self._keys.as_byte(), new_n_buckets * gc.sizeof(K)))
|
||||
self._keys = Ptr[K](
|
||||
gc.realloc(self._keys.as_byte(), new_n_buckets * gc.sizeof(K))
|
||||
)
|
||||
|
||||
if j:
|
||||
j = 0
|
||||
|
@ -326,7 +336,10 @@ class Set[K]:
|
|||
i = (i + step) & new_mask
|
||||
|
||||
khash.__ac_set_isempty_false(new_flags, i)
|
||||
if i < self._n_buckets and khash.__ac_iseither(self._flags, i) == 0:
|
||||
if (
|
||||
i < self._n_buckets
|
||||
and khash.__ac_iseither(self._flags, i) == 0
|
||||
):
|
||||
self._keys[i], key = key, self._keys[i]
|
||||
khash.__ac_set_isdel_true(self._flags, i)
|
||||
else:
|
||||
|
@ -335,14 +348,16 @@ class Set[K]:
|
|||
j += 1
|
||||
|
||||
if self._n_buckets > new_n_buckets:
|
||||
self._keys = Ptr[K](gc.realloc(self._keys.as_byte(), new_n_buckets * gc.sizeof(K)))
|
||||
self._keys = Ptr[K](
|
||||
gc.realloc(self._keys.as_byte(), new_n_buckets * gc.sizeof(K))
|
||||
)
|
||||
|
||||
self._flags = new_flags
|
||||
self._n_buckets = new_n_buckets
|
||||
self._n_occupied = self._size
|
||||
self._upper_bound = int(self._n_buckets * HASH_UPPER + 0.5)
|
||||
|
||||
def _kh_put(self, key: K):
|
||||
def _kh_put(self, key: K) -> Tuple[int, int]:
|
||||
if self._n_occupied >= self._upper_bound:
|
||||
if self._n_buckets > (self._size << 1):
|
||||
self._kh_resize(self._n_buckets - 1)
|
||||
|
@ -359,7 +374,9 @@ class Set[K]:
|
|||
x = i
|
||||
else:
|
||||
last = i
|
||||
while not khash.__ac_isempty(self._flags, i) and (khash.__ac_isdel(self._flags, i) or self._keys[i] != key):
|
||||
while not khash.__ac_isempty(self._flags, i) and (
|
||||
khash.__ac_isdel(self._flags, i) or self._keys[i] != key
|
||||
):
|
||||
if khash.__ac_isdel(self._flags, i):
|
||||
site = i
|
||||
step += 1
|
||||
|
@ -389,18 +406,19 @@ class Set[K]:
|
|||
|
||||
return (ret, x)
|
||||
|
||||
def _kh_del(self, x: int):
|
||||
def _kh_del(self, x: int) -> void:
|
||||
if x != self._n_buckets and not khash.__ac_iseither(self._flags, x):
|
||||
khash.__ac_set_isdel_true(self._flags, x)
|
||||
self._size -= 1
|
||||
|
||||
def _kh_begin(self):
|
||||
def _kh_begin(self) -> int:
|
||||
return 0
|
||||
|
||||
def _kh_end(self):
|
||||
def _kh_end(self) -> int:
|
||||
return self._n_buckets
|
||||
|
||||
def _kh_exist(self, x: int):
|
||||
def _kh_exist(self, x: int) -> bool:
|
||||
return not khash.__ac_iseither(self._flags, x)
|
||||
|
||||
|
||||
set = Set
|
||||
|
|
Loading…
Reference in New Issue