Misc fixes (#410)

* Fix corner case when typechecking scoped names with static compilation

* Undo log

* Fix nested loop domination; Minor aestethic fixes

* clang-format

* Add slice indices() method

* Fix overloads with static arguments

* Update itertools combinatorics functions

* Fix import domination issue (missing stack insert)

* Fix itertools

* Remove log

* Bump version

---------

Co-authored-by: Ibrahim Numanagić <ibrahimpasa@gmail.com>
pull/420/head
A. R. Shajii 2023-07-02 18:50:43 -04:00 committed by GitHub
parent e95f778df1
commit 6bb26e0187
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 869 additions and 173 deletions

View File

@ -1,10 +1,10 @@
cmake_minimum_required(VERSION 3.14)
project(
Codon
VERSION "0.16.1"
VERSION "0.16.2"
HOMEPAGE_URL "https://github.com/exaloop/codon"
DESCRIPTION "high-performance, extensible Python compiler")
set(CODON_JIT_PYTHON_VERSION "0.1.5")
set(CODON_JIT_PYTHON_VERSION "0.1.6")
configure_file("${PROJECT_SOURCE_DIR}/cmake/config.h.in"
"${PROJECT_SOURCE_DIR}/codon/config/config.h")
configure_file("${PROJECT_SOURCE_DIR}/cmake/config.py.in"

View File

@ -154,7 +154,8 @@ bool LinkType::isInstantiated() const { return kind == Link && type->isInstantia
std::string LinkType::debugString(char mode) const {
if (kind == Unbound || kind == Generic) {
if (mode == 2) {
return fmt::format("{}{}{}", kind == Unbound ? '?' : '#', id,
return fmt::format("{}{}{}{}", genericName.empty() ? "" : genericName + ":",
kind == Unbound ? '?' : '#', id,
trait ? ":" + trait->debugString(mode) : "");
}
if (trait)

View File

@ -38,11 +38,16 @@ void SimplifyVisitor::visit(IdExpr *expr) {
// while True:
// if x > 10: break
// x = x + 1 # x must be dominated after the loop to ensure that it gets updated
if (auto loop = ctx->getBase()->getLoop()) {
bool inside = val->scope.size() >= loop->scope.size() &&
val->scope[loop->scope.size() - 1] == loop->scope.back();
if (!inside)
loop->seenVars.insert(expr->value);
if (ctx->getBase()->getLoop()) {
for (size_t li = ctx->getBase()->loops.size(); li-- > 0;) {
auto &loop = ctx->getBase()->loops[li];
bool inside = val->scope.size() >= loop.scope.size() &&
val->scope[loop.scope.size() - 1] == loop.scope.back();
if (!inside)
loop.seenVars.insert(expr->value);
else
break;
}
}
// Replace the variable with its canonical name

View File

@ -153,6 +153,7 @@ SimplifyContext::Item SimplifyContext::findDominatingBinding(const std::string &
(*lastGood)->importPath);
item->accessChecked = {(*lastGood)->scope};
lastGood = it->second.insert(++lastGood, item);
stack.front().push_back(name);
// Make sure to prepend a binding declaration: `var` and `var__used__ = False`
// to the dominating scope.
scope.stmts[scope.blocks[prefix - 1]].push_back(std::make_unique<AssignStmt>(

View File

@ -125,8 +125,9 @@ void SimplifyVisitor::visit(ForStmt *stmt) {
ctx->leaveConditionalBlock(&(stmt->suite->getSuite()->stmts));
// Dominate loop variables
for (auto &var : ctx->getBase()->getLoop()->seenVars)
for (auto &var : ctx->getBase()->getLoop()->seenVars) {
ctx->findDominatingBinding(var);
}
ctx->getBase()->loops.pop_back();
}

View File

@ -284,6 +284,7 @@ void TranslateVisitor::visit(PipeExpr *expr) {
simplePipeline &= !isGen(fn);
std::vector<ir::Value *> args;
args.reserve(call->args.size());
for (auto &a : call->args)
args.emplace_back(a.value->getEllipsis() ? nullptr : transform(a.value));
stages.emplace_back(fn, args, isGen(fn), false);
@ -642,7 +643,7 @@ void TranslateVisitor::transformLLVMFunction(types::FuncType *type, FunctionStmt
ltrim(lp);
rtrim(lp);
// Extract declares and constants.
if (isDeclare && !startswith(lp, "declare ")) {
if (isDeclare && !startswith(lp, "declare ") && !startswith(lp, "@")) {
bool isConst = lp.find("private constant") != std::string::npos;
if (!isConst) {
isDeclare = false;

View File

@ -286,7 +286,8 @@ int TypeContext::reorderNamedArgs(types::FuncType *func,
Emsg(Error::CALL_ARGS_MISSING, cache->rev(func->ast->name),
cache->reverseIdentifierLookup[func->ast->args[i].name]));
}
return score + onDone(starArgIndex, kwstarArgIndex, slots, partial);
auto s = onDone(starArgIndex, kwstarArgIndex, slots, partial);
return s != -1 ? score + s : -1;
}
void TypeContext::dump(int pad) {

View File

@ -86,6 +86,9 @@ public:
return item;
}
std::shared_ptr<TypecheckItem> find(const std::string &name) const override;
std::shared_ptr<TypecheckItem> find(const char *name) const {
return find(std::string(name));
}
/// Find an internal type. Assumes that it exists.
std::shared_ptr<TypecheckItem> forceFind(const std::string &name) const;
types::TypePtr getType(const std::string &name) const;

View File

@ -103,7 +103,7 @@ void TypecheckVisitor::visit(ForStmt *stmt) {
unify(stmt->var->type,
iterType ? unify(val->type, iterType->generics[0].type) : val->type);
ctx->staticLoops.push_back("");
ctx->staticLoops.emplace_back();
ctx->blockLevel++;
transform(stmt->suite);
ctx->blockLevel--;

View File

@ -310,7 +310,7 @@ void TypecheckVisitor::visit(InstantiateExpr *expr) {
} else {
if (expr->typeParams[i]->getNone()) // `None` -> `NoneType`
transformType(expr->typeParams[i]);
if (!expr->typeParams[i]->isType())
if (expr->typeParams[i]->type->getClass() && !expr->typeParams[i]->isType())
E(Error::EXPECTED_TYPE, expr->typeParams[i], "type");
t = ctx->instantiate(expr->typeParams[i]->getSrcInfo(),
expr->typeParams[i]->getType());

View File

@ -258,8 +258,9 @@ int TypecheckVisitor::canCall(const types::FuncTypePtr &fn,
if (slots[si].empty()) {
// is this "real" type?
if (in(niGenerics, fn->ast->args[si].name) &&
!fn->ast->args[si].defaultValue)
!fn->ast->args[si].defaultValue) {
return -1;
}
reordered.push_back({nullptr, 0});
} else {
reordered.push_back({args[slots[si][0]].value->type, slots[si][0]});

View File

@ -8,11 +8,9 @@ class List:
self.arr = Array[T](10)
self.len = 0
def __init__(self, it: Generator[T]):
self.arr = Array[T](10)
def __init__(self, capacity: int):
self.arr = Array[T](capacity)
self.len = 0
for i in it:
self.append(i)
def __init__(self, other: List[T]):
self.arr = Array[T](other.len)
@ -20,9 +18,11 @@ class List:
for i in other:
self.append(i)
def __init__(self, capacity: int):
self.arr = Array[T](capacity)
def __init__(self, it: Generator[T]):
self.arr = Array[T](10)
self.len = 0
for i in it:
self.append(i)
def __init__(self, arr: Array[T], len: int):
self.arr = arr

View File

@ -6,6 +6,9 @@ class Slice:
stop: Optional[int]
step: Optional[int]
def __new__(stop: Optional[int]):
return Slice(None, stop, None)
def adjust_indices(self, length: int) -> Tuple[int, int, int, int]:
step: int = self.step if self.step is not None else 1
start: int = 0
@ -47,6 +50,11 @@ class Slice:
return start, stop, step, 0
def indices(self, length: int):
if length < 0:
raise ValueError("length should not be negative")
return self.adjust_indices(length)[:-1]
def __repr__(self):
return f"slice({self.start}, {self.stop}, {self.step})"

View File

@ -333,140 +333,426 @@ def zip_longest(*args):
# Combinatoric iterators
def combinations(pool: Generator[T], r: int, T: type) -> Generator[List[T]]:
"""
Return successive r-length combinations of elements in the iterable.
combinations(range(4), 3) --> (0,1,2), (0,1,3), (0,2,3), (1,2,3)
"""
def combinations_helper(pool: List[T], r: int, T: type) -> Generator[List[T]]:
n = len(pool)
if r > n:
return
indices = list(range(r))
yield [pool[i] for i in indices]
while True:
b = -1
for i in reversed(range(r)):
if indices[i] != i + n - r:
b = i
break
if b == -1:
return
indices[b] += 1
for j in range(b + 1, r):
indices[j] = indices[j - 1] + 1
yield [pool[i] for i in indices]
if r < 0:
raise ValueError("r must be non-negative")
if hasattr(pool, "__getitem__") and hasattr(pool, "__len__"):
return combinations_helper(pool, r)
def _as_list(x):
if isinstance(x, list):
return x
else:
return combinations_helper([a for a in pool], r)
return list(x)
def combinations_with_replacement(
pool: Generator[T], r: int, T: type
) -> Generator[List[T]]:
"""
Return successive r-length combinations of elements in the iterable
allowing individual elements to have successive repeats.
"""
def combinations_with_replacement_helper(
pool: List[T], r: int, T: type
) -> Generator[List[T]]:
n = len(pool)
if not n and r:
return
indices = [0 for _ in range(r)]
yield [pool[i] for i in indices]
while True:
b = -1
for i in reversed(range(r)):
if indices[i] != n - 1:
b = i
break
if b == -1:
return
newval = indices[b] + 1
for j in range(r - b):
indices[b + j] = newval
yield [pool[i] for i in indices]
if r < 0:
raise ValueError("r must be non-negative")
if hasattr(pool, "__getitem__") and hasattr(pool, "__len__"):
return combinations_with_replacement_helper(pool, r)
else:
return combinations_with_replacement_helper([a for a in pool], r)
def permutations(
pool: Generator[T], r: Optional[int] = None, T: type
) -> Generator[List[T]]:
"""
Return successive r-length permutations of elements in the iterable.
"""
def permutations_helper(
pool: List[T], r: Optional[int], T: type
) -> Generator[List[T]]:
n = len(pool)
r: int = r if r is not None else n
if r > n:
return
indices = list(range(n))
cycles = list(range(n, n - r, -1))
yield [pool[i] for i in indices[:r]]
while n:
b = -1
for i in reversed(range(r)):
cycles[i] -= 1
if cycles[i] == 0:
indices = indices[:i] + indices[i + 1 :] + indices[i : i + 1]
cycles[i] = n - i
else:
b = i
j = cycles[i]
indices[i], indices[-j] = indices[-j], indices[i]
yield [pool[i] for i in indices[:r]]
break
if b == -1:
return
if r is not None and r.__val__() < 0:
raise ValueError("r must be non-negative")
if hasattr(pool, "__getitem__") and hasattr(pool, "__len__"):
return permutations_helper(pool, r)
else:
return permutations_helper([a for a in pool], r)
@inline
def product(*args):
"""
Cartesian product of input iterables.
"""
if staticlen(args) == 0:
yield ()
else:
for a in args[0]:
rest = args[1:]
for b in product(*rest):
yield (a, *b)
@inline
@overload
def product(*args, repeat: int):
"""
Cartesian product of input iterables.
"""
def product(*iterables, repeat: int):
if repeat < 0:
raise ValueError("repeat argument cannot be negative")
pools = [list(pool) for _ in range(repeat) for pool in args]
result = [List[type(pools[0][0])]()]
for pool in pools:
result = [x + [y] for x in result for y in pool]
for prod in result:
yield prod
raise ValueError("repeat must be non-negative")
if repeat == 0:
nargs = 0
else:
nargs = len(iterables)
npools = nargs * repeat
indices = Ptr[int](npools)
pools = list(capacity=npools)
i = 0
while i < nargs:
p = _as_list(iterables[i])
if len(p) == 0:
return
pools.append(p)
indices[i] = 0
i += 1
while i < npools:
pools.append(pools[i - nargs])
indices[i] = 0
i += 1
result = list(capacity=npools)
for i in range(npools):
result.append(pools[i][0])
while True:
yield result
result = result.copy()
i = npools - 1
while i >= 0:
pool = pools[i]
indices[i] += 1
if indices[i] == len(pool):
indices[i] = 0
result[i] = pool[0]
else:
result[i] = pool[indices[i]]
break
i -= 1
if i < 0:
break
@overload
def product(*iterables, repeat: Static[int] = 1):
if repeat < 0:
compile_error("repeat must be non-negative")
# handle some common cases
if repeat == 0:
yield ()
elif repeat == 1 and staticlen(iterables) == 1:
it0 = iterables[0]
for a in it0:
yield (a,)
elif repeat == 1 and staticlen(iterables) == 2:
it0 = iterables[0]
it1 = iterables[1]
for a in it0:
for b in it1:
yield (a, b)
elif repeat == 1 and staticlen(iterables) == 3:
it0 = iterables[0]
it1 = iterables[1]
it2 = iterables[2]
for a in it0:
for b in it1:
for c in it2:
yield (a, b, c)
else:
nargs: Static[int] = staticlen(iterables)
npools: Static[int] = nargs * repeat
indices_tuple = (0,) * npools
indices = Ptr[int](__ptr__(indices_tuple).as_byte())
pools = tuple(_as_list(it) for it in iterables) * repeat
for i in staticrange(nargs):
if len(pools[i]) == 0:
return
result = tuple(pool[0] for pool in pools)
while True:
yield result
i = npools - 1
while i >= 0:
pool = pools[i]
indices[i] += 1
if indices[i] == len(pool):
indices[i] = 0
else:
break
i -= 1
if i < 0:
break
result = tuple(pools[i][indices[i]] for i in staticrange(npools))
def combinations(pool, r: int):
if r < 0:
raise ValueError("r must be non-negative")
pool_list = _as_list(pool)
n = len(pool)
if r > n:
return
pool = pool_list.arr.ptr
indices = Ptr[int](r)
result = list(capacity=r)
for i in range(r):
indices[i] = i
result.append(pool[i])
while True:
yield result
i = r - 1
while i >= 0 and indices[i] == i + n - r:
i -= 1
if i < 0:
break
indices[i] += 1
for j in range(i + 1, r):
indices[j] = indices[j-1] + 1
result = result.copy()
while i < r:
result[i] = pool[indices[i]]
i += 1
@overload
def combinations(pool, r: Static[int]):
def empty(T: type) -> T:
pass
if r < 0:
compile_error("r must be non-negative")
if isinstance(pool, list):
pool_list = pool
else:
pool_list = list(pool)
n = len(pool)
if r > n:
return
pool = pool_list.arr.ptr
indices_tuple = (0,) * r
indices = Ptr[int](__ptr__(indices_tuple).as_byte())
result_tuple = (empty(pool.T),) * r
result = Ptr[pool.T](__ptr__(result_tuple).as_byte())
for i in range(r):
indices[i] = i
result[i] = pool[i]
while True:
yield result_tuple
i = r - 1
while i >= 0 and indices[i] == i + n - r:
i -= 1
if i < 0:
break
indices[i] += 1
for j in range(i + 1, r):
indices[j] = indices[j-1] + 1
while i < r:
result[i] = pool[indices[i]]
i += 1
def combinations_with_replacement(pool, r: int):
if r < 0:
raise ValueError("r must be non-negative")
pool_list = _as_list(pool)
n = len(pool)
if n == 0:
if r == 0:
yield List[pool_list.T](capacity=0)
return
pool = pool_list.arr.ptr
indices = Ptr[int](r)
result = list(capacity=r)
for i in range(r):
indices[i] = 0
result.append(pool[0])
while True:
yield result
i = r - 1
while i >= 0 and indices[i] == n - 1:
i -= 1
if i < 0:
break
result = result.copy()
index = indices[i] + 1
elem = pool[index]
while i < r:
indices[i] = index
result[i] = elem
i += 1
@overload
def combinations_with_replacement(pool, r: Static[int]):
def empty(T: type) -> T:
pass
if r < 0:
compile_error("r must be non-negative")
if r == 0:
yield ()
return
if isinstance(pool, list):
pool_list = pool
else:
pool_list = list(pool)
n = len(pool)
if n == 0:
return
pool = pool_list.arr.ptr
indices_tuple = (0,) * r
indices = Ptr[int](__ptr__(indices_tuple).as_byte())
result_tuple = (empty(pool.T),) * r
result = Ptr[pool.T](__ptr__(result_tuple).as_byte())
for i in range(r):
result[i] = pool[0]
while True:
yield result_tuple
i = r - 1
while i >= 0 and indices[i] == n - 1:
i -= 1
if i < 0:
break
index = indices[i] + 1
elem = pool[index]
while i < r:
indices[i] = index
result[i] = elem
i += 1
def _permutations_non_static(pool, r = None):
pool_list = _as_list(pool)
n = len(pool)
if r is None:
return _permutations_non_static(pool_list, n)
elif not isinstance(r, int):
compile_error("Expected int as r")
if r < 0:
raise ValueError("r must be non-negative")
if r > n:
return
indices = Ptr[int](n)
cycles = Ptr[int](r)
for i in range(n):
indices[i] = i
for i in range(r):
cycles[i] = n - i
pool = pool_list.arr.ptr
result = list(capacity=r)
for i in range(r):
result.append(pool[i])
while True:
yield result
if n == 0:
break
result = result.copy()
i = r - 1
while i >= 0:
cycles[i] -= 1
if cycles[i] == 0:
index = indices[i]
for j in range(i, n - 1):
indices[j] = indices[j+1]
indices[n-1] = index
cycles[i] = n - i
else:
j = cycles[i]
index = indices[i]
indices[i] = indices[n - j]
indices[n - j] = index
for k in range(i, r):
index = indices[k]
result[k] = pool[index]
break
i -= 1
if i < 0:
break
def _permutations_static(pool, r: Static[int]):
def empty(T: type) -> T:
pass
pool_list = _as_list(pool)
n = len(pool)
if r < 0:
raise compile_error("r must be non-negative")
if r > n:
return
indices = Ptr[int](n)
cycles_tuple = (0,) * r
cycles = Ptr[int](__ptr__(cycles_tuple).as_byte())
for i in range(n):
indices[i] = i
for i in range(r):
cycles[i] = n - i
pool = pool_list.arr.ptr
result_tuple = (empty(pool.T),) * r
result = Ptr[pool.T](__ptr__(result_tuple).as_byte())
for i in range(r):
result[i] = pool[i]
while True:
yield result_tuple
if n == 0:
break
i = r - 1
while i >= 0:
cycles[i] -= 1
if cycles[i] == 0:
index = indices[i]
for j in range(i, n - 1):
indices[j] = indices[j+1]
indices[n-1] = index
cycles[i] = n - i
else:
j = cycles[i]
index = indices[i]
indices[i] = indices[n - j]
indices[n - j] = index
for k in range(i, r):
index = indices[k]
result[k] = pool[index]
break
i -= 1
if i < 0:
break
def permutations(pool, r = None):
if isinstance(pool, Tuple) and r is None:
return _permutations_static(pool, staticlen(pool))
else:
return _permutations_non_static(pool, r)
@overload
def permutations(pool, r: Static[int]):
return _permutations_static(pool, r)

View File

@ -571,6 +571,171 @@ def test_dict():
assert repr(Dict[int,int]()) == '{}'
test_dict()
def slice_indices(slice, length):
"""
Reference implementation for the slice.indices method.
"""
# Compute step and length as integers.
#length = operator.index(length)
step: int = 1 if slice.step is None else slice.step
# Raise ValueError for negative length or zero step.
if length < 0:
raise ValueError("length should not be negative")
if step == 0:
raise ValueError("slice step cannot be zero")
# Find lower and upper bounds for start and stop.
lower = -1 if step < 0 else 0
upper = length - 1 if step < 0 else length
# Compute start.
if slice.start is None:
start = upper if step < 0 else lower
else:
start = slice.start
start = max(start + length, lower) if start < 0 else min(start, upper)
# Compute stop.
if slice.stop is None:
stop = lower if step < 0 else upper
else:
stop = slice.stop
stop = max(stop + length, lower) if stop < 0 else min(stop, upper)
return start, stop, step
def check_indices(slice, length):
err1 = False
err2 = False
try:
actual = slice.indices(length)
except ValueError:
err1 = True
try:
expected = slice_indices(slice, length)
except ValueError:
err2 = True
if err1 or err2:
return err1 and err2
if actual != expected:
return False
if length >= 0 and slice.step != 0:
actual = range(*slice.indices(length))
expected = range(length)[slice]
if actual != expected:
return False
return True
@test
def test_slice():
assert repr(slice(1, 2, 3)) == 'slice(1, 2, 3)'
s1 = slice(1, 2, 3)
s2 = slice(1, 2, 3)
s3 = slice(1, 2, 4)
assert s1 == s2
assert s1 != s3
s = slice(1)
assert s.start == None
assert s.stop == 1
assert s.step == None
s = slice(1, 2)
assert s.start == 1
assert s.stop == 2
assert s.step == None
s = slice(1, 2, 3)
assert s.start == 1
assert s.stop == 2
assert s.step == 3
# TODO
assert slice(None ).indices(10) == (0, 10, 1)
assert slice(None, None, 2).indices(10) == (0, 10, 2)
assert slice(1, None, 2).indices(10) == (1, 10, 2)
assert slice(None, None, -1).indices(10) == (9, -1, -1)
assert slice(None, None, -2).indices(10) == (9, -1, -2)
assert slice(3, None, -2).indices(10) == (3, -1, -2)
# issue 3004 tests
assert slice(None, -9).indices(10) == (0, 1, 1)
assert slice(None, -10).indices(10) == (0, 0, 1)
assert slice(None, -11).indices(10) == (0, 0, 1)
assert slice(None, -10, -1).indices(10) == (9, 0, -1)
assert slice(None, -11, -1).indices(10) == (9, -1, -1)
assert slice(None, -12, -1).indices(10) == (9, -1, -1)
assert slice(None, 9).indices(10) == (0, 9, 1)
assert slice(None, 10).indices(10) == (0, 10, 1)
assert slice(None, 11).indices(10) == (0, 10, 1)
assert slice(None, 8, -1).indices(10) == (9, 8, -1)
assert slice(None, 9, -1).indices(10) == (9, 9, -1)
assert slice(None, 10, -1).indices(10) == (9, 9, -1)
assert slice(-100, 100 ).indices(10) == slice(None).indices(10)
assert slice(100, -100, -1).indices(10) == slice(None, None, -1).indices(10)
assert slice(-100, 100, 2).indices(10) == (0, 10, 2)
import sys
assert list(range(10))[::sys.maxsize - 1] == [0]
# Check a variety of start, stop, step and length values, including
# values exceeding sys.maxsize (see issue #14794).
vals = [None, -2**100, -2**30, -53, -7, -1, 0, 1, 7, 53, 2**30, 2**100]
lengths = [0, 1, 7, 53, 2**30, 2**100]
#for slice_args in itertools.product(vals, repeat=3):
for a in vals:
for b in vals:
for c in vals:
slice_args = (a, b, c)
s = slice(*slice_args)
for length in lengths:
assert check_indices(s, length)
assert check_indices(slice(0, 10, 1), -3)
# Negative length should raise ValueError
try:
slice(None).indices(-1)
assert False
except ValueError:
pass
# Zero step should raise ValueError
try:
slice(0, 10, 0).indices(5)
except ValueError:
pass
# ... but it should be fine to use a custom class that provides index.
assert slice(0, 10, 1).indices(5) == (0, 5, 1)
''' # not yet supported in Codon
assert slice(MyIndexable(0), 10, 1).indices(5) == (0, 5, 1)
assert slice(0, MyIndexable(10), 1).indices(5) == (0, 5, 1)
assert slice(0, 10, MyIndexable(1)).indices(5) == (0, 5, 1)
assert slice(0, 10, 1).indices(MyIndexable(5)) == (0, 5, 1)
'''
tmp = []
class X[T](object):
tmp: T
def __setitem__(self, i, k):
self.tmp.append((i, k))
x = X(tmp)
x[1:2] = 42
assert tmp == [(slice(1, 2), 42)]
test_slice()
@test
def test_deque():
from collections import deque

View File

@ -61,7 +61,9 @@ def underten(x):
@test
def test_combinations():
assert list(itertools.combinations("ABCD", 2)) == [
f = lambda x: x # hack to get non-static argument
assert list(itertools.combinations("ABCD", f(2))) == [
["A", "B"],
["A", "C"],
["A", "D"],
@ -69,7 +71,7 @@ def test_combinations():
["B", "D"],
["C", "D"],
]
test_intermediate = itertools.combinations("ABCD", 2)
test_intermediate = itertools.combinations("ABCD", f(2))
next(test_intermediate)
assert list(test_intermediate) == [
["A", "C"],
@ -78,20 +80,49 @@ def test_combinations():
["B", "D"],
["C", "D"],
]
assert list(itertools.combinations(range(4), 3)) == [
assert list(itertools.combinations(range(4), f(3))) == [
[0, 1, 2],
[0, 1, 3],
[0, 2, 3],
[1, 2, 3],
]
test_intermediate = itertools.combinations(range(4), 3)
test_intermediate = itertools.combinations(range(4), f(3))
next(test_intermediate)
assert list(test_intermediate) == [[0, 1, 3], [0, 2, 3], [1, 2, 3]]
assert list(itertools.combinations("ABCD", 2)) == [
("A", "B"),
("A", "C"),
("A", "D"),
("B", "C"),
("B", "D"),
("C", "D"),
]
test_intermediate = itertools.combinations("ABCD", 2)
next(test_intermediate)
assert list(test_intermediate) == [
("A", "C"),
("A", "D"),
("B", "C"),
("B", "D"),
("C", "D"),
]
assert list(itertools.combinations(range(4), 3)) == [
(0, 1, 2),
(0, 1, 3),
(0, 2, 3),
(1, 2, 3),
]
test_intermediate = itertools.combinations(range(4), 3)
next(test_intermediate)
assert list(test_intermediate) == [(0, 1, 3), (0, 2, 3), (1, 2, 3)]
@test
def test_combinations_with_replacement():
assert list(itertools.combinations_with_replacement(range(3), 3)) == [
f = lambda x: x # hack to get non-static argument
assert list(itertools.combinations_with_replacement(range(3), f(3))) == [
[0, 0, 0],
[0, 0, 1],
[0, 0, 2],
@ -103,7 +134,7 @@ def test_combinations_with_replacement():
[1, 2, 2],
[2, 2, 2],
]
assert list(itertools.combinations_with_replacement("ABC", 2)) == [
assert list(itertools.combinations_with_replacement("ABC", f(2))) == [
["A", "A"],
["A", "B"],
["A", "C"],
@ -111,7 +142,7 @@ def test_combinations_with_replacement():
["B", "C"],
["C", "C"],
]
test_intermediate = itertools.combinations_with_replacement("ABC", 2)
test_intermediate = itertools.combinations_with_replacement("ABC", f(2))
next(test_intermediate)
assert list(test_intermediate) == [
["A", "B"],
@ -121,6 +152,35 @@ def test_combinations_with_replacement():
["C", "C"],
]
assert list(itertools.combinations_with_replacement(range(3), 3)) == [
(0, 0, 0),
(0, 0, 1),
(0, 0, 2),
(0, 1, 1),
(0, 1, 2),
(0, 2, 2),
(1, 1, 1),
(1, 1, 2),
(1, 2, 2),
(2, 2, 2),
]
assert list(itertools.combinations_with_replacement("ABC", 2)) == [
("A", "A"),
("A", "B"),
("A", "C"),
("B", "B"),
("B", "C"),
("C", "C"),
]
test_intermediate = itertools.combinations_with_replacement("ABC", 2)
next(test_intermediate)
assert list(test_intermediate) == [
("A", "B"),
("A", "C"),
("B", "B"),
("B", "C"),
("C", "C"),
]
@test
def test_islice():
@ -243,7 +303,9 @@ def test_filterfalse():
@test
def test_permutations():
assert list(itertools.permutations(range(3), 2)) == [
f = lambda x: x # hack to get non-static argument
assert list(itertools.permutations(range(3), f(2))) == [
[0, 1],
[0, 2],
[1, 0],
@ -255,6 +317,24 @@ def test_permutations():
for n in range(3):
values = [5 * x - 12 for x in range(n)]
for r in range(n + 2):
result = list(itertools.permutations(values, f(r)))
if r > n: # right number of perms
assert len(result) == 0
# factorial is not yet implemented in math
# else: fact(n) / fact(n - r)
assert list(itertools.permutations(range(3), 2)) == [
(0, 1),
(0, 2),
(1, 0),
(1, 2),
(2, 0),
(2, 1),
]
for n in staticrange(3):
values = [5 * x - 12 for x in range(n)]
for r in staticrange(n + 2):
result = list(itertools.permutations(values, r))
if r > n: # right number of perms
assert len(result) == 0
@ -487,18 +567,19 @@ test_chain_from_iterable_from_cpython()
@test
def test_combinations_from_cpython():
f = lambda x: x # hack to get non-static argument
from math import factorial as fact
err = False
try:
list(combinations("abc", -2))
list(combinations("abc", f(-2)))
assert False
except ValueError:
err = True
assert err
assert list(combinations("abc", 32)) == [] # r > n
assert list(combinations("ABCD", 2)) == [
assert list(combinations("abc", f(32))) == [] # r > n
assert list(combinations("ABCD", f(2))) == [
["A", "B"],
["A", "C"],
["A", "D"],
@ -506,7 +587,7 @@ def test_combinations_from_cpython():
["B", "D"],
["C", "D"],
]
assert list(combinations(range(4), 3)) == [
assert list(combinations(range(4), f(3))) == [
[0, 1, 2],
[0, 1, 3],
[0, 2, 3],
@ -516,7 +597,7 @@ def test_combinations_from_cpython():
for n in range(7):
values = [5 * x - 12 for x in range(n)]
for r in range(n + 2):
result = list(combinations(values, r))
result = list(combinations(values, f(r)))
assert len(result) == (0 if r > n else fact(n) // fact(r) // fact(n - r))
assert len(result) == len(set(result)) # no repeats
@ -531,21 +612,55 @@ def test_combinations_from_cpython():
] # comb is a subsequence of the input iterable
assert list(combinations("abc", 32)) == [] # r > n
assert list(combinations("ABCD", 2)) == [
("A", "B"),
("A", "C"),
("A", "D"),
("B", "C"),
("B", "D"),
("C", "D"),
]
assert list(combinations(range(4), 3)) == [
(0, 1, 2),
(0, 1, 3),
(0, 2, 3),
(1, 2, 3),
]
for n in staticrange(7):
values = [5 * x - 12 for x in range(n)]
for r in staticrange(n + 2):
result = list(combinations(values, r))
assert len(result) == (0 if r > n else fact(n) // fact(r) // fact(n - r))
assert len(result) == len(set(result)) # no repeats
# assert result == sorted(result) # lexicographic order
for c in result:
assert len(c) == r # r-length combinations
assert len(set(c)) == r # no duplicate elements
assert list(c) == sorted(c) # keep original ordering
assert all(e in values for e in c) # elements taken from input iterable
assert list(c) == [
e for e in values if e in c
] # comb is a subsequence of the input iterable
test_combinations_from_cpython()
@test
def test_combinations_with_replacement_from_cpython():
f = lambda x: x # hack to get non-static argument
cwr = combinations_with_replacement
err = False
try:
list(cwr("abc", -2))
list(combinations_with_replacement("abc", f(-2)))
assert False
except ValueError:
err = True
assert err
assert list(cwr("ABC", 2)) == [
assert list(combinations_with_replacement("ABC", f(2))) == [
["A", "A"],
["A", "B"],
["A", "C"],
@ -564,7 +679,44 @@ def test_combinations_with_replacement_from_cpython():
for n in range(7):
values = [5 * x - 12 for x in range(n)]
for r in range(n + 2):
result = list(cwr(values, r))
result = list(combinations_with_replacement(values, r))
regular_combs = list(combinations(values, r))
assert len(result) == numcombs(n, r)
assert len(result) == len(set(result)) # no repeats
# assert result == sorted(result) # lexicographic order
if n == 0 or r <= 1:
assert result == regular_combs # cases that should be identical
else:
assert set(result) >= set(regular_combs)
for c in result:
assert len(c) == r # r-length combinations
noruns = [k for k, v in groupby(c)] # combo without consecutive repeats
assert len(noruns) == len(
set(noruns)
) # no repeats other than consecutive
assert list(c) == sorted(c) # keep original ordering
assert all(e in values for e in c) # elements taken from input iterable
assert noruns == [
e for e in values if e in c
] # comb is a subsequence of the input iterable
assert list(combinations_with_replacement("ABC", 2)) == [
("A", "A"),
("A", "B"),
("A", "C"),
("B", "B"),
("B", "C"),
("C", "C"),
]
for n in staticrange(7):
values = [5 * x - 12 for x in range(n)]
for r in staticrange(n + 2):
result = list(combinations_with_replacement(values, r))
regular_combs = list(combinations(values, r))
assert len(result) == numcombs(n, r)
@ -594,18 +746,19 @@ test_combinations_with_replacement_from_cpython()
@test
def test_permutations_from_cpython():
f = lambda x: x # hack to get non-static argument
from math import factorial as fact
err = False
try:
list(permutations("abc", -2))
list(permutations("abc", f(-2)))
assert False
except ValueError:
err = True
assert err
assert list(permutations("abc", 32)) == []
assert list(permutations(range(3), 2)) == [
assert list(permutations("abc", f(32))) == []
assert list(permutations(range(3), f(2))) == [
[0, 1],
[0, 2],
[1, 0],
@ -632,6 +785,33 @@ def test_permutations_from_cpython():
assert result == list(permutations(values, None)) # test r as None
assert result == list(permutations(values)) # test default r
assert list(permutations("abc", 32)) == []
assert list(permutations(range(3), 2)) == [
(0, 1),
(0, 2),
(1, 0),
(1, 2),
(2, 0),
(2, 1),
]
for n in staticrange(7):
values = [5 * x - 12 for x in range(n)]
for r in staticrange(n + 2):
result = list(permutations(values, r))
assert len(result) == (
0 if r > n else fact(n) // fact(n - r)
) # right number of perms
assert len(result) == len(set(result)) # no repeats
# assert result == sorted(result) # lexicographic order
for p in result:
assert len(p) == r # r-length permutations
assert len(set(p)) == r # no duplicate elements
assert all(e in values for e in p) # elements taken from input iterable
if r == n:
assert result == list(permutations(values, r))
test_permutations_from_cpython()
@ -728,6 +908,49 @@ def test_combinatorics_from_cpython():
) # comb: cwr that is a perm
assert comb == sorted(set(cwr) & set(perm)) # comb: both a cwr and a perm
for n in staticrange(6):
s = "ABCDEFG"[:n]
for r in staticrange(8):
prod = list(product(s, repeat=r))
cwr = list(combinations_with_replacement(s, r))
perm = list(permutations(s, r))
comb = list(combinations(s, r))
# Check size
assert len(prod) == n ** r
assert len(cwr) == (
(fact(n + r - 1) // fact(r) // fact(n - 1)) if n else (0 if r else 1)
)
assert len(perm) == (0 if r > n else fact(n) // fact(n - r))
assert len(comb) == (0 if r > n else fact(n) // fact(r) // fact(n - r))
# Check lexicographic order without repeated tuples
assert prod == sorted(set(prod))
assert cwr == sorted(set(cwr))
assert perm == sorted(set(perm))
assert comb == sorted(set(comb))
# Check interrelationships
assert cwr == [
t for t in prod if sorted(t) == list(t)
] # cwr: prods which are sorted
assert perm == [
t for t in prod if len(set(t)) == r
] # perm: prods with no dups
assert comb == [
t for t in perm if sorted(t) == list(t)
] # comb: perms that are sorted
assert comb == [
t for t in cwr if len(set(t)) == r
] # comb: cwrs without dups
assert comb == list(
filter(set(cwr).__contains__, perm)
) # comb: perm that is a cwr
assert comb == list(
filter(set(perm).__contains__, cwr)
) # comb: cwr that is a perm
assert comb == sorted(set(cwr) & set(perm)) # comb: both a cwr and a perm
test_combinatorics_from_cpython()