codon/stdlib/itertools.codon

759 lines
18 KiB
Python

# Copyright (C) 2022-2023 Exaloop Inc. <https://exaloop.io>
from internal.types.optional import unwrap
# Infinite iterators
@inline
def count(start: T = 0, step: T = 1, T: type) -> Generator[T]:
"""
Return a count object whose ``__next__`` method returns consecutive values.
"""
n = start
while True:
yield n
n += step
@inline
def cycle(iterable: Generator[T], T: type) -> Generator[T]:
"""
Cycles repeatedly through an iterable.
"""
saved = []
for element in iterable:
yield element
saved.append(element)
while saved:
for element in saved:
yield element
@inline
def repeat(object: T, times: Optional[int] = None, T: type) -> Generator[T]:
"""
Make an iterator that returns a given object over and over again.
"""
if times is None:
while True:
yield object
else:
for i in range(times):
yield object
# Iterators terminating on the shortest input sequence
@inline
def accumulate(iterable: Generator[T], func=lambda a, b: a + b, initial=0, T: type):
"""
Make an iterator that returns accumulated sums, or accumulated results
of other binary functions (specified via the optional func argument).
"""
total = initial
yield total
for element in iterable:
total = func(total, element)
yield total
@inline
@overload
def accumulate(iterable: Generator[T], func=lambda a, b: a + b, T: type):
"""
Make an iterator that returns accumulated sums, or accumulated results
of other binary functions (specified via the optional func argument).
"""
total: Optional[T] = None
for element in iterable:
total = element if total is None else func(unwrap(total), element)
yield unwrap(total)
@tuple
class chain:
"""
Make an iterator that returns elements from the first iterable until it is exhausted,
then proceeds to the next iterable, until all of the iterables are exhausted.
"""
@inline
def __new__(*iterables):
for it in iterables:
for element in it:
yield element
@inline
def from_iterable(iterables):
for it in iterables:
for element in it:
yield element
@inline
def compress(
data: Generator[T], selectors: Generator[B], T: type, B: type
) -> Generator[T]:
"""
Return data elements corresponding to true selector elements.
Forms a shorter iterator from selected data elements using the selectors to
choose the data elements.
"""
for d, s in zip(data, selectors):
if s:
yield d
@inline
def dropwhile(
predicate: Callable[[T], bool], iterable: Generator[T], T: type
) -> Generator[T]:
"""
Drop items from the iterable while predicate(item) is true.
Afterwards, return every element until the iterable is exhausted.
"""
b = False
for x in iterable:
if not b and not predicate(x):
b = True
if b:
yield x
@inline
def filterfalse(
predicate: Callable[[T], bool], iterable: Generator[T], T: type
) -> Generator[T]:
"""
Return those items of iterable for which function(item) is false.
"""
for x in iterable:
if not predicate(x):
yield x
# TODO: fix this once Optional[Callable] lands
@inline
def groupby(iterable, key=Optional[int]()):
"""
Make an iterator that returns consecutive keys and groups from the iterable.
"""
currkey = None
group = []
for currvalue in iterable:
k = currvalue if isinstance(key, Optional) else key(currvalue)
if currkey is None:
currkey = k
if k != unwrap(currkey):
yield unwrap(currkey), group
currkey = k
group = []
group.append(currvalue)
if currkey is not None:
yield unwrap(currkey), group
def islice(iterable: Generator[T], stop: Optional[int], T: type) -> Generator[T]:
"""
Make an iterator that returns selected elements from the iterable.
"""
if stop is not None and stop.__val__() < 0:
raise ValueError(
"Indices for islice() must be None or an integer: 0 <= x <= sys.maxsize."
)
i = 0
for x in iterable:
if stop is not None and i >= stop.__val__():
break
yield x
i += 1
@overload
def islice(
iterable: Generator[T],
start: Optional[int],
stop: Optional[int],
step: Optional[int] = None,
T: type,
) -> Generator[T]:
"""
Make an iterator that returns selected elements from the iterable.
"""
from sys import maxsize
start: int = 0 if start is None else start
stop: int = maxsize if stop is None else stop
step: int = 1 if step is None else step
have_stop = False
if start < 0 or stop < 0:
raise ValueError(
"Indices for islice() must be None or an integer: 0 <= x <= sys.maxsize."
)
elif step < 0:
raise ValueError("Step for islice() must be a positive integer or None.")
it = range(start, stop, step)
N = len(it)
idx = 0
b = -1
if N == 0:
for i, element in zip(range(start), iterable):
pass
return
nexti = it[0]
for i, element in enumerate(iterable):
if i == nexti:
yield element
idx += 1
if idx >= N:
b = i
break
nexti = it[idx]
if b >= 0:
for i, element in zip(range(b + 1, stop), iterable):
pass
@inline
def starmap(function, iterable):
"""
Return an iterator whose values are returned from the function
evaluated with an argument tuple taken from the given sequence.
"""
for args in iterable:
yield function(*args)
@inline
def takewhile(
predicate: Callable[[T], bool], iterable: Generator[T], T: type
) -> Generator[T]:
"""
Return successive entries from an iterable as long as the predicate evaluates to true for each entry.
"""
for x in iterable:
if predicate(x):
yield x
else:
break
def tee(iterable: Generator[T], n: int = 2, T: type) -> List[Generator[T]]:
"""
Return n independent iterators from a single iterable.
"""
from collections import deque
it = iter(iterable)
deques = [deque[T]() for i in range(n)]
def gen(mydeque: deque[T], T: type) -> Generator[T]:
while True:
if not mydeque: # when the local deque is empty
if it.__done__():
return
it.__resume__()
if it.__done__():
return
newval = it.next()
for d in deques: # load it to all the deques
d.append(newval)
yield mydeque.popleft()
return [gen(d) for d in deques]
@inline
def zip_longest(*iterables, fillvalue):
"""
Make an iterator that aggregates elements from each of the iterables.
If the iterables are of uneven length, missing values are filled-in
with fillvalue. Iteration continues until the longest iterable is
exhausted.
"""
if staticlen(iterables) == 2:
a = iter(iterables[0])
b = iter(iterables[1])
a_done = False
b_done = False
while not a.done():
a_val = a.next()
b_val = fillvalue
if not b_done:
b_done = b.done()
if not b_done:
b_val = b.next()
yield a_val, b_val
if not b_done:
while not b.done():
yield fillvalue, b.next()
a.destroy()
b.destroy()
else:
iterators = tuple(iter(it) for it in iterables)
num_active = len(iterators)
if not num_active:
return
while True:
values = []
for it in iterators:
if it.__done__(): # already done
values.append(fillvalue)
elif it.done(): # resume and check
num_active -= 1
if not num_active:
return
values.append(fillvalue)
else:
values.append(it.next())
yield values
@inline
@overload
def zip_longest(*args):
"""
Make an iterator that aggregates elements from each of the iterables.
If the iterables are of uneven length, missing values are filled-in
with fillvalue. Iteration continues until the longest iterable is
exhausted.
"""
def get_next(it):
if it.__done__() or it.done():
return None
return it.next()
iters = tuple(iter(arg) for arg in args)
while True:
done_count = 0
result = tuple(get_next(it) for it in iters)
all_none = True
for a in result:
if a is not None:
all_none = False
if all_none:
return
yield result
for it in iters:
it.destroy()
# Combinatoric iterators
def _as_list(x):
if isinstance(x, list):
return x
else:
return list(x)
def product(*iterables, repeat: int):
if repeat < 0:
raise ValueError("repeat must be non-negative")
if repeat == 0:
nargs = 0
else:
nargs = len(iterables)
npools = nargs * repeat
indices = Ptr[int](npools)
pools = list(capacity=npools)
i = 0
while i < nargs:
p = _as_list(iterables[i])
if len(p) == 0:
return
pools.append(p)
indices[i] = 0
i += 1
while i < npools:
pools.append(pools[i - nargs])
indices[i] = 0
i += 1
result = list(capacity=npools)
for i in range(npools):
result.append(pools[i][0])
while True:
yield result
result = result.copy()
i = npools - 1
while i >= 0:
pool = pools[i]
indices[i] += 1
if indices[i] == len(pool):
indices[i] = 0
result[i] = pool[0]
else:
result[i] = pool[indices[i]]
break
i -= 1
if i < 0:
break
@overload
def product(*iterables, repeat: Static[int] = 1):
if repeat < 0:
compile_error("repeat must be non-negative")
# handle some common cases
if repeat == 0:
yield ()
elif repeat == 1 and staticlen(iterables) == 1:
it0 = iterables[0]
for a in it0:
yield (a,)
elif repeat == 1 and staticlen(iterables) == 2:
it0 = iterables[0]
it1 = iterables[1]
for a in it0:
for b in it1:
yield (a, b)
elif repeat == 1 and staticlen(iterables) == 3:
it0 = iterables[0]
it1 = iterables[1]
it2 = iterables[2]
for a in it0:
for b in it1:
for c in it2:
yield (a, b, c)
else:
nargs: Static[int] = staticlen(iterables)
npools: Static[int] = nargs * repeat
indices_tuple = (0,) * npools
indices = Ptr[int](__ptr__(indices_tuple).as_byte())
pools = tuple(_as_list(it) for it in iterables) * repeat
for i in staticrange(nargs):
if len(pools[i]) == 0:
return
result = tuple(pool[0] for pool in pools)
while True:
yield result
i = npools - 1
while i >= 0:
pool = pools[i]
indices[i] += 1
if indices[i] == len(pool):
indices[i] = 0
else:
break
i -= 1
if i < 0:
break
result = tuple(pools[i][indices[i]] for i in staticrange(npools))
def combinations(pool, r: int):
if r < 0:
raise ValueError("r must be non-negative")
pool_list = _as_list(pool)
n = len(pool)
if r > n:
return
pool = pool_list.arr.ptr
indices = Ptr[int](r)
result = list(capacity=r)
for i in range(r):
indices[i] = i
result.append(pool[i])
while True:
yield result
i = r - 1
while i >= 0 and indices[i] == i + n - r:
i -= 1
if i < 0:
break
indices[i] += 1
for j in range(i + 1, r):
indices[j] = indices[j-1] + 1
result = result.copy()
while i < r:
result[i] = pool[indices[i]]
i += 1
@overload
def combinations(pool, r: Static[int]):
def empty(T: type) -> T:
pass
if r < 0:
compile_error("r must be non-negative")
if isinstance(pool, list):
pool_list = pool
else:
pool_list = list(pool)
n = len(pool)
if r > n:
return
pool = pool_list.arr.ptr
indices_tuple = (0,) * r
indices = Ptr[int](__ptr__(indices_tuple).as_byte())
result_tuple = (empty(pool.T),) * r
result = Ptr[pool.T](__ptr__(result_tuple).as_byte())
for i in range(r):
indices[i] = i
result[i] = pool[i]
while True:
yield result_tuple
i = r - 1
while i >= 0 and indices[i] == i + n - r:
i -= 1
if i < 0:
break
indices[i] += 1
for j in range(i + 1, r):
indices[j] = indices[j-1] + 1
while i < r:
result[i] = pool[indices[i]]
i += 1
def combinations_with_replacement(pool, r: int):
if r < 0:
raise ValueError("r must be non-negative")
pool_list = _as_list(pool)
n = len(pool)
if n == 0:
if r == 0:
yield List[pool_list.T](capacity=0)
return
pool = pool_list.arr.ptr
indices = Ptr[int](r)
result = list(capacity=r)
for i in range(r):
indices[i] = 0
result.append(pool[0])
while True:
yield result
i = r - 1
while i >= 0 and indices[i] == n - 1:
i -= 1
if i < 0:
break
result = result.copy()
index = indices[i] + 1
elem = pool[index]
while i < r:
indices[i] = index
result[i] = elem
i += 1
@overload
def combinations_with_replacement(pool, r: Static[int]):
def empty(T: type) -> T:
pass
if r < 0:
compile_error("r must be non-negative")
if r == 0:
yield ()
return
if isinstance(pool, list):
pool_list = pool
else:
pool_list = list(pool)
n = len(pool)
if n == 0:
return
pool = pool_list.arr.ptr
indices_tuple = (0,) * r
indices = Ptr[int](__ptr__(indices_tuple).as_byte())
result_tuple = (empty(pool.T),) * r
result = Ptr[pool.T](__ptr__(result_tuple).as_byte())
for i in range(r):
result[i] = pool[0]
while True:
yield result_tuple
i = r - 1
while i >= 0 and indices[i] == n - 1:
i -= 1
if i < 0:
break
index = indices[i] + 1
elem = pool[index]
while i < r:
indices[i] = index
result[i] = elem
i += 1
def _permutations_non_static(pool, r = None):
pool_list = _as_list(pool)
n = len(pool)
if r is None:
return _permutations_non_static(pool_list, n)
elif not isinstance(r, int):
compile_error("Expected int as r")
if r < 0:
raise ValueError("r must be non-negative")
if r > n:
return
indices = Ptr[int](n)
cycles = Ptr[int](r)
for i in range(n):
indices[i] = i
for i in range(r):
cycles[i] = n - i
pool = pool_list.arr.ptr
result = list(capacity=r)
for i in range(r):
result.append(pool[i])
while True:
yield result
if n == 0:
break
result = result.copy()
i = r - 1
while i >= 0:
cycles[i] -= 1
if cycles[i] == 0:
index = indices[i]
for j in range(i, n - 1):
indices[j] = indices[j+1]
indices[n-1] = index
cycles[i] = n - i
else:
j = cycles[i]
index = indices[i]
indices[i] = indices[n - j]
indices[n - j] = index
for k in range(i, r):
index = indices[k]
result[k] = pool[index]
break
i -= 1
if i < 0:
break
def _permutations_static(pool, r: Static[int]):
def empty(T: type) -> T:
pass
pool_list = _as_list(pool)
n = len(pool)
if r < 0:
raise compile_error("r must be non-negative")
if r > n:
return
indices = Ptr[int](n)
cycles_tuple = (0,) * r
cycles = Ptr[int](__ptr__(cycles_tuple).as_byte())
for i in range(n):
indices[i] = i
for i in range(r):
cycles[i] = n - i
pool = pool_list.arr.ptr
result_tuple = (empty(pool.T),) * r
result = Ptr[pool.T](__ptr__(result_tuple).as_byte())
for i in range(r):
result[i] = pool[i]
while True:
yield result_tuple
if n == 0:
break
i = r - 1
while i >= 0:
cycles[i] -= 1
if cycles[i] == 0:
index = indices[i]
for j in range(i, n - 1):
indices[j] = indices[j+1]
indices[n-1] = index
cycles[i] = n - i
else:
j = cycles[i]
index = indices[i]
indices[i] = indices[n - j]
indices[n - j] = index
for k in range(i, r):
index = indices[k]
result[k] = pool[index]
break
i -= 1
if i < 0:
break
def permutations(pool, r = None):
if isinstance(pool, Tuple) and r is None:
return _permutations_static(pool, staticlen(pool))
else:
return _permutations_non_static(pool, r)
@overload
def permutations(pool, r: Static[int]):
return _permutations_static(pool, r)