mirror of https://github.com/exaloop/codon
Misc fixes (#410)
* Fix corner case when typechecking scoped names with static compilation * Undo log * Fix nested loop domination; Minor aestethic fixes * clang-format * Add slice indices() method * Fix overloads with static arguments * Update itertools combinatorics functions * Fix import domination issue (missing stack insert) * Fix itertools * Remove log * Bump version --------- Co-authored-by: Ibrahim Numanagić <ibrahimpasa@gmail.com>pull/420/head
parent
e95f778df1
commit
6bb26e0187
|
@ -1,10 +1,10 @@
|
|||
cmake_minimum_required(VERSION 3.14)
|
||||
project(
|
||||
Codon
|
||||
VERSION "0.16.1"
|
||||
VERSION "0.16.2"
|
||||
HOMEPAGE_URL "https://github.com/exaloop/codon"
|
||||
DESCRIPTION "high-performance, extensible Python compiler")
|
||||
set(CODON_JIT_PYTHON_VERSION "0.1.5")
|
||||
set(CODON_JIT_PYTHON_VERSION "0.1.6")
|
||||
configure_file("${PROJECT_SOURCE_DIR}/cmake/config.h.in"
|
||||
"${PROJECT_SOURCE_DIR}/codon/config/config.h")
|
||||
configure_file("${PROJECT_SOURCE_DIR}/cmake/config.py.in"
|
||||
|
|
|
@ -154,7 +154,8 @@ bool LinkType::isInstantiated() const { return kind == Link && type->isInstantia
|
|||
std::string LinkType::debugString(char mode) const {
|
||||
if (kind == Unbound || kind == Generic) {
|
||||
if (mode == 2) {
|
||||
return fmt::format("{}{}{}", kind == Unbound ? '?' : '#', id,
|
||||
return fmt::format("{}{}{}{}", genericName.empty() ? "" : genericName + ":",
|
||||
kind == Unbound ? '?' : '#', id,
|
||||
trait ? ":" + trait->debugString(mode) : "");
|
||||
}
|
||||
if (trait)
|
||||
|
|
|
@ -38,11 +38,16 @@ void SimplifyVisitor::visit(IdExpr *expr) {
|
|||
// while True:
|
||||
// if x > 10: break
|
||||
// x = x + 1 # x must be dominated after the loop to ensure that it gets updated
|
||||
if (auto loop = ctx->getBase()->getLoop()) {
|
||||
bool inside = val->scope.size() >= loop->scope.size() &&
|
||||
val->scope[loop->scope.size() - 1] == loop->scope.back();
|
||||
if (!inside)
|
||||
loop->seenVars.insert(expr->value);
|
||||
if (ctx->getBase()->getLoop()) {
|
||||
for (size_t li = ctx->getBase()->loops.size(); li-- > 0;) {
|
||||
auto &loop = ctx->getBase()->loops[li];
|
||||
bool inside = val->scope.size() >= loop.scope.size() &&
|
||||
val->scope[loop.scope.size() - 1] == loop.scope.back();
|
||||
if (!inside)
|
||||
loop.seenVars.insert(expr->value);
|
||||
else
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Replace the variable with its canonical name
|
||||
|
|
|
@ -153,6 +153,7 @@ SimplifyContext::Item SimplifyContext::findDominatingBinding(const std::string &
|
|||
(*lastGood)->importPath);
|
||||
item->accessChecked = {(*lastGood)->scope};
|
||||
lastGood = it->second.insert(++lastGood, item);
|
||||
stack.front().push_back(name);
|
||||
// Make sure to prepend a binding declaration: `var` and `var__used__ = False`
|
||||
// to the dominating scope.
|
||||
scope.stmts[scope.blocks[prefix - 1]].push_back(std::make_unique<AssignStmt>(
|
||||
|
|
|
@ -125,8 +125,9 @@ void SimplifyVisitor::visit(ForStmt *stmt) {
|
|||
|
||||
ctx->leaveConditionalBlock(&(stmt->suite->getSuite()->stmts));
|
||||
// Dominate loop variables
|
||||
for (auto &var : ctx->getBase()->getLoop()->seenVars)
|
||||
for (auto &var : ctx->getBase()->getLoop()->seenVars) {
|
||||
ctx->findDominatingBinding(var);
|
||||
}
|
||||
ctx->getBase()->loops.pop_back();
|
||||
}
|
||||
|
||||
|
|
|
@ -284,6 +284,7 @@ void TranslateVisitor::visit(PipeExpr *expr) {
|
|||
simplePipeline &= !isGen(fn);
|
||||
|
||||
std::vector<ir::Value *> args;
|
||||
args.reserve(call->args.size());
|
||||
for (auto &a : call->args)
|
||||
args.emplace_back(a.value->getEllipsis() ? nullptr : transform(a.value));
|
||||
stages.emplace_back(fn, args, isGen(fn), false);
|
||||
|
@ -642,7 +643,7 @@ void TranslateVisitor::transformLLVMFunction(types::FuncType *type, FunctionStmt
|
|||
ltrim(lp);
|
||||
rtrim(lp);
|
||||
// Extract declares and constants.
|
||||
if (isDeclare && !startswith(lp, "declare ")) {
|
||||
if (isDeclare && !startswith(lp, "declare ") && !startswith(lp, "@")) {
|
||||
bool isConst = lp.find("private constant") != std::string::npos;
|
||||
if (!isConst) {
|
||||
isDeclare = false;
|
||||
|
|
|
@ -286,7 +286,8 @@ int TypeContext::reorderNamedArgs(types::FuncType *func,
|
|||
Emsg(Error::CALL_ARGS_MISSING, cache->rev(func->ast->name),
|
||||
cache->reverseIdentifierLookup[func->ast->args[i].name]));
|
||||
}
|
||||
return score + onDone(starArgIndex, kwstarArgIndex, slots, partial);
|
||||
auto s = onDone(starArgIndex, kwstarArgIndex, slots, partial);
|
||||
return s != -1 ? score + s : -1;
|
||||
}
|
||||
|
||||
void TypeContext::dump(int pad) {
|
||||
|
|
|
@ -86,6 +86,9 @@ public:
|
|||
return item;
|
||||
}
|
||||
std::shared_ptr<TypecheckItem> find(const std::string &name) const override;
|
||||
std::shared_ptr<TypecheckItem> find(const char *name) const {
|
||||
return find(std::string(name));
|
||||
}
|
||||
/// Find an internal type. Assumes that it exists.
|
||||
std::shared_ptr<TypecheckItem> forceFind(const std::string &name) const;
|
||||
types::TypePtr getType(const std::string &name) const;
|
||||
|
|
|
@ -103,7 +103,7 @@ void TypecheckVisitor::visit(ForStmt *stmt) {
|
|||
unify(stmt->var->type,
|
||||
iterType ? unify(val->type, iterType->generics[0].type) : val->type);
|
||||
|
||||
ctx->staticLoops.push_back("");
|
||||
ctx->staticLoops.emplace_back();
|
||||
ctx->blockLevel++;
|
||||
transform(stmt->suite);
|
||||
ctx->blockLevel--;
|
||||
|
|
|
@ -310,7 +310,7 @@ void TypecheckVisitor::visit(InstantiateExpr *expr) {
|
|||
} else {
|
||||
if (expr->typeParams[i]->getNone()) // `None` -> `NoneType`
|
||||
transformType(expr->typeParams[i]);
|
||||
if (!expr->typeParams[i]->isType())
|
||||
if (expr->typeParams[i]->type->getClass() && !expr->typeParams[i]->isType())
|
||||
E(Error::EXPECTED_TYPE, expr->typeParams[i], "type");
|
||||
t = ctx->instantiate(expr->typeParams[i]->getSrcInfo(),
|
||||
expr->typeParams[i]->getType());
|
||||
|
|
|
@ -258,8 +258,9 @@ int TypecheckVisitor::canCall(const types::FuncTypePtr &fn,
|
|||
if (slots[si].empty()) {
|
||||
// is this "real" type?
|
||||
if (in(niGenerics, fn->ast->args[si].name) &&
|
||||
!fn->ast->args[si].defaultValue)
|
||||
!fn->ast->args[si].defaultValue) {
|
||||
return -1;
|
||||
}
|
||||
reordered.push_back({nullptr, 0});
|
||||
} else {
|
||||
reordered.push_back({args[slots[si][0]].value->type, slots[si][0]});
|
||||
|
|
|
@ -8,11 +8,9 @@ class List:
|
|||
self.arr = Array[T](10)
|
||||
self.len = 0
|
||||
|
||||
def __init__(self, it: Generator[T]):
|
||||
self.arr = Array[T](10)
|
||||
def __init__(self, capacity: int):
|
||||
self.arr = Array[T](capacity)
|
||||
self.len = 0
|
||||
for i in it:
|
||||
self.append(i)
|
||||
|
||||
def __init__(self, other: List[T]):
|
||||
self.arr = Array[T](other.len)
|
||||
|
@ -20,9 +18,11 @@ class List:
|
|||
for i in other:
|
||||
self.append(i)
|
||||
|
||||
def __init__(self, capacity: int):
|
||||
self.arr = Array[T](capacity)
|
||||
def __init__(self, it: Generator[T]):
|
||||
self.arr = Array[T](10)
|
||||
self.len = 0
|
||||
for i in it:
|
||||
self.append(i)
|
||||
|
||||
def __init__(self, arr: Array[T], len: int):
|
||||
self.arr = arr
|
||||
|
|
|
@ -6,6 +6,9 @@ class Slice:
|
|||
stop: Optional[int]
|
||||
step: Optional[int]
|
||||
|
||||
def __new__(stop: Optional[int]):
|
||||
return Slice(None, stop, None)
|
||||
|
||||
def adjust_indices(self, length: int) -> Tuple[int, int, int, int]:
|
||||
step: int = self.step if self.step is not None else 1
|
||||
start: int = 0
|
||||
|
@ -47,6 +50,11 @@ class Slice:
|
|||
|
||||
return start, stop, step, 0
|
||||
|
||||
def indices(self, length: int):
|
||||
if length < 0:
|
||||
raise ValueError("length should not be negative")
|
||||
return self.adjust_indices(length)[:-1]
|
||||
|
||||
def __repr__(self):
|
||||
return f"slice({self.start}, {self.stop}, {self.step})"
|
||||
|
||||
|
|
|
@ -333,140 +333,426 @@ def zip_longest(*args):
|
|||
|
||||
# Combinatoric iterators
|
||||
|
||||
def combinations(pool: Generator[T], r: int, T: type) -> Generator[List[T]]:
|
||||
"""
|
||||
Return successive r-length combinations of elements in the iterable.
|
||||
|
||||
combinations(range(4), 3) --> (0,1,2), (0,1,3), (0,2,3), (1,2,3)
|
||||
"""
|
||||
|
||||
def combinations_helper(pool: List[T], r: int, T: type) -> Generator[List[T]]:
|
||||
n = len(pool)
|
||||
if r > n:
|
||||
return
|
||||
indices = list(range(r))
|
||||
yield [pool[i] for i in indices]
|
||||
while True:
|
||||
b = -1
|
||||
for i in reversed(range(r)):
|
||||
if indices[i] != i + n - r:
|
||||
b = i
|
||||
break
|
||||
if b == -1:
|
||||
return
|
||||
indices[b] += 1
|
||||
for j in range(b + 1, r):
|
||||
indices[j] = indices[j - 1] + 1
|
||||
yield [pool[i] for i in indices]
|
||||
|
||||
if r < 0:
|
||||
raise ValueError("r must be non-negative")
|
||||
if hasattr(pool, "__getitem__") and hasattr(pool, "__len__"):
|
||||
return combinations_helper(pool, r)
|
||||
def _as_list(x):
|
||||
if isinstance(x, list):
|
||||
return x
|
||||
else:
|
||||
return combinations_helper([a for a in pool], r)
|
||||
return list(x)
|
||||
|
||||
def combinations_with_replacement(
|
||||
pool: Generator[T], r: int, T: type
|
||||
) -> Generator[List[T]]:
|
||||
"""
|
||||
Return successive r-length combinations of elements in the iterable
|
||||
allowing individual elements to have successive repeats.
|
||||
"""
|
||||
|
||||
def combinations_with_replacement_helper(
|
||||
pool: List[T], r: int, T: type
|
||||
) -> Generator[List[T]]:
|
||||
n = len(pool)
|
||||
if not n and r:
|
||||
return
|
||||
indices = [0 for _ in range(r)]
|
||||
yield [pool[i] for i in indices]
|
||||
while True:
|
||||
b = -1
|
||||
for i in reversed(range(r)):
|
||||
if indices[i] != n - 1:
|
||||
b = i
|
||||
break
|
||||
if b == -1:
|
||||
return
|
||||
newval = indices[b] + 1
|
||||
for j in range(r - b):
|
||||
indices[b + j] = newval
|
||||
yield [pool[i] for i in indices]
|
||||
|
||||
if r < 0:
|
||||
raise ValueError("r must be non-negative")
|
||||
if hasattr(pool, "__getitem__") and hasattr(pool, "__len__"):
|
||||
return combinations_with_replacement_helper(pool, r)
|
||||
else:
|
||||
return combinations_with_replacement_helper([a for a in pool], r)
|
||||
|
||||
def permutations(
|
||||
pool: Generator[T], r: Optional[int] = None, T: type
|
||||
) -> Generator[List[T]]:
|
||||
"""
|
||||
Return successive r-length permutations of elements in the iterable.
|
||||
"""
|
||||
|
||||
def permutations_helper(
|
||||
pool: List[T], r: Optional[int], T: type
|
||||
) -> Generator[List[T]]:
|
||||
n = len(pool)
|
||||
r: int = r if r is not None else n
|
||||
if r > n:
|
||||
return
|
||||
|
||||
indices = list(range(n))
|
||||
cycles = list(range(n, n - r, -1))
|
||||
yield [pool[i] for i in indices[:r]]
|
||||
while n:
|
||||
b = -1
|
||||
for i in reversed(range(r)):
|
||||
cycles[i] -= 1
|
||||
if cycles[i] == 0:
|
||||
indices = indices[:i] + indices[i + 1 :] + indices[i : i + 1]
|
||||
cycles[i] = n - i
|
||||
else:
|
||||
b = i
|
||||
j = cycles[i]
|
||||
indices[i], indices[-j] = indices[-j], indices[i]
|
||||
yield [pool[i] for i in indices[:r]]
|
||||
break
|
||||
if b == -1:
|
||||
return
|
||||
|
||||
if r is not None and r.__val__() < 0:
|
||||
raise ValueError("r must be non-negative")
|
||||
if hasattr(pool, "__getitem__") and hasattr(pool, "__len__"):
|
||||
return permutations_helper(pool, r)
|
||||
else:
|
||||
return permutations_helper([a for a in pool], r)
|
||||
|
||||
@inline
|
||||
def product(*args):
|
||||
"""
|
||||
Cartesian product of input iterables.
|
||||
"""
|
||||
if staticlen(args) == 0:
|
||||
yield ()
|
||||
else:
|
||||
for a in args[0]:
|
||||
rest = args[1:]
|
||||
for b in product(*rest):
|
||||
yield (a, *b)
|
||||
|
||||
@inline
|
||||
@overload
|
||||
def product(*args, repeat: int):
|
||||
"""
|
||||
Cartesian product of input iterables.
|
||||
"""
|
||||
def product(*iterables, repeat: int):
|
||||
if repeat < 0:
|
||||
raise ValueError("repeat argument cannot be negative")
|
||||
pools = [list(pool) for _ in range(repeat) for pool in args]
|
||||
result = [List[type(pools[0][0])]()]
|
||||
for pool in pools:
|
||||
result = [x + [y] for x in result for y in pool]
|
||||
for prod in result:
|
||||
yield prod
|
||||
raise ValueError("repeat must be non-negative")
|
||||
|
||||
if repeat == 0:
|
||||
nargs = 0
|
||||
else:
|
||||
nargs = len(iterables)
|
||||
|
||||
npools = nargs * repeat
|
||||
indices = Ptr[int](npools)
|
||||
|
||||
pools = list(capacity=npools)
|
||||
i = 0
|
||||
|
||||
while i < nargs:
|
||||
p = _as_list(iterables[i])
|
||||
if len(p) == 0:
|
||||
return
|
||||
pools.append(p)
|
||||
indices[i] = 0
|
||||
i += 1
|
||||
|
||||
while i < npools:
|
||||
pools.append(pools[i - nargs])
|
||||
indices[i] = 0
|
||||
i += 1
|
||||
|
||||
result = list(capacity=npools)
|
||||
for i in range(npools):
|
||||
result.append(pools[i][0])
|
||||
|
||||
while True:
|
||||
yield result
|
||||
|
||||
result = result.copy()
|
||||
i = npools - 1
|
||||
while i >= 0:
|
||||
pool = pools[i]
|
||||
indices[i] += 1
|
||||
|
||||
if indices[i] == len(pool):
|
||||
indices[i] = 0
|
||||
result[i] = pool[0]
|
||||
else:
|
||||
result[i] = pool[indices[i]]
|
||||
break
|
||||
|
||||
i -= 1
|
||||
|
||||
if i < 0:
|
||||
break
|
||||
|
||||
@overload
|
||||
def product(*iterables, repeat: Static[int] = 1):
|
||||
if repeat < 0:
|
||||
compile_error("repeat must be non-negative")
|
||||
|
||||
# handle some common cases
|
||||
if repeat == 0:
|
||||
yield ()
|
||||
elif repeat == 1 and staticlen(iterables) == 1:
|
||||
it0 = iterables[0]
|
||||
for a in it0:
|
||||
yield (a,)
|
||||
elif repeat == 1 and staticlen(iterables) == 2:
|
||||
it0 = iterables[0]
|
||||
it1 = iterables[1]
|
||||
for a in it0:
|
||||
for b in it1:
|
||||
yield (a, b)
|
||||
elif repeat == 1 and staticlen(iterables) == 3:
|
||||
it0 = iterables[0]
|
||||
it1 = iterables[1]
|
||||
it2 = iterables[2]
|
||||
for a in it0:
|
||||
for b in it1:
|
||||
for c in it2:
|
||||
yield (a, b, c)
|
||||
else:
|
||||
nargs: Static[int] = staticlen(iterables)
|
||||
npools: Static[int] = nargs * repeat
|
||||
indices_tuple = (0,) * npools
|
||||
indices = Ptr[int](__ptr__(indices_tuple).as_byte())
|
||||
pools = tuple(_as_list(it) for it in iterables) * repeat
|
||||
|
||||
for i in staticrange(nargs):
|
||||
if len(pools[i]) == 0:
|
||||
return
|
||||
|
||||
result = tuple(pool[0] for pool in pools)
|
||||
|
||||
while True:
|
||||
yield result
|
||||
|
||||
i = npools - 1
|
||||
while i >= 0:
|
||||
pool = pools[i]
|
||||
indices[i] += 1
|
||||
|
||||
if indices[i] == len(pool):
|
||||
indices[i] = 0
|
||||
else:
|
||||
break
|
||||
|
||||
i -= 1
|
||||
|
||||
if i < 0:
|
||||
break
|
||||
|
||||
result = tuple(pools[i][indices[i]] for i in staticrange(npools))
|
||||
|
||||
def combinations(pool, r: int):
|
||||
if r < 0:
|
||||
raise ValueError("r must be non-negative")
|
||||
|
||||
pool_list = _as_list(pool)
|
||||
n = len(pool)
|
||||
|
||||
if r > n:
|
||||
return
|
||||
|
||||
pool = pool_list.arr.ptr
|
||||
indices = Ptr[int](r)
|
||||
result = list(capacity=r)
|
||||
|
||||
for i in range(r):
|
||||
indices[i] = i
|
||||
result.append(pool[i])
|
||||
|
||||
while True:
|
||||
yield result
|
||||
|
||||
i = r - 1
|
||||
while i >= 0 and indices[i] == i + n - r:
|
||||
i -= 1
|
||||
|
||||
if i < 0:
|
||||
break
|
||||
|
||||
indices[i] += 1
|
||||
|
||||
for j in range(i + 1, r):
|
||||
indices[j] = indices[j-1] + 1
|
||||
|
||||
result = result.copy()
|
||||
while i < r:
|
||||
result[i] = pool[indices[i]]
|
||||
i += 1
|
||||
|
||||
@overload
|
||||
def combinations(pool, r: Static[int]):
|
||||
def empty(T: type) -> T:
|
||||
pass
|
||||
|
||||
if r < 0:
|
||||
compile_error("r must be non-negative")
|
||||
|
||||
if isinstance(pool, list):
|
||||
pool_list = pool
|
||||
else:
|
||||
pool_list = list(pool)
|
||||
|
||||
n = len(pool)
|
||||
|
||||
if r > n:
|
||||
return
|
||||
|
||||
pool = pool_list.arr.ptr
|
||||
indices_tuple = (0,) * r
|
||||
indices = Ptr[int](__ptr__(indices_tuple).as_byte())
|
||||
result_tuple = (empty(pool.T),) * r
|
||||
result = Ptr[pool.T](__ptr__(result_tuple).as_byte())
|
||||
|
||||
for i in range(r):
|
||||
indices[i] = i
|
||||
result[i] = pool[i]
|
||||
|
||||
while True:
|
||||
yield result_tuple
|
||||
|
||||
i = r - 1
|
||||
while i >= 0 and indices[i] == i + n - r:
|
||||
i -= 1
|
||||
|
||||
if i < 0:
|
||||
break
|
||||
|
||||
indices[i] += 1
|
||||
|
||||
for j in range(i + 1, r):
|
||||
indices[j] = indices[j-1] + 1
|
||||
|
||||
while i < r:
|
||||
result[i] = pool[indices[i]]
|
||||
i += 1
|
||||
|
||||
def combinations_with_replacement(pool, r: int):
|
||||
if r < 0:
|
||||
raise ValueError("r must be non-negative")
|
||||
|
||||
pool_list = _as_list(pool)
|
||||
n = len(pool)
|
||||
|
||||
if n == 0:
|
||||
if r == 0:
|
||||
yield List[pool_list.T](capacity=0)
|
||||
return
|
||||
|
||||
pool = pool_list.arr.ptr
|
||||
indices = Ptr[int](r)
|
||||
result = list(capacity=r)
|
||||
|
||||
for i in range(r):
|
||||
indices[i] = 0
|
||||
result.append(pool[0])
|
||||
|
||||
while True:
|
||||
yield result
|
||||
|
||||
i = r - 1
|
||||
while i >= 0 and indices[i] == n - 1:
|
||||
i -= 1
|
||||
|
||||
if i < 0:
|
||||
break
|
||||
|
||||
result = result.copy()
|
||||
index = indices[i] + 1
|
||||
elem = pool[index]
|
||||
|
||||
while i < r:
|
||||
indices[i] = index
|
||||
result[i] = elem
|
||||
i += 1
|
||||
|
||||
@overload
|
||||
def combinations_with_replacement(pool, r: Static[int]):
|
||||
def empty(T: type) -> T:
|
||||
pass
|
||||
|
||||
if r < 0:
|
||||
compile_error("r must be non-negative")
|
||||
|
||||
if r == 0:
|
||||
yield ()
|
||||
return
|
||||
|
||||
if isinstance(pool, list):
|
||||
pool_list = pool
|
||||
else:
|
||||
pool_list = list(pool)
|
||||
|
||||
n = len(pool)
|
||||
|
||||
if n == 0:
|
||||
return
|
||||
|
||||
pool = pool_list.arr.ptr
|
||||
indices_tuple = (0,) * r
|
||||
indices = Ptr[int](__ptr__(indices_tuple).as_byte())
|
||||
result_tuple = (empty(pool.T),) * r
|
||||
result = Ptr[pool.T](__ptr__(result_tuple).as_byte())
|
||||
|
||||
for i in range(r):
|
||||
result[i] = pool[0]
|
||||
|
||||
while True:
|
||||
yield result_tuple
|
||||
|
||||
i = r - 1
|
||||
while i >= 0 and indices[i] == n - 1:
|
||||
i -= 1
|
||||
|
||||
if i < 0:
|
||||
break
|
||||
|
||||
index = indices[i] + 1
|
||||
elem = pool[index]
|
||||
|
||||
while i < r:
|
||||
indices[i] = index
|
||||
result[i] = elem
|
||||
i += 1
|
||||
|
||||
def _permutations_non_static(pool, r = None):
|
||||
pool_list = _as_list(pool)
|
||||
n = len(pool)
|
||||
|
||||
if r is None:
|
||||
return _permutations_non_static(pool_list, n)
|
||||
elif not isinstance(r, int):
|
||||
compile_error("Expected int as r")
|
||||
|
||||
if r < 0:
|
||||
raise ValueError("r must be non-negative")
|
||||
|
||||
if r > n:
|
||||
return
|
||||
|
||||
indices = Ptr[int](n)
|
||||
cycles = Ptr[int](r)
|
||||
|
||||
for i in range(n):
|
||||
indices[i] = i
|
||||
|
||||
for i in range(r):
|
||||
cycles[i] = n - i
|
||||
|
||||
pool = pool_list.arr.ptr
|
||||
result = list(capacity=r)
|
||||
|
||||
for i in range(r):
|
||||
result.append(pool[i])
|
||||
|
||||
while True:
|
||||
yield result
|
||||
|
||||
if n == 0:
|
||||
break
|
||||
|
||||
result = result.copy()
|
||||
i = r - 1
|
||||
while i >= 0:
|
||||
cycles[i] -= 1
|
||||
if cycles[i] == 0:
|
||||
index = indices[i]
|
||||
for j in range(i, n - 1):
|
||||
indices[j] = indices[j+1]
|
||||
indices[n-1] = index
|
||||
cycles[i] = n - i
|
||||
else:
|
||||
j = cycles[i]
|
||||
index = indices[i]
|
||||
indices[i] = indices[n - j]
|
||||
indices[n - j] = index
|
||||
|
||||
for k in range(i, r):
|
||||
index = indices[k]
|
||||
result[k] = pool[index]
|
||||
|
||||
break
|
||||
i -= 1
|
||||
|
||||
if i < 0:
|
||||
break
|
||||
|
||||
def _permutations_static(pool, r: Static[int]):
|
||||
def empty(T: type) -> T:
|
||||
pass
|
||||
|
||||
pool_list = _as_list(pool)
|
||||
n = len(pool)
|
||||
|
||||
if r < 0:
|
||||
raise compile_error("r must be non-negative")
|
||||
|
||||
if r > n:
|
||||
return
|
||||
|
||||
indices = Ptr[int](n)
|
||||
cycles_tuple = (0,) * r
|
||||
cycles = Ptr[int](__ptr__(cycles_tuple).as_byte())
|
||||
|
||||
for i in range(n):
|
||||
indices[i] = i
|
||||
|
||||
for i in range(r):
|
||||
cycles[i] = n - i
|
||||
|
||||
pool = pool_list.arr.ptr
|
||||
result_tuple = (empty(pool.T),) * r
|
||||
result = Ptr[pool.T](__ptr__(result_tuple).as_byte())
|
||||
|
||||
for i in range(r):
|
||||
result[i] = pool[i]
|
||||
|
||||
while True:
|
||||
yield result_tuple
|
||||
|
||||
if n == 0:
|
||||
break
|
||||
|
||||
i = r - 1
|
||||
while i >= 0:
|
||||
cycles[i] -= 1
|
||||
if cycles[i] == 0:
|
||||
index = indices[i]
|
||||
for j in range(i, n - 1):
|
||||
indices[j] = indices[j+1]
|
||||
indices[n-1] = index
|
||||
cycles[i] = n - i
|
||||
else:
|
||||
j = cycles[i]
|
||||
index = indices[i]
|
||||
indices[i] = indices[n - j]
|
||||
indices[n - j] = index
|
||||
|
||||
for k in range(i, r):
|
||||
index = indices[k]
|
||||
result[k] = pool[index]
|
||||
|
||||
break
|
||||
i -= 1
|
||||
|
||||
if i < 0:
|
||||
break
|
||||
|
||||
def permutations(pool, r = None):
|
||||
if isinstance(pool, Tuple) and r is None:
|
||||
return _permutations_static(pool, staticlen(pool))
|
||||
else:
|
||||
return _permutations_non_static(pool, r)
|
||||
|
||||
@overload
|
||||
def permutations(pool, r: Static[int]):
|
||||
return _permutations_static(pool, r)
|
||||
|
|
|
@ -571,6 +571,171 @@ def test_dict():
|
|||
assert repr(Dict[int,int]()) == '{}'
|
||||
test_dict()
|
||||
|
||||
def slice_indices(slice, length):
|
||||
"""
|
||||
Reference implementation for the slice.indices method.
|
||||
|
||||
"""
|
||||
# Compute step and length as integers.
|
||||
#length = operator.index(length)
|
||||
step: int = 1 if slice.step is None else slice.step
|
||||
|
||||
# Raise ValueError for negative length or zero step.
|
||||
if length < 0:
|
||||
raise ValueError("length should not be negative")
|
||||
if step == 0:
|
||||
raise ValueError("slice step cannot be zero")
|
||||
|
||||
# Find lower and upper bounds for start and stop.
|
||||
lower = -1 if step < 0 else 0
|
||||
upper = length - 1 if step < 0 else length
|
||||
|
||||
# Compute start.
|
||||
if slice.start is None:
|
||||
start = upper if step < 0 else lower
|
||||
else:
|
||||
start = slice.start
|
||||
start = max(start + length, lower) if start < 0 else min(start, upper)
|
||||
|
||||
# Compute stop.
|
||||
if slice.stop is None:
|
||||
stop = lower if step < 0 else upper
|
||||
else:
|
||||
stop = slice.stop
|
||||
stop = max(stop + length, lower) if stop < 0 else min(stop, upper)
|
||||
|
||||
return start, stop, step
|
||||
|
||||
def check_indices(slice, length):
|
||||
err1 = False
|
||||
err2 = False
|
||||
|
||||
try:
|
||||
actual = slice.indices(length)
|
||||
except ValueError:
|
||||
err1 = True
|
||||
|
||||
try:
|
||||
expected = slice_indices(slice, length)
|
||||
except ValueError:
|
||||
err2 = True
|
||||
|
||||
if err1 or err2:
|
||||
return err1 and err2
|
||||
|
||||
if actual != expected:
|
||||
return False
|
||||
|
||||
if length >= 0 and slice.step != 0:
|
||||
actual = range(*slice.indices(length))
|
||||
expected = range(length)[slice]
|
||||
if actual != expected:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@test
|
||||
def test_slice():
|
||||
assert repr(slice(1, 2, 3)) == 'slice(1, 2, 3)'
|
||||
|
||||
s1 = slice(1, 2, 3)
|
||||
s2 = slice(1, 2, 3)
|
||||
s3 = slice(1, 2, 4)
|
||||
|
||||
assert s1 == s2
|
||||
assert s1 != s3
|
||||
|
||||
s = slice(1)
|
||||
assert s.start == None
|
||||
assert s.stop == 1
|
||||
assert s.step == None
|
||||
|
||||
s = slice(1, 2)
|
||||
assert s.start == 1
|
||||
assert s.stop == 2
|
||||
assert s.step == None
|
||||
|
||||
s = slice(1, 2, 3)
|
||||
assert s.start == 1
|
||||
assert s.stop == 2
|
||||
assert s.step == 3
|
||||
|
||||
# TODO
|
||||
assert slice(None ).indices(10) == (0, 10, 1)
|
||||
assert slice(None, None, 2).indices(10) == (0, 10, 2)
|
||||
assert slice(1, None, 2).indices(10) == (1, 10, 2)
|
||||
assert slice(None, None, -1).indices(10) == (9, -1, -1)
|
||||
assert slice(None, None, -2).indices(10) == (9, -1, -2)
|
||||
assert slice(3, None, -2).indices(10) == (3, -1, -2)
|
||||
# issue 3004 tests
|
||||
assert slice(None, -9).indices(10) == (0, 1, 1)
|
||||
assert slice(None, -10).indices(10) == (0, 0, 1)
|
||||
assert slice(None, -11).indices(10) == (0, 0, 1)
|
||||
assert slice(None, -10, -1).indices(10) == (9, 0, -1)
|
||||
assert slice(None, -11, -1).indices(10) == (9, -1, -1)
|
||||
assert slice(None, -12, -1).indices(10) == (9, -1, -1)
|
||||
assert slice(None, 9).indices(10) == (0, 9, 1)
|
||||
assert slice(None, 10).indices(10) == (0, 10, 1)
|
||||
assert slice(None, 11).indices(10) == (0, 10, 1)
|
||||
assert slice(None, 8, -1).indices(10) == (9, 8, -1)
|
||||
assert slice(None, 9, -1).indices(10) == (9, 9, -1)
|
||||
assert slice(None, 10, -1).indices(10) == (9, 9, -1)
|
||||
|
||||
assert slice(-100, 100 ).indices(10) == slice(None).indices(10)
|
||||
|
||||
assert slice(100, -100, -1).indices(10) == slice(None, None, -1).indices(10)
|
||||
|
||||
assert slice(-100, 100, 2).indices(10) == (0, 10, 2)
|
||||
|
||||
import sys
|
||||
assert list(range(10))[::sys.maxsize - 1] == [0]
|
||||
|
||||
# Check a variety of start, stop, step and length values, including
|
||||
# values exceeding sys.maxsize (see issue #14794).
|
||||
vals = [None, -2**100, -2**30, -53, -7, -1, 0, 1, 7, 53, 2**30, 2**100]
|
||||
lengths = [0, 1, 7, 53, 2**30, 2**100]
|
||||
#for slice_args in itertools.product(vals, repeat=3):
|
||||
for a in vals:
|
||||
for b in vals:
|
||||
for c in vals:
|
||||
slice_args = (a, b, c)
|
||||
s = slice(*slice_args)
|
||||
for length in lengths:
|
||||
assert check_indices(s, length)
|
||||
assert check_indices(slice(0, 10, 1), -3)
|
||||
|
||||
# Negative length should raise ValueError
|
||||
try:
|
||||
slice(None).indices(-1)
|
||||
assert False
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Zero step should raise ValueError
|
||||
try:
|
||||
slice(0, 10, 0).indices(5)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# ... but it should be fine to use a custom class that provides index.
|
||||
assert slice(0, 10, 1).indices(5) == (0, 5, 1)
|
||||
''' # not yet supported in Codon
|
||||
assert slice(MyIndexable(0), 10, 1).indices(5) == (0, 5, 1)
|
||||
assert slice(0, MyIndexable(10), 1).indices(5) == (0, 5, 1)
|
||||
assert slice(0, 10, MyIndexable(1)).indices(5) == (0, 5, 1)
|
||||
assert slice(0, 10, 1).indices(MyIndexable(5)) == (0, 5, 1)
|
||||
'''
|
||||
tmp = []
|
||||
class X[T](object):
|
||||
tmp: T
|
||||
def __setitem__(self, i, k):
|
||||
self.tmp.append((i, k))
|
||||
|
||||
x = X(tmp)
|
||||
x[1:2] = 42
|
||||
assert tmp == [(slice(1, 2), 42)]
|
||||
test_slice()
|
||||
|
||||
@test
|
||||
def test_deque():
|
||||
from collections import deque
|
||||
|
|
|
@ -61,7 +61,9 @@ def underten(x):
|
|||
|
||||
@test
|
||||
def test_combinations():
|
||||
assert list(itertools.combinations("ABCD", 2)) == [
|
||||
f = lambda x: x # hack to get non-static argument
|
||||
|
||||
assert list(itertools.combinations("ABCD", f(2))) == [
|
||||
["A", "B"],
|
||||
["A", "C"],
|
||||
["A", "D"],
|
||||
|
@ -69,7 +71,7 @@ def test_combinations():
|
|||
["B", "D"],
|
||||
["C", "D"],
|
||||
]
|
||||
test_intermediate = itertools.combinations("ABCD", 2)
|
||||
test_intermediate = itertools.combinations("ABCD", f(2))
|
||||
next(test_intermediate)
|
||||
assert list(test_intermediate) == [
|
||||
["A", "C"],
|
||||
|
@ -78,20 +80,49 @@ def test_combinations():
|
|||
["B", "D"],
|
||||
["C", "D"],
|
||||
]
|
||||
assert list(itertools.combinations(range(4), 3)) == [
|
||||
assert list(itertools.combinations(range(4), f(3))) == [
|
||||
[0, 1, 2],
|
||||
[0, 1, 3],
|
||||
[0, 2, 3],
|
||||
[1, 2, 3],
|
||||
]
|
||||
test_intermediate = itertools.combinations(range(4), 3)
|
||||
test_intermediate = itertools.combinations(range(4), f(3))
|
||||
next(test_intermediate)
|
||||
assert list(test_intermediate) == [[0, 1, 3], [0, 2, 3], [1, 2, 3]]
|
||||
|
||||
assert list(itertools.combinations("ABCD", 2)) == [
|
||||
("A", "B"),
|
||||
("A", "C"),
|
||||
("A", "D"),
|
||||
("B", "C"),
|
||||
("B", "D"),
|
||||
("C", "D"),
|
||||
]
|
||||
test_intermediate = itertools.combinations("ABCD", 2)
|
||||
next(test_intermediate)
|
||||
assert list(test_intermediate) == [
|
||||
("A", "C"),
|
||||
("A", "D"),
|
||||
("B", "C"),
|
||||
("B", "D"),
|
||||
("C", "D"),
|
||||
]
|
||||
assert list(itertools.combinations(range(4), 3)) == [
|
||||
(0, 1, 2),
|
||||
(0, 1, 3),
|
||||
(0, 2, 3),
|
||||
(1, 2, 3),
|
||||
]
|
||||
test_intermediate = itertools.combinations(range(4), 3)
|
||||
next(test_intermediate)
|
||||
assert list(test_intermediate) == [(0, 1, 3), (0, 2, 3), (1, 2, 3)]
|
||||
|
||||
|
||||
@test
|
||||
def test_combinations_with_replacement():
|
||||
assert list(itertools.combinations_with_replacement(range(3), 3)) == [
|
||||
f = lambda x: x # hack to get non-static argument
|
||||
|
||||
assert list(itertools.combinations_with_replacement(range(3), f(3))) == [
|
||||
[0, 0, 0],
|
||||
[0, 0, 1],
|
||||
[0, 0, 2],
|
||||
|
@ -103,7 +134,7 @@ def test_combinations_with_replacement():
|
|||
[1, 2, 2],
|
||||
[2, 2, 2],
|
||||
]
|
||||
assert list(itertools.combinations_with_replacement("ABC", 2)) == [
|
||||
assert list(itertools.combinations_with_replacement("ABC", f(2))) == [
|
||||
["A", "A"],
|
||||
["A", "B"],
|
||||
["A", "C"],
|
||||
|
@ -111,7 +142,7 @@ def test_combinations_with_replacement():
|
|||
["B", "C"],
|
||||
["C", "C"],
|
||||
]
|
||||
test_intermediate = itertools.combinations_with_replacement("ABC", 2)
|
||||
test_intermediate = itertools.combinations_with_replacement("ABC", f(2))
|
||||
next(test_intermediate)
|
||||
assert list(test_intermediate) == [
|
||||
["A", "B"],
|
||||
|
@ -121,6 +152,35 @@ def test_combinations_with_replacement():
|
|||
["C", "C"],
|
||||
]
|
||||
|
||||
assert list(itertools.combinations_with_replacement(range(3), 3)) == [
|
||||
(0, 0, 0),
|
||||
(0, 0, 1),
|
||||
(0, 0, 2),
|
||||
(0, 1, 1),
|
||||
(0, 1, 2),
|
||||
(0, 2, 2),
|
||||
(1, 1, 1),
|
||||
(1, 1, 2),
|
||||
(1, 2, 2),
|
||||
(2, 2, 2),
|
||||
]
|
||||
assert list(itertools.combinations_with_replacement("ABC", 2)) == [
|
||||
("A", "A"),
|
||||
("A", "B"),
|
||||
("A", "C"),
|
||||
("B", "B"),
|
||||
("B", "C"),
|
||||
("C", "C"),
|
||||
]
|
||||
test_intermediate = itertools.combinations_with_replacement("ABC", 2)
|
||||
next(test_intermediate)
|
||||
assert list(test_intermediate) == [
|
||||
("A", "B"),
|
||||
("A", "C"),
|
||||
("B", "B"),
|
||||
("B", "C"),
|
||||
("C", "C"),
|
||||
]
|
||||
|
||||
@test
|
||||
def test_islice():
|
||||
|
@ -243,7 +303,9 @@ def test_filterfalse():
|
|||
|
||||
@test
|
||||
def test_permutations():
|
||||
assert list(itertools.permutations(range(3), 2)) == [
|
||||
f = lambda x: x # hack to get non-static argument
|
||||
|
||||
assert list(itertools.permutations(range(3), f(2))) == [
|
||||
[0, 1],
|
||||
[0, 2],
|
||||
[1, 0],
|
||||
|
@ -255,6 +317,24 @@ def test_permutations():
|
|||
for n in range(3):
|
||||
values = [5 * x - 12 for x in range(n)]
|
||||
for r in range(n + 2):
|
||||
result = list(itertools.permutations(values, f(r)))
|
||||
if r > n: # right number of perms
|
||||
assert len(result) == 0
|
||||
# factorial is not yet implemented in math
|
||||
# else: fact(n) / fact(n - r)
|
||||
|
||||
assert list(itertools.permutations(range(3), 2)) == [
|
||||
(0, 1),
|
||||
(0, 2),
|
||||
(1, 0),
|
||||
(1, 2),
|
||||
(2, 0),
|
||||
(2, 1),
|
||||
]
|
||||
|
||||
for n in staticrange(3):
|
||||
values = [5 * x - 12 for x in range(n)]
|
||||
for r in staticrange(n + 2):
|
||||
result = list(itertools.permutations(values, r))
|
||||
if r > n: # right number of perms
|
||||
assert len(result) == 0
|
||||
|
@ -487,18 +567,19 @@ test_chain_from_iterable_from_cpython()
|
|||
|
||||
@test
|
||||
def test_combinations_from_cpython():
|
||||
f = lambda x: x # hack to get non-static argument
|
||||
from math import factorial as fact
|
||||
|
||||
err = False
|
||||
try:
|
||||
list(combinations("abc", -2))
|
||||
list(combinations("abc", f(-2)))
|
||||
assert False
|
||||
except ValueError:
|
||||
err = True
|
||||
assert err
|
||||
|
||||
assert list(combinations("abc", 32)) == [] # r > n
|
||||
assert list(combinations("ABCD", 2)) == [
|
||||
assert list(combinations("abc", f(32))) == [] # r > n
|
||||
assert list(combinations("ABCD", f(2))) == [
|
||||
["A", "B"],
|
||||
["A", "C"],
|
||||
["A", "D"],
|
||||
|
@ -506,7 +587,7 @@ def test_combinations_from_cpython():
|
|||
["B", "D"],
|
||||
["C", "D"],
|
||||
]
|
||||
assert list(combinations(range(4), 3)) == [
|
||||
assert list(combinations(range(4), f(3))) == [
|
||||
[0, 1, 2],
|
||||
[0, 1, 3],
|
||||
[0, 2, 3],
|
||||
|
@ -516,7 +597,7 @@ def test_combinations_from_cpython():
|
|||
for n in range(7):
|
||||
values = [5 * x - 12 for x in range(n)]
|
||||
for r in range(n + 2):
|
||||
result = list(combinations(values, r))
|
||||
result = list(combinations(values, f(r)))
|
||||
|
||||
assert len(result) == (0 if r > n else fact(n) // fact(r) // fact(n - r))
|
||||
assert len(result) == len(set(result)) # no repeats
|
||||
|
@ -531,21 +612,55 @@ def test_combinations_from_cpython():
|
|||
] # comb is a subsequence of the input iterable
|
||||
|
||||
|
||||
assert list(combinations("abc", 32)) == [] # r > n
|
||||
assert list(combinations("ABCD", 2)) == [
|
||||
("A", "B"),
|
||||
("A", "C"),
|
||||
("A", "D"),
|
||||
("B", "C"),
|
||||
("B", "D"),
|
||||
("C", "D"),
|
||||
]
|
||||
assert list(combinations(range(4), 3)) == [
|
||||
(0, 1, 2),
|
||||
(0, 1, 3),
|
||||
(0, 2, 3),
|
||||
(1, 2, 3),
|
||||
]
|
||||
|
||||
for n in staticrange(7):
|
||||
values = [5 * x - 12 for x in range(n)]
|
||||
for r in staticrange(n + 2):
|
||||
result = list(combinations(values, r))
|
||||
|
||||
assert len(result) == (0 if r > n else fact(n) // fact(r) // fact(n - r))
|
||||
assert len(result) == len(set(result)) # no repeats
|
||||
# assert result == sorted(result) # lexicographic order
|
||||
for c in result:
|
||||
assert len(c) == r # r-length combinations
|
||||
assert len(set(c)) == r # no duplicate elements
|
||||
assert list(c) == sorted(c) # keep original ordering
|
||||
assert all(e in values for e in c) # elements taken from input iterable
|
||||
assert list(c) == [
|
||||
e for e in values if e in c
|
||||
] # comb is a subsequence of the input iterable
|
||||
|
||||
test_combinations_from_cpython()
|
||||
|
||||
|
||||
@test
|
||||
def test_combinations_with_replacement_from_cpython():
|
||||
f = lambda x: x # hack to get non-static argument
|
||||
cwr = combinations_with_replacement
|
||||
err = False
|
||||
try:
|
||||
list(cwr("abc", -2))
|
||||
list(combinations_with_replacement("abc", f(-2)))
|
||||
assert False
|
||||
except ValueError:
|
||||
err = True
|
||||
assert err
|
||||
|
||||
assert list(cwr("ABC", 2)) == [
|
||||
assert list(combinations_with_replacement("ABC", f(2))) == [
|
||||
["A", "A"],
|
||||
["A", "B"],
|
||||
["A", "C"],
|
||||
|
@ -564,7 +679,44 @@ def test_combinations_with_replacement_from_cpython():
|
|||
for n in range(7):
|
||||
values = [5 * x - 12 for x in range(n)]
|
||||
for r in range(n + 2):
|
||||
result = list(cwr(values, r))
|
||||
result = list(combinations_with_replacement(values, r))
|
||||
regular_combs = list(combinations(values, r))
|
||||
|
||||
assert len(result) == numcombs(n, r)
|
||||
assert len(result) == len(set(result)) # no repeats
|
||||
# assert result == sorted(result) # lexicographic order
|
||||
|
||||
if n == 0 or r <= 1:
|
||||
assert result == regular_combs # cases that should be identical
|
||||
else:
|
||||
assert set(result) >= set(regular_combs)
|
||||
|
||||
for c in result:
|
||||
assert len(c) == r # r-length combinations
|
||||
noruns = [k for k, v in groupby(c)] # combo without consecutive repeats
|
||||
assert len(noruns) == len(
|
||||
set(noruns)
|
||||
) # no repeats other than consecutive
|
||||
assert list(c) == sorted(c) # keep original ordering
|
||||
assert all(e in values for e in c) # elements taken from input iterable
|
||||
assert noruns == [
|
||||
e for e in values if e in c
|
||||
] # comb is a subsequence of the input iterable
|
||||
|
||||
|
||||
assert list(combinations_with_replacement("ABC", 2)) == [
|
||||
("A", "A"),
|
||||
("A", "B"),
|
||||
("A", "C"),
|
||||
("B", "B"),
|
||||
("B", "C"),
|
||||
("C", "C"),
|
||||
]
|
||||
|
||||
for n in staticrange(7):
|
||||
values = [5 * x - 12 for x in range(n)]
|
||||
for r in staticrange(n + 2):
|
||||
result = list(combinations_with_replacement(values, r))
|
||||
regular_combs = list(combinations(values, r))
|
||||
|
||||
assert len(result) == numcombs(n, r)
|
||||
|
@ -594,18 +746,19 @@ test_combinations_with_replacement_from_cpython()
|
|||
|
||||
@test
|
||||
def test_permutations_from_cpython():
|
||||
f = lambda x: x # hack to get non-static argument
|
||||
from math import factorial as fact
|
||||
|
||||
err = False
|
||||
try:
|
||||
list(permutations("abc", -2))
|
||||
list(permutations("abc", f(-2)))
|
||||
assert False
|
||||
except ValueError:
|
||||
err = True
|
||||
assert err
|
||||
|
||||
assert list(permutations("abc", 32)) == []
|
||||
assert list(permutations(range(3), 2)) == [
|
||||
assert list(permutations("abc", f(32))) == []
|
||||
assert list(permutations(range(3), f(2))) == [
|
||||
[0, 1],
|
||||
[0, 2],
|
||||
[1, 0],
|
||||
|
@ -632,6 +785,33 @@ def test_permutations_from_cpython():
|
|||
assert result == list(permutations(values, None)) # test r as None
|
||||
assert result == list(permutations(values)) # test default r
|
||||
|
||||
assert list(permutations("abc", 32)) == []
|
||||
assert list(permutations(range(3), 2)) == [
|
||||
(0, 1),
|
||||
(0, 2),
|
||||
(1, 0),
|
||||
(1, 2),
|
||||
(2, 0),
|
||||
(2, 1),
|
||||
]
|
||||
|
||||
for n in staticrange(7):
|
||||
values = [5 * x - 12 for x in range(n)]
|
||||
for r in staticrange(n + 2):
|
||||
result = list(permutations(values, r))
|
||||
assert len(result) == (
|
||||
0 if r > n else fact(n) // fact(n - r)
|
||||
) # right number of perms
|
||||
assert len(result) == len(set(result)) # no repeats
|
||||
# assert result == sorted(result) # lexicographic order
|
||||
for p in result:
|
||||
assert len(p) == r # r-length permutations
|
||||
assert len(set(p)) == r # no duplicate elements
|
||||
assert all(e in values for e in p) # elements taken from input iterable
|
||||
|
||||
if r == n:
|
||||
assert result == list(permutations(values, r))
|
||||
|
||||
|
||||
test_permutations_from_cpython()
|
||||
|
||||
|
@ -728,6 +908,49 @@ def test_combinatorics_from_cpython():
|
|||
) # comb: cwr that is a perm
|
||||
assert comb == sorted(set(cwr) & set(perm)) # comb: both a cwr and a perm
|
||||
|
||||
for n in staticrange(6):
|
||||
s = "ABCDEFG"[:n]
|
||||
for r in staticrange(8):
|
||||
prod = list(product(s, repeat=r))
|
||||
cwr = list(combinations_with_replacement(s, r))
|
||||
perm = list(permutations(s, r))
|
||||
comb = list(combinations(s, r))
|
||||
|
||||
# Check size
|
||||
assert len(prod) == n ** r
|
||||
assert len(cwr) == (
|
||||
(fact(n + r - 1) // fact(r) // fact(n - 1)) if n else (0 if r else 1)
|
||||
)
|
||||
assert len(perm) == (0 if r > n else fact(n) // fact(n - r))
|
||||
assert len(comb) == (0 if r > n else fact(n) // fact(r) // fact(n - r))
|
||||
|
||||
# Check lexicographic order without repeated tuples
|
||||
assert prod == sorted(set(prod))
|
||||
assert cwr == sorted(set(cwr))
|
||||
assert perm == sorted(set(perm))
|
||||
assert comb == sorted(set(comb))
|
||||
|
||||
# Check interrelationships
|
||||
assert cwr == [
|
||||
t for t in prod if sorted(t) == list(t)
|
||||
] # cwr: prods which are sorted
|
||||
assert perm == [
|
||||
t for t in prod if len(set(t)) == r
|
||||
] # perm: prods with no dups
|
||||
assert comb == [
|
||||
t for t in perm if sorted(t) == list(t)
|
||||
] # comb: perms that are sorted
|
||||
assert comb == [
|
||||
t for t in cwr if len(set(t)) == r
|
||||
] # comb: cwrs without dups
|
||||
assert comb == list(
|
||||
filter(set(cwr).__contains__, perm)
|
||||
) # comb: perm that is a cwr
|
||||
assert comb == list(
|
||||
filter(set(perm).__contains__, cwr)
|
||||
) # comb: cwr that is a perm
|
||||
assert comb == sorted(set(cwr) & set(perm)) # comb: both a cwr and a perm
|
||||
|
||||
|
||||
test_combinatorics_from_cpython()
|
||||
|
||||
|
|
Loading…
Reference in New Issue