From 18d1d9b51e55cf3e21a03940591a0ee9c854b3b7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ibrahim=20Numanagi=C4=87?= Date: Thu, 30 Dec 2021 01:46:15 +0100 Subject: [PATCH] Seq backports (#8) * Backport seq-lang/seq@develop fixes * Backport seq-lang/seq@develop fixes * Resolve review issues * Resolve File.__copy__() issue * Resolve incorrect partial handling of object methods * Use ints and floats for __suffix__ methods * Fix float test * Update complex tests * Fix float constructor * Fix cmath test Co-authored-by: A. R. Shajii --- codon/parser/ast/expr.cpp | 2 +- codon/parser/ast/types.cpp | 12 ++- codon/parser/peg/grammar.peg | 5 +- codon/parser/visitors/simplify/simplify.h | 4 +- .../visitors/simplify/simplify_expr.cpp | 28 ++--- .../visitors/simplify/simplify_stmt.cpp | 4 +- .../visitors/typecheck/typecheck_ctx.cpp | 3 +- .../parser/visitors/typecheck/typecheck_ctx.h | 5 + .../visitors/typecheck/typecheck_expr.cpp | 72 ++++++++----- .../visitors/typecheck/typecheck_infer.cpp | 6 ++ .../visitors/typecheck/typecheck_stmt.cpp | 20 +++- docs/sphinx/primer.rst | 7 +- stdlib/algorithms/pdqsort.codon | 12 +-- stdlib/algorithms/qsort.codon | 12 ++- stdlib/collections.codon | 3 + stdlib/heapq.codon | 20 ++-- stdlib/internal/file.codon | 4 +- stdlib/internal/gc.codon | 9 +- stdlib/internal/python.codon | 10 +- stdlib/internal/str.codon | 6 -- stdlib/internal/types/array.codon | 7 ++ stdlib/internal/types/bool.codon | 2 + stdlib/internal/types/byte.codon | 2 + stdlib/internal/types/collections/dict.codon | 3 + stdlib/internal/types/collections/list.codon | 3 + stdlib/internal/types/collections/set.codon | 3 + stdlib/internal/types/complex.codon | 10 ++ stdlib/internal/types/float.codon | 2 + stdlib/internal/types/int.codon | 2 + stdlib/internal/types/intn.codon | 4 + stdlib/internal/types/str.codon | 7 ++ stdlib/itertools.codon | 5 +- test/parser/simplify_expr.codon | 15 ++- test/parser/simplify_stmt.codon | 20 +++- test/parser/typecheck_expr.codon | 1 + test/parser/typecheck_stmt.codon | 30 ++++++ test/parser/types.codon | 35 +++++- test/stdlib/cmath_test.codon | 101 +++++++++--------- 38 files changed, 335 insertions(+), 161 deletions(-) diff --git a/codon/parser/ast/expr.cpp b/codon/parser/ast/expr.cpp index 19f2a48c..ec5768a3 100644 --- a/codon/parser/ast/expr.cpp +++ b/codon/parser/ast/expr.cpp @@ -100,7 +100,7 @@ std::string IntExpr::toString() const { ACCEPT_IMPL(IntExpr, ASTVisitor); FloatExpr::FloatExpr(double floatValue) - : Expr(), value(std::to_string(floatValue)), floatValue(floatValue) {} + : Expr(), value(fmt::format("{:g}", floatValue)), floatValue(floatValue) {} FloatExpr::FloatExpr(const std::string &value, std::string suffix) : Expr(), value(value), suffix(std::move(suffix)), floatValue(0.0) {} std::string FloatExpr::toString() const { diff --git a/codon/parser/ast/types.cpp b/codon/parser/ast/types.cpp index e932851b..a08abed4 100644 --- a/codon/parser/ast/types.cpp +++ b/codon/parser/ast/types.cpp @@ -562,10 +562,12 @@ std::string PartialType::debugString(bool debug) const { std::vector as; int i, gi; for (i = 0, gi = 0; i < known.size(); i++) - if (!known[i]) - as.emplace_back("..."); - else - as.emplace_back(gs[gi++]); + if (!func->ast->args[i].generic) { + if (!known[i]) + as.emplace_back("..."); + else + as.emplace_back(gs[gi++]); + } return fmt::format("{}[{}]", !debug ? func->ast->name : func->debugString(debug), join(as, ",")); } @@ -756,7 +758,7 @@ int CallableTrait::unify(Type *typ, Unification *us) { if (args[0]->unify(pt->func->args[0].get(), us) == -1) return -1; for (int pi = 0, gi = 1; pi < pt->known.size(); pi++) - if (!pt->known[pi]) + if (!pt->known[pi] && !pt->func->ast->args[pi].generic) if (args[gi++]->unify(pt->func->args[pi + 1].get(), us) == -1) return -1; return 1; diff --git a/codon/parser/peg/grammar.peg b/codon/parser/peg/grammar.peg index 0f8cf6e2..ad292a2b 100644 --- a/codon/parser/peg/grammar.peg +++ b/codon/parser/peg/grammar.peg @@ -555,9 +555,8 @@ atom <- ast(LOC, ac(V0)), ast(LOC, ac(V1)) ); } - / FLOAT { - // Right now suffixes are _not_ supported - return ast(LOC, ac(V0), ""); + / FLOAT NAME? { + return ast(LOC, ac(V0), VS.size() > 1 ? ac(V1) : ""); } / INT NAME? { return ast(LOC, ac(V0), VS.size() > 1 ? ac(V1) : ""); diff --git a/codon/parser/visitors/simplify/simplify.h b/codon/parser/visitors/simplify/simplify.h index 7060ad46..df4a8317 100644 --- a/codon/parser/visitors/simplify/simplify.h +++ b/codon/parser/visitors/simplify/simplify.h @@ -343,8 +343,8 @@ private: /// (XXXuN and XXXiN), and other suffix integers to a corresponding integer value or a /// constructor: /// 123u -> UInt[64](123) - /// 123i56 -> Int[56]("123") (same for UInt) - /// 123pf -> int.__suffix_pf__("123") + /// 123i56 -> Int[56](123) (same for UInt) + /// 123pf -> int.__suffix_pf__(123) ExprPtr transformInt(const std::string &value, const std::string &suffix); /// Converts a float string to double. ExprPtr transformFloat(const std::string &value, const std::string &suffix); diff --git a/codon/parser/visitors/simplify/simplify_expr.cpp b/codon/parser/visitors/simplify/simplify_expr.cpp index a7242651..053e384b 100644 --- a/codon/parser/visitors/simplify/simplify_expr.cpp +++ b/codon/parser/visitors/simplify/simplify_expr.cpp @@ -691,48 +691,48 @@ ExprPtr SimplifyVisitor::transformInt(const std::string &value, return std::stoull(s.substr(2), nullptr, 2); return std::stoull(s, nullptr, 0); }; + int64_t val; try { - if (suffix.empty()) { - auto expr = N(to_int(value)); - return expr; - } + val = to_int(value); + if (suffix.empty()) + return N(val); /// Unsigned numbers: use UInt[64] for that if (suffix == "u") return transform(N(N(N("UInt"), N(64)), - N(to_int(value)))); + N(val))); /// Fixed-precision numbers (uXXX and iXXX) /// NOTE: you cannot use binary (0bXXX) format with those numbers. /// TODO: implement non-string constructor for these cases. if (suffix[0] == 'u' && isdigit(suffix.substr(1))) return transform(N( N(N("UInt"), N(std::stoi(suffix.substr(1)))), - N(value))); + N(val))); if (suffix[0] == 'i' && isdigit(suffix.substr(1))) return transform(N( N(N("Int"), N(std::stoi(suffix.substr(1)))), - N(value))); + N(val))); } catch (std::out_of_range &) { error("integer {} out of range", value); } /// Custom suffix sfx: use int.__suffix_sfx__(str) call. /// NOTE: you cannot neither use binary (0bXXX) format here. return transform(N(N("int", format("__suffix_{}__", suffix)), - N(value))); + N(val))); } ExprPtr SimplifyVisitor::transformFloat(const std::string &value, const std::string &suffix) { + double val; try { - if (suffix.empty()) { - auto expr = N(std::stod(value)); - return expr; - } + val = std::stod(value); } catch (std::out_of_range &) { - error("integer {} out of range", value); + error("float {} out of range", value); } + if (suffix.empty()) + return N(val); /// Custom suffix sfx: use float.__suffix_sfx__(str) call. return transform(N(N("float", format("__suffix_{}__", suffix)), - N(value))); + N(value))); } ExprPtr SimplifyVisitor::transformFString(std::string value) { diff --git a/codon/parser/visitors/simplify/simplify_stmt.cpp b/codon/parser/visitors/simplify/simplify_stmt.cpp index 53edd73c..a3158edc 100644 --- a/codon/parser/visitors/simplify/simplify_stmt.cpp +++ b/codon/parser/visitors/simplify/simplify_stmt.cpp @@ -786,8 +786,8 @@ void SimplifyVisitor::visit(ClassStmt *stmt) { if (bcName.empty() || !in(ctx->cache->classes, bcName)) error(baseClass.get(), "invalid base class"); baseASTs.push_back(ctx->cache->classes[bcName].ast.get()); - if (baseASTs.back()->attributes.has(Attr::Tuple) != isRecord) - error("tuples cannot inherit reference classes (and vice versa)"); + if (!isRecord && baseASTs.back()->attributes.has(Attr::Tuple)) + error("reference classes cannot inherit by-value classes"); if (baseASTs.back()->attributes.has(Attr::Internal)) error("cannot inherit internal types"); int si = 0; diff --git a/codon/parser/visitors/typecheck/typecheck_ctx.cpp b/codon/parser/visitors/typecheck/typecheck_ctx.cpp index ef2c0971..31297307 100644 --- a/codon/parser/visitors/typecheck/typecheck_ctx.cpp +++ b/codon/parser/visitors/typecheck/typecheck_ctx.cpp @@ -16,7 +16,8 @@ namespace ast { TypeContext::TypeContext(Cache *cache) : Context(""), cache(move(cache)), typecheckLevel(0), - allowActivation(true), age(0), realizationDepth(0) { + allowActivation(true), age(0), realizationDepth(0), blockLevel(0), + returnEarly(false) { stack.push_front(std::vector()); bases.push_back({"", nullptr, nullptr}); } diff --git a/codon/parser/visitors/typecheck/typecheck_ctx.h b/codon/parser/visitors/typecheck/typecheck_ctx.h index b0f6a597..f6852c5c 100644 --- a/codon/parser/visitors/typecheck/typecheck_ctx.h +++ b/codon/parser/visitors/typecheck/typecheck_ctx.h @@ -67,6 +67,11 @@ struct TypeContext : public Context { /// (e.g. class A: def __init__(a: A = A())). std::set defaultCallDepth; + /// Number of nested blocks (0 for toplevel) + int blockLevel; + /// True if an early return is found (anything afterwards won't be typechecked) + bool returnEarly; + public: explicit TypeContext(Cache *cache); diff --git a/codon/parser/visitors/typecheck/typecheck_expr.cpp b/codon/parser/visitors/typecheck/typecheck_expr.cpp index a0c46863..25cae4c6 100644 --- a/codon/parser/visitors/typecheck/typecheck_expr.cpp +++ b/codon/parser/visitors/typecheck/typecheck_expr.cpp @@ -709,10 +709,20 @@ ExprPtr TypecheckVisitor::transformBinary(BinaryExpr *expr, bool isAtomic, if (method) swap(expr->lexpr, expr->rexpr); } - if (!method) - error("cannot find magic '{}' in {}", magic, lt->toString()); - - return transform(N(N(method->ast->name), expr->lexpr, expr->rexpr)); + if (method) { + return transform( + N(N(method->ast->name), expr->lexpr, expr->rexpr)); + } else if (lt->is("pyobj")) { + return transform(N(N(N(expr->lexpr, "_getattr"), + N(format("__{}__", magic))), + expr->rexpr)); + } else if (rt->is("pyobj")) { + return transform(N(N(N(expr->rexpr, "_getattr"), + N(format("__r{}__", magic))), + expr->lexpr)); + } + error("cannot find magic '{}' in {}", magic, lt->toString()); + return nullptr; } ExprPtr TypecheckVisitor::transformStaticTupleIndex(ClassType *tuple, ExprPtr &expr, @@ -932,13 +942,13 @@ ExprPtr TypecheckVisitor::transformDot(DotExpr *expr, // Object access: y.method. Transform y.method to a partial call // typeof(t).foo(y, ...). std::vector methodArgs{expr->expr}; - for (int i = 0; i < std::max(1, (int)bestMethod->args.size() - 2); i++) - methodArgs.push_back(N()); + methodArgs.push_back(N()); // Handle @property methods. if (bestMethod->ast->attributes.has(Attr::Property)) methodArgs.pop_back(); ExprPtr e = N(N(bestMethod->ast->name), methodArgs); - return transform(e, false, allowVoidExpr); + ExprPtr r = transform(e, false, allowVoidExpr); + return r; } } @@ -1046,8 +1056,11 @@ ExprPtr TypecheckVisitor::transformCall(CallExpr *expr, const types::TypePtr &in // c = t.__new__(); c.__init__(args); c ExprPtr var = N(ctx->cache->getTemporaryVar("v")); return transform(N( - N(clone(var), N(N(expr->expr, "__new__"))), - N(N(N(clone(var), "__init__"), expr->args)), + N( + N(clone(var), N(N(expr->expr, "__new__"))), + N(N(N("std.internal.gc.register_finalizer"), + clone(var))), + N(N(N(clone(var), "__init__"), expr->args))), clone(var))); } } else if (auto pc = callee->getPartial()) { @@ -1055,8 +1068,13 @@ ExprPtr TypecheckVisitor::transformCall(CallExpr *expr, const types::TypePtr &in expr->expr = transform(N(N(clone(var), expr->expr), N(pc->func->ast->name))); calleeFn = expr->expr->type->getFunc(); - for (int i = 0, j = 0; i < calleeFn->ast->args.size(); i++) - known.push_back(calleeFn->ast->args[i].generic ? 0 : pc->known[j++]); + for (int i = 0, j = 0; i < pc->known.size(); i++) + if (pc->func->ast->args[i].generic) { + if (pc->known[i]) + unify(calleeFn->funcGenerics[j].type, pc->func->funcGenerics[j].type); + j++; + } + known = pc->known; seqassert(calleeFn, "not a function: {}", expr->expr->type->toString()); } else if (!callee->getFunc()) { // Case 3: callee is not a named function. Route it through a __call__ method. @@ -1070,6 +1088,7 @@ ExprPtr TypecheckVisitor::transformCall(CallExpr *expr, const types::TypePtr &in int typeArgCount = 0; bool isPartial = false; int ellipsisStage = -1; + auto newMask = std::vector(calleeFn->ast->args.size(), 1); if (expr->ordered) args = expr->args; else @@ -1085,6 +1104,7 @@ ExprPtr TypecheckVisitor::transformCall(CallExpr *expr, const types::TypePtr &in typeArgs.push_back(slots[si].empty() ? nullptr : expr->args[slots[si][0]].value); typeArgCount += typeArgs.back() != nullptr; + newMask[si] = slots[si].empty() ? 0 : 1; } else if (si == starArgIndex && !(partial && slots[si].empty())) { std::vector extra; for (auto &e : slots[si]) { @@ -1115,6 +1135,7 @@ ExprPtr TypecheckVisitor::transformCall(CallExpr *expr, const types::TypePtr &in args.push_back({"", ex}); } else if (partial) { args.push_back({"", transform(N())}); + newMask[si] = 0; } else { auto es = calleeFn->ast->args[si].deflt->toString(); if (in(ctx->defaultCallDepth, es)) @@ -1207,7 +1228,7 @@ ExprPtr TypecheckVisitor::transformCall(CallExpr *expr, const types::TypePtr &in // Handle default generics (calleeFn.g. foo[S, T=int]) only if all arguments were // unified. // TODO: remove once the proper partial handling of overloaded functions land - if (unificationsDone) + if (unificationsDone) { for (int i = 0, j = 0; i < calleeFn->ast->args.size(); i++) if (calleeFn->ast->args[i].generic) { if (calleeFn->ast->args[i].deflt && @@ -1222,6 +1243,7 @@ ExprPtr TypecheckVisitor::transformCall(CallExpr *expr, const types::TypePtr &in } j++; } + } for (int si = 0; si < replacements.size(); si++) if (replacements[si]) { if (replacements[si]->getFunc()) @@ -1238,13 +1260,7 @@ ExprPtr TypecheckVisitor::transformCall(CallExpr *expr, const types::TypePtr &in expr->done &= expr->expr->done; // Emit the final call. - std::vector newMask; - if (isPartial) - newMask = std::vector(calleeFn->args.size() - 1, 1); - for (int si = 0; si < calleeFn->args.size() - 1; si++) - if (args[si].value->getEllipsis() && !args[si].value->getEllipsis()->isPipeArg) - newMask[si] = 0; - if (!newMask.empty()) { + if (isPartial) { // Case 1: partial call. // Transform calleeFn(args...) to Partial.N.(args...). auto partialTypeName = generatePartialStub(newMask, calleeFn->getFunc().get()); @@ -1507,14 +1523,15 @@ std::string TypecheckVisitor::generateFunctionStub(int n) { std::string TypecheckVisitor::generatePartialStub(const std::vector &mask, types::FuncType *fn) { std::string strMask(mask.size(), '1'); + int tupleSize = 0; for (int i = 0; i < mask.size(); i++) if (!mask[i]) strMask[i] = '0'; - auto typeName = format(TYPE_PARTIAL "{}", strMask); - if (!ctx->find(typeName)) { - auto tupleSize = std::count_if(mask.begin(), mask.end(), [](char c) { return c; }); + else if (!fn->ast->args[i].generic) + tupleSize++; + auto typeName = format(TYPE_PARTIAL "{}.{}", strMask, fn->ast->name); + if (!ctx->find(typeName)) generateTupleStub(tupleSize, typeName, {}, false); - } return typeName; } @@ -1569,7 +1586,14 @@ void TypecheckVisitor::generateFnCall(int n) { ExprPtr TypecheckVisitor::partializeFunction(ExprPtr expr) { auto fn = expr->getType()->getFunc(); seqassert(fn, "not a function: {}", expr->getType()->toString()); - std::vector mask(fn->args.size() - 1, 0); + std::vector mask(fn->ast->args.size(), 0); + for (int i = 0, j = 0; i < fn->ast->args.size(); i++) + if (fn->ast->args[i].generic) { + // TODO: better detection of user-provided args...? + if (!fn->funcGenerics[j].type->getUnbound()) + mask[i] = 1; + j++; + } auto partialTypeName = generatePartialStub(mask, fn.get()); deactivateUnbounds(fn.get()); std::string var = ctx->cache->getTemporaryVar("partial"); diff --git a/codon/parser/visitors/typecheck/typecheck_infer.cpp b/codon/parser/visitors/typecheck/typecheck_infer.cpp index 329b51e2..2acb96ef 100644 --- a/codon/parser/visitors/typecheck/typecheck_infer.cpp +++ b/codon/parser/visitors/typecheck/typecheck_infer.cpp @@ -188,7 +188,13 @@ types::TypePtr TypecheckVisitor::realizeFunc(types::FuncType *type) { StmtPtr realized = nullptr; if (!isInternal) { + auto oldBlockLevel = ctx->blockLevel; + auto oldReturnEarly = ctx->returnEarly; + ctx->blockLevel = 0; + ctx->returnEarly = false; realized = inferTypes(ast->suite, false, type->realizedName()).second; + ctx->blockLevel = oldBlockLevel; + ctx->returnEarly = oldReturnEarly; if (ast->attributes.has(Attr::LLVM)) { auto s = realized->getSuite(); for (int i = 1; i < s->stmts.size(); i++) { diff --git a/codon/parser/visitors/typecheck/typecheck_stmt.cpp b/codon/parser/visitors/typecheck/typecheck_stmt.cpp index 307377b4..432857c0 100644 --- a/codon/parser/visitors/typecheck/typecheck_stmt.cpp +++ b/codon/parser/visitors/typecheck/typecheck_stmt.cpp @@ -49,11 +49,16 @@ void TypecheckVisitor::defaultVisit(Stmt *s) { void TypecheckVisitor::visit(SuiteStmt *stmt) { std::vector stmts; stmt->done = true; - for (auto &s : stmt->stmts) + ctx->blockLevel += int(stmt->ownBlock); + for (auto &s : stmt->stmts) { + if (ctx->returnEarly) + break; if (auto t = transform(s)) { stmts.push_back(t); stmt->done &= stmts.back()->done; } + } + ctx->blockLevel -= int(stmt->ownBlock); stmt->stmts = stmts; } @@ -205,16 +210,18 @@ void TypecheckVisitor::visit(ReturnStmt *stmt) { stmt->expr = transform(stmt->expr); if (stmt->expr) { auto &base = ctx->bases.back(); - wrapExpr(stmt->expr, base.returnType, nullptr); + if (!base.returnType->getUnbound()) + wrapExpr(stmt->expr, base.returnType, nullptr); if (stmt->expr->getType()->getFunc() && !(base.returnType->getClass() && startswith(base.returnType->getClass()->name, TYPE_FUNCTION))) stmt->expr = partializeFunction(stmt->expr); unify(base.returnType, stmt->expr->type); - auto retTyp = stmt->expr->getType()->getClass(); stmt->done = stmt->expr->done; } else { + if (ctx->blockLevel == 1) + ctx->returnEarly = true; stmt->done = true; } } @@ -307,7 +314,12 @@ void TypecheckVisitor::visit(IfStmt *stmt) { isTrue = !stmt->cond->staticValue.getString().empty(); else isTrue = stmt->cond->staticValue.getInt(); - resultStmt = transform(isTrue ? stmt->ifSuite : stmt->elseSuite); + resultStmt = isTrue ? stmt->ifSuite : stmt->elseSuite; + bool isOwn = // these blocks will not be a real owning blocks after inlining + resultStmt && resultStmt->getSuite() && resultStmt->getSuite()->ownBlock; + ctx->blockLevel -= isOwn; + resultStmt = transform(resultStmt); + ctx->blockLevel += isOwn; if (!resultStmt) resultStmt = transform(N()); return; diff --git a/docs/sphinx/primer.rst b/docs/sphinx/primer.rst index 80af9e80..13efa016 100644 --- a/docs/sphinx/primer.rst +++ b/docs/sphinx/primer.rst @@ -811,7 +811,12 @@ You can also add methods to types: Type extensions ~~~~~~~~~~~~~~~ -Suppose you have a class that lacks a method or an operator that might be really useful. Codon provides an ``@extend`` annotation that allows you to add and modify methods of various types at compile time, including built-in types like ``int`` or ``str``. This allows much of the functionality of built-in types to be implemented in Codon as type extensions in the standard library. +Suppose you have a class that lacks a method or an operator that might be really useful. + +Codon provides an ``@extend`` annotation that allows programmers to add and modify +methods of various types at compile time, including built-in types like ``int`` or ``str``. +This actually allows much of the functionality of built-in types to be implemented in +Codon as type extensions in the standard library. .. code:: python diff --git a/stdlib/algorithms/pdqsort.codon b/stdlib/algorithms/pdqsort.codon index 9607893e..0c3c5756 100644 --- a/stdlib/algorithms/pdqsort.codon +++ b/stdlib/algorithms/pdqsort.codon @@ -35,7 +35,7 @@ def _partial_insertion_sort[S,T](arr: Array[T], begin: int, end: int, keyf: Call arr[sift] = arr[sift_1] sift -= 1 sift_1 -= 1 - if sift == begin or keyf(tmp) >= keyf(arr[sift_1]): + if sift == begin or not keyf(tmp) < keyf(arr[sift_1]): break arr[sift] = tmp @@ -52,7 +52,7 @@ def _partition_left[S,T](arr: Array[T], begin: int, end: int, keyf: Callable[[T] while True: last -= 1 - if keyf(pivot) >= keyf(arr[last]): + if not keyf(pivot) < keyf(arr[last]): break if (last + 1 == end): @@ -71,7 +71,7 @@ def _partition_left[S,T](arr: Array[T], begin: int, end: int, keyf: Callable[[T] arr[first], arr[last] = arr[last], arr[first] while True: last -= 1 - if keyf(pivot) >= keyf(arr[last]): + if not keyf(pivot) < keyf(arr[last]): break while True: first += 1 @@ -91,7 +91,7 @@ def _partition_right[S,T](arr: Array[T], begin: int, end: int, keyf: Callable[[T while True: first += 1 - if keyf(arr[first]) >= keyf(pivot): + if not keyf(arr[first]) < keyf(pivot): break if first - 1 == begin: @@ -115,7 +115,7 @@ def _partition_right[S,T](arr: Array[T], begin: int, end: int, keyf: Callable[[T while True: first += 1 - if keyf(arr[first]) >= keyf(pivot): + if not keyf(arr[first]) < keyf(pivot): break while True: @@ -155,7 +155,7 @@ def _pdq_sort[S,T](arr: Array[T], begin: int, end: int, keyf: Callable[[T], S], else: _sort3(arr, begin + size_2, begin, end - 1, keyf) - if not leftmost and keyf(arr[begin - 1]) >= keyf(arr[begin]): + if not leftmost and not keyf(arr[begin - 1]) < keyf(arr[begin]): begin = _partition_left(arr, begin, end, keyf) + 1 continue diff --git a/stdlib/algorithms/qsort.codon b/stdlib/algorithms/qsort.codon index 26c0a0b2..69ccf57f 100644 --- a/stdlib/algorithms/qsort.codon +++ b/stdlib/algorithms/qsort.codon @@ -1,6 +1,8 @@ def _med3[S,T](a: int, b: int, c: int, d: Array[T], k: Callable[[T], S]): - return ((b if (k(d[b]) < k(d[c])) else (c if k(d[a]) < k(d[c]) else a)) - if (k(d[a]) < k(d[b])) else (b if (k(d[b]) > k(d[c])) else (c if k(d[a]) > k(d[c]) else a))) + if k(d[a]) < k(d[b]): + return b if (k(d[b]) < k(d[c])) else (c if k(d[a]) < k(d[c]) else a) + else: + return b if not (k(d[b]) < k(d[c]) or k(d[b]) == k(d[c])) else (c if not (k(d[a]) < k(d[c]) or k(d[a]) == k(d[c])) else a) def _swap[T](i: int, j: int, a: Array[T]): a[i], a[j] = a[j], a[i] @@ -17,7 +19,7 @@ def _qsort[S,T](arr: Array[T], frm: int, cnt: int, keyf: Callable[[T], S]): i = frm + 1 while i < frm + cnt: j = i - while j > frm and keyf(arr[j - 1]) > keyf(arr[j]): + while j > frm and not (keyf(arr[j - 1]) < keyf(arr[j]) or keyf(arr[j - 1]) == keyf(arr[j])): _swap(j, j - 1, arr) j -= 1 i += 1 @@ -41,13 +43,13 @@ def _qsort[S,T](arr: Array[T], frm: int, cnt: int, keyf: Callable[[T], S]): d = c while True: - while b <= c and keyf(arr[b]) <= keyf(arr[frm]): + while b <= c and (keyf(arr[b]) < keyf(arr[frm]) or keyf(arr[b]) == keyf(arr[frm])): if keyf(arr[b]) == keyf(arr[frm]): _swap(a, b, arr) a += 1 b += 1 - while c >= b and keyf(arr[c]) >= keyf(arr[frm]): + while c >= b and not keyf(arr[c]) < keyf(arr[frm]): if keyf(arr[c]) == keyf(arr[frm]): _swap(c, d, arr) d -= 1 diff --git a/stdlib/collections.codon b/stdlib/collections.codon index 50097155..fc18a627 100644 --- a/stdlib/collections.codon +++ b/stdlib/collections.codon @@ -104,6 +104,9 @@ class deque[T]: return True return False + def __deepcopy__(self): + return deque(i.__deepcopy__() for i in self) + def __copy__(self): return deque[T](copy(self._arr), self._head, self._tail, self._maxlen) diff --git a/stdlib/heapq.codon b/stdlib/heapq.codon index 25e40d9b..6d6216d0 100644 --- a/stdlib/heapq.codon +++ b/stdlib/heapq.codon @@ -148,7 +148,7 @@ def nsmallest[T](n: int, iterable: Generator[T], key = Optional[int]()): Equivalent to: sorted(iterable, key=key)[:n] """ if n == 1: - v = List[T](1) + v = List(1) for a in iterable: if not v: v.append(a) @@ -165,7 +165,7 @@ def nsmallest[T](n: int, iterable: Generator[T], key = Optional[int]()): it = iter(iterable) # put the range(n) first so that zip() doesn't # consume one too many elements from the iterator - result = List[Tuple[T,int]](n) + result = List(n) done = False for i in range(n): if it.done(): @@ -174,7 +174,7 @@ def nsmallest[T](n: int, iterable: Generator[T], key = Optional[int]()): result.append((it.next(), i)) if not result: it.destroy() - return List[T](0) + return [] _heapify_max(result) top = result[0][0] order = n @@ -191,7 +191,7 @@ def nsmallest[T](n: int, iterable: Generator[T], key = Optional[int]()): else: # General case, slowest method it = iter(iterable) - result = List[Tuple[type(key(T())),int,T]](n) + result = List(n) done = False for i in range(n): if it.done(): @@ -201,7 +201,7 @@ def nsmallest[T](n: int, iterable: Generator[T], key = Optional[int]()): result.append((key(elem), i, elem)) if not result: it.destroy() - return List[T](0) + return [] _heapify_max(result) top = result[0][0] order = n @@ -222,7 +222,7 @@ def nlargest[T](n: int, iterable: Generator[T], key = Optional[int]()): Equivalent to: sorted(iterable, key=key, reverse=True)[:n] """ if n == 1: - v = List[T](1) + v = List(1) for a in iterable: if not v: v.append(a) @@ -237,7 +237,7 @@ def nlargest[T](n: int, iterable: Generator[T], key = Optional[int]()): # When key is none, use simpler decoration if isinstance(key, Optional): it = iter(iterable) - result = List[Tuple[T,int]](n) + result = List(n) done = False for i in range(0, -n, -1): if it.done(): @@ -246,7 +246,7 @@ def nlargest[T](n: int, iterable: Generator[T], key = Optional[int]()): result.append((it.next(), i)) if not result: it.destroy() - return List[T](0) + return [] heapify(result) top = result[0][0] order = -n @@ -263,7 +263,7 @@ def nlargest[T](n: int, iterable: Generator[T], key = Optional[int]()): else: # General case, slowest method it = iter(iterable) - result = List[Tuple[type(key(T())),int,T]](n) + result = List(n) done = False for i in range(0, -n, -1): if it.done(): @@ -272,7 +272,7 @@ def nlargest[T](n: int, iterable: Generator[T], key = Optional[int]()): elem = it.next() result.append((key(elem), i, elem)) if not result: - return List[T](0) + return [] heapify(result) top = result[0][0] order = -n diff --git a/stdlib/internal/file.codon b/stdlib/internal/file.codon index 8f319d62..5e792494 100644 --- a/stdlib/internal/file.codon +++ b/stdlib/internal/file.codon @@ -27,7 +27,7 @@ class File: def __iter__(self): for a in self._iter(): - yield copy(a) + yield a.__ptrcopy__() def readlines(self): return [l for l in self] @@ -142,7 +142,7 @@ class gzFile: def __iter__(self): for a in self._iter(): - yield copy(a) + yield a.__ptrcopy__() def __enter__(self): pass diff --git a/stdlib/internal/gc.codon b/stdlib/internal/gc.codon index ec2fe228..596f9210 100644 --- a/stdlib/internal/gc.codon +++ b/stdlib/internal/gc.codon @@ -35,9 +35,6 @@ def realloc(p: cobj, sz: int): def free(p: cobj): seq_free(p) -def register_finalizer(p: cobj, f: Function[[cobj, cobj], void]): - seq_register_finalizer(p, f.__raw__()) - def add_roots(start: cobj, end: cobj): seq_gc_add_roots(start, end) @@ -49,3 +46,9 @@ def clear_roots(): def exclude_static_roots(start: cobj, end: cobj): seq_gc_exclude_static_roots(start, end) + +def register_finalizer(p): + if hasattr(p, '__del__'): + def f(x: cobj, data: cobj, T: type): + Ptr[T](__ptr__(x).as_byte())[0].__del__() + seq_register_finalizer(p.__raw__(), f(T=type(p), ...).__raw__()) diff --git a/stdlib/internal/python.codon b/stdlib/internal/python.codon index c5614649..3301f318 100644 --- a/stdlib/internal/python.codon +++ b/stdlib/internal/python.codon @@ -132,12 +132,6 @@ class pyobj: def _getattr(self, name: str): return pyobj.exc_wrap(pyobj(PyObject_GetAttrString(self.p, name.c_str()))) - def __getitem__(self, t): - return self._getattr("__getitem__")(t) - - def __add__(self, t): - return self._getattr("__add__")(t) - def __setitem__(self, name: str, val: pyobj): return pyobj.exc_wrap(pyobj(PyObject_SetAttrString(self.p, name.c_str(), val.p))) @@ -232,7 +226,7 @@ class pyobj: ensure_initialized() PyRun_SimpleString(code.c_str()) - def get[T](self) -> T: + def get(self, T: type) -> T: return T.__from_py__(self) def none(): @@ -244,7 +238,7 @@ def none(): def py(x) -> pyobj: return x.__to_py__() -def get[T](x: pyobj) -> T: +def get(x: pyobj, T: type) -> T: return T.__from_py__(x) @extend diff --git a/stdlib/internal/str.codon b/stdlib/internal/str.codon index ffd7e17a..790f1b39 100644 --- a/stdlib/internal/str.codon +++ b/stdlib/internal/str.codon @@ -106,12 +106,6 @@ class str: n += self.len return str(p, total) - def __copy__(self): - n = len(self) - p = cobj(n) - str.memcpy(p, self.ptr, n) - return str(p, n) - def _cmp(self, other: str): n = min(self.len, other.len) i = 0 diff --git a/stdlib/internal/types/array.codon b/stdlib/internal/types/array.codon index e052f5ce..834db41b 100644 --- a/stdlib/internal/types/array.codon +++ b/stdlib/internal/types/array.codon @@ -10,6 +10,13 @@ class Array: p = Ptr[T](self.len) str.memcpy(p.as_byte(), self.ptr.as_byte(), self.len * sizeof(T)) return (self.len, p) + def __deepcopy__(self) -> Array[T]: + p = Ptr[T](self.len) + i = 0 + while i < self.len: + p[i] = self.ptr[i].__deepcopy__() + i += 1 + return (self.len, p) def __len__(self) -> int: return self.len def __bool__(self) -> bool: diff --git a/stdlib/internal/types/bool.codon b/stdlib/internal/types/bool.codon index 81ba8c5f..de66f660 100644 --- a/stdlib/internal/types/bool.codon +++ b/stdlib/internal/types/bool.codon @@ -10,6 +10,8 @@ class bool: return "True" if self else "False" def __copy__(self) -> bool: return self + def __deepcopy__(self) -> bool: + return self def __bool__(self) -> bool: return self def __hash__(self): diff --git a/stdlib/internal/types/byte.codon b/stdlib/internal/types/byte.codon index fa22afc9..764c056b 100644 --- a/stdlib/internal/types/byte.codon +++ b/stdlib/internal/types/byte.codon @@ -21,6 +21,8 @@ class byte: ret i8 %0 def __copy__(self) -> byte: return self + def __deepcopy__(self) -> byte: + return self @pure @llvm def __bool__(self) -> bool: diff --git a/stdlib/internal/types/collections/dict.codon b/stdlib/internal/types/collections/dict.codon index fc7a177b..6a742fb0 100644 --- a/stdlib/internal/types/collections/dict.codon +++ b/stdlib/internal/types/collections/dict.codon @@ -113,6 +113,9 @@ class Dict[K,V]: str.memcpy(vals_copy.as_byte(), self._vals.as_byte(), n * gc.sizeof(V)) return Dict[K,V](n, self._size, self._n_occupied, self._upper_bound, flags_copy, keys_copy, vals_copy) + def __deepcopy__(self): + return {k.__deepcopy__(): v.__deepcopy__() for k, v in self.items()} + def __repr__(self): n = self.__len__() if n == 0: diff --git a/stdlib/internal/types/collections/list.codon b/stdlib/internal/types/collections/list.codon index d4ce27ba..dd490b10 100644 --- a/stdlib/internal/types/collections/list.codon +++ b/stdlib/internal/types/collections/list.codon @@ -173,6 +173,9 @@ class List: def __copy__(self): return List[T](self.arr.__copy__(), self.len) + def __deepcopy__(self): + return [l.__deepcopy__() for l in self] + def __iter__(self): i = 0 N = self.len diff --git a/stdlib/internal/types/collections/set.codon b/stdlib/internal/types/collections/set.codon index 8e26f965..c045fc78 100644 --- a/stdlib/internal/types/collections/set.codon +++ b/stdlib/internal/types/collections/set.codon @@ -119,6 +119,9 @@ class Set[K]: str.memcpy(keys_copy.as_byte(), self._keys.as_byte(), n * gc.sizeof(K)) return Set[K](n, self._size, self._n_occupied, self._upper_bound, flags_copy, keys_copy) + def __deepcopy__(self): + return {s.__deepcopy__() for s in self} + def __repr__(self): n = self.__len__() if n == 0: diff --git a/stdlib/internal/types/complex.codon b/stdlib/internal/types/complex.codon index ef4502b4..c3e3d8de 100644 --- a/stdlib/internal/types/complex.codon +++ b/stdlib/internal/types/complex.codon @@ -266,3 +266,13 @@ class complex: declare double @llvm.log.f64(double) %y = call double @llvm.log.f64(double %x) ret double %y + +@extend +class int: + def __suffix_j__(x: int): + return complex(0, x) + +@extend +class float: + def __suffix_j__(x: float): + return complex(0, x) diff --git a/stdlib/internal/types/float.codon b/stdlib/internal/types/float.codon index 572a6fbf..54b358e0 100644 --- a/stdlib/internal/types/float.codon +++ b/stdlib/internal/types/float.codon @@ -17,6 +17,8 @@ class float: return s if s != "-nan" else "nan" def __copy__(self) -> float: return self + def __deepcopy__(self) -> float: + return self @pure @llvm def __int__(self) -> int: diff --git a/stdlib/internal/types/int.codon b/stdlib/internal/types/int.codon index 42002ff8..5e7be7db 100644 --- a/stdlib/internal/types/int.codon +++ b/stdlib/internal/types/int.codon @@ -31,6 +31,8 @@ class int: return seq_str_int(self) def __copy__(self) -> int: return self + def __deepcopy__(self) -> int: + return self def __hash__(self) -> int: return self @pure diff --git a/stdlib/internal/types/intn.codon b/stdlib/internal/types/intn.codon index 94c9a3d3..4ed993da 100644 --- a/stdlib/internal/types/intn.codon +++ b/stdlib/internal/types/intn.codon @@ -50,6 +50,8 @@ class Int: return int(self) def __copy__(self) -> Int[N]: return self + def __deepcopy__(self) -> Int[N]: + return self def __hash__(self) -> int: return int(self) @pure @@ -255,6 +257,8 @@ class UInt: return int(self) def __copy__(self) -> UInt[N]: return self + def __deepcopy__(self) -> UInt[N]: + return self def __hash__(self) -> int: return int(self) @pure diff --git a/stdlib/internal/types/str.codon b/stdlib/internal/types/str.codon index d28dff66..4a9e638c 100644 --- a/stdlib/internal/types/str.codon +++ b/stdlib/internal/types/str.codon @@ -24,6 +24,13 @@ class str: return self.len != 0 def __copy__(self) -> str: return self + def __deepcopy__(self) -> str: + return self + def __ptrcopy__(self) -> str: + n = self.len + p = cobj(n) + str.memcpy(p, self.ptr, n) + return str(p, n) @llvm def memcpy(dest: Ptr[byte], src: Ptr[byte], len: int) -> void: declare void @llvm.memcpy.p0i8.p0i8.i64(i8* %dest, i8* %src, i64 %len, i32 %align, i1 %isvolatile) diff --git a/stdlib/itertools.codon b/stdlib/itertools.codon index 9461c316..bc987669 100644 --- a/stdlib/itertools.codon +++ b/stdlib/itertools.codon @@ -283,8 +283,9 @@ def starmap(function, iterable): for args in iterable: yield function(*args) +# TODO: fix this once Optional[Callable] lands @inline -def groupby[T](iterable: Generator[T], key: Callable[[T], S] = NoneType, S: type = NoneType): +def groupby(iterable, key = Optional[int]()): """ Make an iterator that returns consecutive keys and groups from the iterable. """ @@ -292,7 +293,7 @@ def groupby[T](iterable: Generator[T], key: Callable[[T], S] = NoneType, S: type group = [] for currvalue in iterable: - k = currvalue if isinstance(key, NoneType) else key(currvalue) + k = currvalue if isinstance(key, Optional) else key(currvalue) if currkey is None: currkey = k if k != ~currkey: diff --git a/test/parser/simplify_expr.codon b/test/parser/simplify_expr.codon index 6be91e16..e2638d8b 100644 --- a/test/parser/simplify_expr.codon +++ b/test/parser/simplify_expr.codon @@ -27,8 +27,8 @@ print -1u7 #: 127 @extend class int: - def __suffix_test__(s: str): - return 'TEST: ' + s + def __suffix_test__(s): + return 'TEST: ' + str(s) print 123_456test #: TEST: 123456 #%% int_error,barebones @@ -39,8 +39,13 @@ print 5.15 #: 5.15 print 2e2 #: 200 print 2.e-2 #: 0.02 -#%% float_error,barebones -print 5.__str__() #! syntax error +#%% float_suffix,barebones +@extend +class float: + def __suffix_zoo__(x): + return str(x) + '_zoo' + +print 1.2e-1zoo #: 0.12_zoo #%% string,barebones print 'kthxbai', "kthxbai" #: kthxbai kthxbai @@ -269,7 +274,7 @@ f[:,0] #: ((start: None, stop: None, step: None), 0) Ptr[9.99] #! expected type or static parameters #%% index_error_b,barebones -Ptr['s'] #! cannot unify T and "s" +Ptr['s'] #! cannot unify T and 's' #%% index_error_static,barebones Ptr[1] #! cannot unify T and 1 diff --git a/test/parser/simplify_stmt.codon b/test/parser/simplify_stmt.codon index b43a7988..3fe932bf 100644 --- a/test/parser/simplify_stmt.codon +++ b/test/parser/simplify_stmt.codon @@ -851,6 +851,11 @@ print np.transpose(a) n = np.array([[1, 2], [3, 4]]) print n[0], n[0][0] + 1 #: [1 2] 2 +a = np.array([1,2,3]) +print(a + 1) #: [2 3 4] +print(a - 1) #: [0 1 2] +print(1 - a) #: [ 0 -1 -2] + #%% python_import_fn from python import re.split(str, str) -> List[str] as rs print rs(r'\W+', 'Words, words, words.') #: ['Words', 'words', 'words', ''] @@ -943,10 +948,23 @@ class defdict(Dict[str,float]): z = defdict() z[1.1] #! cannot unify float and str +#%% inherit_tuple,barebones +class Foo: + a: int + b: str + def __init__(self, a: int): + self.a, self.b = a, 'yoo' +@tuple +class FooTup(Foo): pass + +f = Foo(5) +print f.a, f.b #: 5 yoo +fp = FooTup(6, 's') +print fp #: (a: 6, b: 's') #%% inherit_class_err_1,barebones class defdict(Array[int]): - pass #! tuples cannot inherit reference classes (and vice versa) + pass #! reference classes cannot inherit by-value classes #%% inherit_class_err_2,barebones @tuple diff --git a/test/parser/typecheck_expr.codon b/test/parser/typecheck_expr.codon index beefbfeb..dedc61de 100644 --- a/test/parser/typecheck_expr.codon +++ b/test/parser/typecheck_expr.codon @@ -441,6 +441,7 @@ def foo(x, *args, **kwargs): print x, args, kwargs p = foo(...) p(1, z=5) #: 1 () (z: 5) +p('s', zh=65) #: s () (zh: 65) q = p(zh=43, ...) q(1) #: 1 () (zh: 43) r = q(5, 38, ...) diff --git a/test/parser/typecheck_stmt.codon b/test/parser/typecheck_stmt.codon index 3e8ab0b8..af6cbe3c 100644 --- a/test/parser/typecheck_stmt.codon +++ b/test/parser/typecheck_stmt.codon @@ -282,3 +282,33 @@ class int: yield self self -= 1 print list((5).run_lola_run()) #: [5, 4, 3, 2, 1] + + +#%% early_return,barebones +def foo(x): + print x-1 + return + print len(x) +foo(5) #: 4 + +def foo(x): + if isinstance(x, int): + print x+1 + return + print len(x) +foo(1) #: 2 +foo('s') #: 1 + +def foo(x, y: Static[int] = 5): + if y < 3: + if y > 1: + if isinstance(x, int): + print x+1 + return + if isinstance(x, int): + return + print len(x) +foo(1, 1) +foo(1, 2) #: 2 +foo(1) +foo('s') #: 1 diff --git a/test/parser/types.codon b/test/parser/types.codon index 53e14255..a90894cb 100644 --- a/test/parser/types.codon +++ b/test/parser/types.codon @@ -765,10 +765,7 @@ def bar[T](x): foo(bar, 1) #: 1 int #: bar[...] -foo(bar(T=int,...), 1) -#: 1 int -#: bar[...] -foo(bar(T=str,...), 's') +foo(bar(...), 's') #: s str #: bar[...] z = bar @@ -780,7 +777,16 @@ z(1, T=str) zz = bar(T=int,...) zz(1) #: 1 int -# zz('s') # TODO: zz = foo[int] is update stmt, messes up everything... :/ + +#%% forward_error,barebones +def foo(f, x): + f(x, type(x)) + print f.__class__ +def bar[T](x): + print x, T.__class__ +foo(bar(T=int,...), 1) +#! too many arguments for bar[T1,int] (expected maximum 2, got 2) +#! while realizing foo (arguments foo[bar[...],int]) #%% sort_partial def foo(x, y): @@ -1082,3 +1088,22 @@ def test(v: Optional[int]): print v.__class__ test(5) #: Optional[int] test(None) #: Optional[int] + +#%% methodcaller,barebones +def foo(): + def bar(a, b): + print 'bar', a, b + return bar +foo()(1, 2) #: bar 1 2 + +def methodcaller(foo: Static[str]): + def caller(foo: Static[str], obj, *args, **kwargs): + if isinstance(getattr(obj, foo)(*args, **kwargs), void): + getattr(obj, foo)(*args, **kwargs) + else: + return getattr(obj, foo)(*args, **kwargs) + return caller(foo=foo, ...) +v = [1] +methodcaller('append')(v, 42) +print v #: [1, 42] +print methodcaller('index')(v, 42) #: 1 diff --git a/test/stdlib/cmath_test.codon b/test/stdlib/cmath_test.codon index c46933b4..313cfd84 100644 --- a/test/stdlib/cmath_test.codon +++ b/test/stdlib/cmath_test.codon @@ -3,7 +3,6 @@ import cmath INF = float('inf') NAN = float('nan') -j = complex(0, 1) def float_identical(x, y): if math.isnan(x) or math.isnan(y): @@ -27,11 +26,11 @@ def complex_identical(x, y): ########### ZERO_DIVISION = ( - (1+1*j, 0+0*j), - (1+1*j, 0.0+0*j), - (1+1*j, 0+0*j), - (1.0+0*j, 0+0*j), - (1+0*j, 0+0*j), + (1+1j, 0+0j), + (1+1j, 0.0+0j), + (1+1j, 0+0j), + (1.0+0j, 0+0j), + (1+0j, 0+0j), ) def close_abs(x, y, eps=1e-9): @@ -81,14 +80,14 @@ def test_truediv(): # A naive complex division algorithm (such as in 2.0) is very prone to # nonsense errors for these (overflows and underflows). - assert check_div(complex(1e200, 1e200), 1+0*j) - assert check_div(complex(1e-200, 1e-200), 1+0*j) + assert check_div(complex(1e200, 1e200), 1+0j) + assert check_div(complex(1e-200, 1e-200), 1+0j) # Just for fun. for i in range(100): check_div(complex(random(), random()), complex(random(), random())) - assert close_complex(complex.__truediv__(2+0*j, 1+1*j), 1-1*j) + assert close_complex(complex.__truediv__(2+0j, 1+1j), 1-1j) for denom_real, denom_imag in [(0., NAN), (NAN, 0.), (NAN, NAN)]: z = complex(0, 0) / complex(denom_real, denom_imag) @@ -98,51 +97,51 @@ test_truediv() @test def test_richcompare(): - assert not complex.__eq__(1+1*j, 1<<10000) - assert complex.__eq__(1+1*j, 1+1*j) - assert not complex.__eq__(1+1*j, 2+2*j) - assert not complex.__ne__(1+1*j, 1+1*j) - assert complex.__ne__(1+1*j, 2+2*j), True + assert not complex.__eq__(1+1j, 1<<10000) + assert complex.__eq__(1+1j, 1+1j) + assert not complex.__eq__(1+1j, 2+2j) + assert not complex.__ne__(1+1j, 1+1j) + assert complex.__ne__(1+1j, 2+2j), True for i in range(1, 100): f = i / 100.0 - assert complex.__eq__(f+0*j, f) - assert not complex.__ne__(f+0*j, f) + assert complex.__eq__(f+0j, f) + assert not complex.__ne__(f+0j, f) assert not complex.__eq__(complex(f, f), f) assert complex.__ne__(complex(f, f), f) import operator - assert operator.eq(1+1*j, 1+1*j) == True - assert operator.eq(1+1*j, 2+2*j) == False - assert operator.ne(1+1*j, 1+1*j) == False - assert operator.ne(1+1*j, 2+2*j) == True + assert operator.eq(1+1j, 1+1j) == True + assert operator.eq(1+1j, 2+2j) == False + assert operator.ne(1+1j, 1+1j) == False + assert operator.ne(1+1j, 2+2j) == True test_richcompare() @test def test_pow(): def pow(a, b): return a ** b - assert close_complex(pow(1+1*j, 0+0*j), 1.0) - assert close_complex(pow(0+0*j, 2+0*j), 0.0) - assert close_complex(pow(1*j, -1), 1/(1*j)) - assert close_complex(pow(1*j, 200), 1) + assert close_complex(pow(1+1j, 0+0j), 1.0) + assert close_complex(pow(0+0j, 2+0j), 0.0) + assert close_complex(pow(1j, -1), 1/(1j)) + assert close_complex(pow(1j, 200), 1) - a = 3.33+4.43*j - assert a ** (0*j) == 1 - assert a ** (0.+0.*j) == 1 + a = 3.33+4.43j + assert a ** (0j) == 1 + assert a ** (0.+0.j) == 1 - assert (3*j) ** (0*j) == 1 - assert (3*j) ** 0 == 1 + assert (3j) ** (0j) == 1 + assert (3j) ** 0 == 1 # The following is used to exercise certain code paths assert a ** 105 == a ** 105 assert a ** -105 == a ** -105 assert a ** -30 == a ** -30 - assert (0.0*j) ** 0 == 1 + assert (0.0j) ** 0 == 1 test_pow() @test def test_conjugate(): - assert close_complex(complex(5.3, 9.8).conjugate(), 5.3-9.8*j) + assert close_complex(complex(5.3, 9.8).conjugate(), 5.3-9.8j) test_conjugate() @test @@ -419,8 +418,8 @@ def test_polar(): assert check(0, (0., 0.)) assert check(1, (1., 0.)) assert check(-1, (1., pi)) - assert check(1*j, (1., pi / 2)) - assert check(-3*j, (3., -pi / 2)) + assert check(1j, (1., pi / 2)) + assert check(-3j, (3., -pi / 2)) inf = float('inf') assert check(complex(inf, 0), (inf, 0.)) assert check(complex(-inf, 0), (inf, pi)) @@ -446,10 +445,10 @@ def test_phase(): assert almost_equal(phase(0), 0.) assert almost_equal(phase(1.), 0.) assert almost_equal(phase(-1.), pi) - assert almost_equal(phase(-1.+1E-300*j), pi) - assert almost_equal(phase(-1.-1E-300*j), -pi) - assert almost_equal(phase(1*j), pi/2) - assert almost_equal(phase(-1*j), -pi/2) + assert almost_equal(phase(-1.+1E-300j), pi) + assert almost_equal(phase(-1.-1E-300j), -pi) + assert almost_equal(phase(1j), pi/2) + assert almost_equal(phase(-1j), -pi/2) # zeros assert phase(complex(0.0, 0.0)) == 0.0 @@ -539,7 +538,7 @@ test_isfinite() @test def test_isnan(): assert not cmath.isnan(1) - assert not cmath.isnan(1*j) + assert not cmath.isnan(1j) assert not cmath.isnan(INF) assert cmath.isnan(NAN) assert cmath.isnan(complex(NAN, 0)) @@ -552,7 +551,7 @@ test_isnan() @test def test_isinf(): assert not cmath.isinf(1) - assert not cmath.isinf(1*j) + assert not cmath.isinf(1j) assert not cmath.isinf(NAN) assert cmath.isinf(INF) assert cmath.isinf(complex(INF, 0)) @@ -583,10 +582,10 @@ test_atanh_sign() @test def test_is_close(): # test complex values that are close to within 12 decimal places - complex_examples = [(1.0+1.0*j, 1.000000000001+1.0*j), - (1.0+1.0*j, 1.0+1.000000000001*j), - (-1.0+1.0*j, -1.000000000001+1.0*j), - (1.0-1.0*j, 1.0-0.999999999999*j), + complex_examples = [(1.0+1.0j, 1.000000000001+1.0j), + (1.0+1.0j, 1.0+1.000000000001j), + (-1.0+1.0j, -1.000000000001+1.0j), + (1.0-1.0j, 1.0-0.999999999999j), ] for a,b in complex_examples: @@ -594,20 +593,20 @@ def test_is_close(): assert not cmath.isclose(a, b, rel_tol=1e-13) # test values near zero that are near to within three decimal places - near_zero_examples = [(0.001*j, 0), - (0.001 + 0*j, 0), - (0.001+0.001*j, 0), - (-0.001+0.001*j, 0), - (0.001-0.001*j, 0), - (-0.001-0.001*j, 0), + near_zero_examples = [(0.001j, 0), + (0.001 + 0j, 0), + (0.001+0.001j, 0), + (-0.001+0.001j, 0), + (0.001-0.001j, 0), + (-0.001-0.001j, 0), ] for a,b in near_zero_examples: assert cmath.isclose(a, b, abs_tol=1.5e-03) assert not cmath.isclose(a, b, abs_tol=0.5e-03) - assert cmath.isclose(0.001-0.001*j, 0.001+0.001*j, abs_tol=2e-03) - assert not cmath.isclose(0.001-0.001*j, 0.001+0.001*j, abs_tol=1e-03) + assert cmath.isclose(0.001-0.001j, 0.001+0.001j, abs_tol=2e-03) + assert not cmath.isclose(0.001-0.001j, 0.001+0.001j, abs_tol=1e-03) test_is_close() @test