codon/stdlib/itertools.codon

473 lines
13 KiB
Python

# Copyright (C) 2022 Exaloop Inc. <https://exaloop.io>
from internal.types.optional import unwrap
# Infinite iterators
@inline
def count(start: T = 0, step: T = 1, T: type) -> Generator[T]:
"""
Return a count object whose ``__next__`` method returns consecutive values.
"""
n = start
while True:
yield n
n += step
@inline
def cycle(iterable: Generator[T], T: type) -> Generator[T]:
"""
Cycles repeatedly through an iterable.
"""
saved = []
for element in iterable:
yield element
saved.append(element)
while saved:
for element in saved:
yield element
@inline
def repeat(object: T, times: Optional[int] = None, T: type) -> Generator[T]:
"""
Make an iterator that returns a given object over and over again.
"""
if times is None:
while True:
yield object
else:
for i in range(times):
yield object
# Iterators terminating on the shortest input sequence
@inline
def accumulate(iterable: Generator[T], func=lambda a, b: a + b, initial=0, T: type):
"""
Make an iterator that returns accumulated sums, or accumulated results
of other binary functions (specified via the optional func argument).
"""
total = initial
yield total
for element in iterable:
total = func(total, element)
yield total
@inline
@overload
def accumulate(iterable: Generator[T], func=lambda a, b: a + b, T: type):
"""
Make an iterator that returns accumulated sums, or accumulated results
of other binary functions (specified via the optional func argument).
"""
total = None
for element in iterable:
total = element if total is None else func(unwrap(total), element)
yield unwrap(total)
@tuple
class chain:
"""
Make an iterator that returns elements from the first iterable until it is exhausted,
then proceeds to the next iterable, until all of the iterables are exhausted.
"""
@inline
def __new__(*iterables):
for it in iterables:
for element in it:
yield element
@inline
def from_iterable(iterables):
for it in iterables:
for element in it:
yield element
@inline
def compress(
data: Generator[T], selectors: Generator[B], T: type, B: type
) -> Generator[T]:
"""
Return data elements corresponding to true selector elements.
Forms a shorter iterator from selected data elements using the selectors to
choose the data elements.
"""
for d, s in zip(data, selectors):
if s:
yield d
@inline
def dropwhile(
predicate: Callable[[T], bool], iterable: Generator[T], T: type
) -> Generator[T]:
"""
Drop items from the iterable while predicate(item) is true.
Afterwards, return every element until the iterable is exhausted.
"""
b = False
for x in iterable:
if not b and not predicate(x):
b = True
if b:
yield x
@inline
def filterfalse(
predicate: Callable[[T], bool], iterable: Generator[T], T: type
) -> Generator[T]:
"""
Return those items of iterable for which function(item) is false.
"""
for x in iterable:
if not predicate(x):
yield x
# TODO: fix this once Optional[Callable] lands
@inline
def groupby(iterable, key=Optional[int]()):
"""
Make an iterator that returns consecutive keys and groups from the iterable.
"""
currkey = None
group = []
for currvalue in iterable:
k = currvalue if isinstance(key, Optional) else key(currvalue)
if currkey is None:
currkey = k
if k != unwrap(currkey):
yield unwrap(currkey), group
currkey = k
group = []
group.append(currvalue)
if currkey is not None:
yield unwrap(currkey), group
def islice(iterable: Generator[T], stop: Optional[int], T: type) -> Generator[T]:
"""
Make an iterator that returns selected elements from the iterable.
"""
if stop is not None and stop.__val__() < 0:
raise ValueError(
"Indices for islice() must be None or an integer: 0 <= x <= sys.maxsize."
)
i = 0
for x in iterable:
if stop is not None and i >= stop.__val__():
break
yield x
i += 1
@overload
def islice(
iterable: Generator[T],
start: Optional[int],
stop: Optional[int],
step: Optional[int] = None,
T: type,
) -> Generator[T]:
"""
Make an iterator that returns selected elements from the iterable.
"""
from sys import maxsize
start: int = 0 if start is None else start
stop: int = maxsize if stop is None else stop
step: int = 1 if step is None else step
have_stop = False
if start < 0 or stop < 0:
raise ValueError(
"Indices for islice() must be None or an integer: 0 <= x <= sys.maxsize."
)
elif step < 0:
raise ValueError("Step for islice() must be a positive integer or None.")
it = range(start, stop, step)
N = len(it)
idx = 0
b = -1
if N == 0:
for i, element in zip(range(start), iterable):
pass
return
nexti = it[0]
for i, element in enumerate(iterable):
if i == nexti:
yield element
idx += 1
if idx >= N:
b = i
break
nexti = it[idx]
if b >= 0:
for i, element in zip(range(b + 1, stop), iterable):
pass
@inline
def starmap(function, iterable):
"""
Return an iterator whose values are returned from the function
evaluated with an argument tuple taken from the given sequence.
"""
for args in iterable:
yield function(*args)
@inline
def takewhile(
predicate: Callable[[T], bool], iterable: Generator[T], T: type
) -> Generator[T]:
"""
Return successive entries from an iterable as long as the predicate evaluates to true for each entry.
"""
for x in iterable:
if predicate(x):
yield x
else:
break
def tee(iterable: Generator[T], n: int = 2, T: type) -> List[Generator[T]]:
"""
Return n independent iterators from a single iterable.
"""
from collections import deque
it = iter(iterable)
deques = [deque[T]() for i in range(n)]
def gen(mydeque: deque[T], T: type) -> Generator[T]:
while True:
if not mydeque: # when the local deque is empty
if it.__done__():
return
it.__resume__()
if it.__done__():
return
newval = it.next()
for d in deques: # load it to all the deques
d.append(newval)
yield mydeque.popleft()
return [gen(d) for d in deques]
@inline
def zip_longest(*iterables, fillvalue):
"""
Make an iterator that aggregates elements from each of the iterables.
If the iterables are of uneven length, missing values are filled-in
with fillvalue. Iteration continues until the longest iterable is
exhausted.
"""
if staticlen(iterables) == 2:
a = iter(iterables[0])
b = iter(iterables[1])
a_done = False
b_done = False
while not a.done():
a_val = a.next()
b_val = fillvalue
if not b_done:
b_done = b.done()
if not b_done:
b_val = b.next()
yield a_val, b_val
if not b_done:
while not b.done():
yield fillvalue, b.next()
a.destroy()
b.destroy()
else:
iterators = tuple(iter(it) for it in iterables)
num_active = len(iterators)
if not num_active:
return
while True:
values = []
for it in iterators:
if it.__done__(): # already done
values.append(fillvalue)
elif it.done(): # resume and check
num_active -= 1
if not num_active:
return
values.append(fillvalue)
else:
values.append(it.next())
yield values
@inline
@overload
def zip_longest(*args):
"""
Make an iterator that aggregates elements from each of the iterables.
If the iterables are of uneven length, missing values are filled-in
with fillvalue. Iteration continues until the longest iterable is
exhausted.
"""
def get_next(it):
if it.__done__() or it.done():
return None
return it.next()
iters = tuple(iter(arg) for arg in args)
while True:
done_count = 0
result = tuple(get_next(it) for it in iters)
all_none = True
for a in result:
if a is not None:
all_none = False
if all_none:
return
yield result
for it in iters:
it.destroy()
# Combinatoric iterators
def combinations(pool: Generator[T], r: int, T: type) -> Generator[List[T]]:
"""
Return successive r-length combinations of elements in the iterable.
combinations(range(4), 3) --> (0,1,2), (0,1,3), (0,2,3), (1,2,3)
"""
def combinations_helper(pool: List[T], r: int, T: type) -> Generator[List[T]]:
n = len(pool)
if r > n:
return
indices = list(range(r))
yield [pool[i] for i in indices]
while True:
b = -1
for i in reversed(range(r)):
if indices[i] != i + n - r:
b = i
break
if b == -1:
return
indices[b] += 1
for j in range(b + 1, r):
indices[j] = indices[j - 1] + 1
yield [pool[i] for i in indices]
if r < 0:
raise ValueError("r must be non-negative")
if hasattr(pool, "__getitem__") and hasattr(pool, "__len__"):
return combinations_helper(pool, r)
else:
return combinations_helper([a for a in pool], r)
def combinations_with_replacement(
pool: Generator[T], r: int, T: type
) -> Generator[List[T]]:
"""
Return successive r-length combinations of elements in the iterable
allowing individual elements to have successive repeats.
"""
def combinations_with_replacement_helper(
pool: List[T], r: int, T: type
) -> Generator[List[T]]:
n = len(pool)
if not n and r:
return
indices = [0 for _ in range(r)]
yield [pool[i] for i in indices]
while True:
b = -1
for i in reversed(range(r)):
if indices[i] != n - 1:
b = i
break
if b == -1:
return
newval = indices[b] + 1
for j in range(r - b):
indices[b + j] = newval
yield [pool[i] for i in indices]
if r < 0:
raise ValueError("r must be non-negative")
if hasattr(pool, "__getitem__") and hasattr(pool, "__len__"):
return combinations_with_replacement_helper(pool, r)
else:
return combinations_with_replacement_helper([a for a in pool], r)
def permutations(
pool: Generator[T], r: Optional[int] = None, T: type
) -> Generator[List[T]]:
"""
Return successive r-length permutations of elements in the iterable.
"""
def permutations_helper(
pool: List[T], r: Optional[int], T: type
) -> Generator[List[T]]:
n = len(pool)
r: int = r if r is not None else n
if r > n:
return
indices = list(range(n))
cycles = list(range(n, n - r, -1))
yield [pool[i] for i in indices[:r]]
while n:
b = -1
for i in reversed(range(r)):
cycles[i] -= 1
if cycles[i] == 0:
indices = indices[:i] + indices[i + 1 :] + indices[i : i + 1]
cycles[i] = n - i
else:
b = i
j = cycles[i]
indices[i], indices[-j] = indices[-j], indices[i]
yield [pool[i] for i in indices[:r]]
break
if b == -1:
return
if r is not None and r.__val__() < 0:
raise ValueError("r must be non-negative")
if hasattr(pool, "__getitem__") and hasattr(pool, "__len__"):
return permutations_helper(pool, r)
else:
return permutations_helper([a for a in pool], r)
@inline
def product(*args):
"""
Cartesian product of input iterables.
"""
if staticlen(args) == 0:
yield ()
else:
for a in args[0]:
rest = args[1:]
for b in product(*rest):
yield (a, *b)
@inline
@overload
def product(*args, repeat: int):
"""
Cartesian product of input iterables.
"""
if repeat < 0:
raise ValueError("repeat argument cannot be negative")
pools = [list(pool) for _ in range(repeat) for pool in args]
result = [List[type(pools[0][0])]()]
for pool in pools:
result = [x + [y] for x in result for y in pool]
for prod in result:
yield prod