Merge branch 'develop' of https://github.com/exaloop/codon into develop

pull/10/head
A. R. Shajii 2021-12-29 20:06:24 -05:00
commit 240f2947c5
38 changed files with 334 additions and 160 deletions

View File

@ -100,7 +100,7 @@ std::string IntExpr::toString() const {
ACCEPT_IMPL(IntExpr, ASTVisitor);
FloatExpr::FloatExpr(double floatValue)
: Expr(), value(std::to_string(floatValue)), floatValue(floatValue) {}
: Expr(), value(fmt::format("{:g}", floatValue)), floatValue(floatValue) {}
FloatExpr::FloatExpr(const std::string &value, std::string suffix)
: Expr(), value(value), suffix(std::move(suffix)), floatValue(0.0) {}
std::string FloatExpr::toString() const {

View File

@ -562,10 +562,12 @@ std::string PartialType::debugString(bool debug) const {
std::vector<std::string> as;
int i, gi;
for (i = 0, gi = 0; i < known.size(); i++)
if (!known[i])
as.emplace_back("...");
else
as.emplace_back(gs[gi++]);
if (!func->ast->args[i].generic) {
if (!known[i])
as.emplace_back("...");
else
as.emplace_back(gs[gi++]);
}
return fmt::format("{}[{}]", !debug ? func->ast->name : func->debugString(debug),
join(as, ","));
}
@ -756,7 +758,7 @@ int CallableTrait::unify(Type *typ, Unification *us) {
if (args[0]->unify(pt->func->args[0].get(), us) == -1)
return -1;
for (int pi = 0, gi = 1; pi < pt->known.size(); pi++)
if (!pt->known[pi])
if (!pt->known[pi] && !pt->func->ast->args[pi].generic)
if (args[gi++]->unify(pt->func->args[pi + 1].get(), us) == -1)
return -1;
return 1;

View File

@ -555,9 +555,8 @@ atom <-
ast<IntExpr>(LOC, ac<string>(V0)), ast<IntExpr>(LOC, ac<string>(V1))
);
}
/ FLOAT {
// Right now suffixes are _not_ supported
return ast<FloatExpr>(LOC, ac<string>(V0), "");
/ FLOAT NAME? {
return ast<FloatExpr>(LOC, ac<string>(V0), VS.size() > 1 ? ac<string>(V1) : "");
}
/ INT NAME? {
return ast<IntExpr>(LOC, ac<string>(V0), VS.size() > 1 ? ac<string>(V1) : "");

View File

@ -343,8 +343,8 @@ private:
/// (XXXuN and XXXiN), and other suffix integers to a corresponding integer value or a
/// constructor:
/// 123u -> UInt[64](123)
/// 123i56 -> Int[56]("123") (same for UInt)
/// 123pf -> int.__suffix_pf__("123")
/// 123i56 -> Int[56](123) (same for UInt)
/// 123pf -> int.__suffix_pf__(123)
ExprPtr transformInt(const std::string &value, const std::string &suffix);
/// Converts a float string to double.
ExprPtr transformFloat(const std::string &value, const std::string &suffix);

View File

@ -691,48 +691,48 @@ ExprPtr SimplifyVisitor::transformInt(const std::string &value,
return std::stoull(s.substr(2), nullptr, 2);
return std::stoull(s, nullptr, 0);
};
int64_t val;
try {
if (suffix.empty()) {
auto expr = N<IntExpr>(to_int(value));
return expr;
}
val = to_int(value);
if (suffix.empty())
return N<IntExpr>(val);
/// Unsigned numbers: use UInt[64] for that
if (suffix == "u")
return transform(N<CallExpr>(N<IndexExpr>(N<IdExpr>("UInt"), N<IntExpr>(64)),
N<IntExpr>(to_int(value))));
N<IntExpr>(val)));
/// Fixed-precision numbers (uXXX and iXXX)
/// NOTE: you cannot use binary (0bXXX) format with those numbers.
/// TODO: implement non-string constructor for these cases.
if (suffix[0] == 'u' && isdigit(suffix.substr(1)))
return transform(N<CallExpr>(
N<IndexExpr>(N<IdExpr>("UInt"), N<IntExpr>(std::stoi(suffix.substr(1)))),
N<StringExpr>(value)));
N<IntExpr>(val)));
if (suffix[0] == 'i' && isdigit(suffix.substr(1)))
return transform(N<CallExpr>(
N<IndexExpr>(N<IdExpr>("Int"), N<IntExpr>(std::stoi(suffix.substr(1)))),
N<StringExpr>(value)));
N<IntExpr>(val)));
} catch (std::out_of_range &) {
error("integer {} out of range", value);
}
/// Custom suffix sfx: use int.__suffix_sfx__(str) call.
/// NOTE: you cannot neither use binary (0bXXX) format here.
return transform(N<CallExpr>(N<DotExpr>("int", format("__suffix_{}__", suffix)),
N<StringExpr>(value)));
N<IntExpr>(val)));
}
ExprPtr SimplifyVisitor::transformFloat(const std::string &value,
const std::string &suffix) {
double val;
try {
if (suffix.empty()) {
auto expr = N<FloatExpr>(std::stod(value));
return expr;
}
val = std::stod(value);
} catch (std::out_of_range &) {
error("integer {} out of range", value);
error("float {} out of range", value);
}
if (suffix.empty())
return N<FloatExpr>(val);
/// Custom suffix sfx: use float.__suffix_sfx__(str) call.
return transform(N<CallExpr>(N<DotExpr>("float", format("__suffix_{}__", suffix)),
N<StringExpr>(value)));
N<FloatExpr>(value)));
}
ExprPtr SimplifyVisitor::transformFString(std::string value) {

View File

@ -786,8 +786,8 @@ void SimplifyVisitor::visit(ClassStmt *stmt) {
if (bcName.empty() || !in(ctx->cache->classes, bcName))
error(baseClass.get(), "invalid base class");
baseASTs.push_back(ctx->cache->classes[bcName].ast.get());
if (baseASTs.back()->attributes.has(Attr::Tuple) != isRecord)
error("tuples cannot inherit reference classes (and vice versa)");
if (!isRecord && baseASTs.back()->attributes.has(Attr::Tuple))
error("reference classes cannot inherit by-value classes");
if (baseASTs.back()->attributes.has(Attr::Internal))
error("cannot inherit internal types");
int si = 0;

View File

@ -16,7 +16,8 @@ namespace ast {
TypeContext::TypeContext(Cache *cache)
: Context<TypecheckItem>(""), cache(move(cache)), typecheckLevel(0),
allowActivation(true), age(0), realizationDepth(0) {
allowActivation(true), age(0), realizationDepth(0), blockLevel(0),
returnEarly(false) {
stack.push_front(std::vector<std::string>());
bases.push_back({"", nullptr, nullptr});
}

View File

@ -67,6 +67,11 @@ struct TypeContext : public Context<TypecheckItem> {
/// (e.g. class A: def __init__(a: A = A())).
std::set<std::string> defaultCallDepth;
/// Number of nested blocks (0 for toplevel)
int blockLevel;
/// True if an early return is found (anything afterwards won't be typechecked)
bool returnEarly;
public:
explicit TypeContext(Cache *cache);

View File

@ -709,10 +709,20 @@ ExprPtr TypecheckVisitor::transformBinary(BinaryExpr *expr, bool isAtomic,
if (method)
swap(expr->lexpr, expr->rexpr);
}
if (!method)
error("cannot find magic '{}' in {}", magic, lt->toString());
return transform(N<CallExpr>(N<IdExpr>(method->ast->name), expr->lexpr, expr->rexpr));
if (method) {
return transform(
N<CallExpr>(N<IdExpr>(method->ast->name), expr->lexpr, expr->rexpr));
} else if (lt->is("pyobj")) {
return transform(N<CallExpr>(N<CallExpr>(N<DotExpr>(expr->lexpr, "_getattr"),
N<StringExpr>(format("__{}__", magic))),
expr->rexpr));
} else if (rt->is("pyobj")) {
return transform(N<CallExpr>(N<CallExpr>(N<DotExpr>(expr->rexpr, "_getattr"),
N<StringExpr>(format("__r{}__", magic))),
expr->lexpr));
}
error("cannot find magic '{}' in {}", magic, lt->toString());
return nullptr;
}
ExprPtr TypecheckVisitor::transformStaticTupleIndex(ClassType *tuple, ExprPtr &expr,
@ -932,13 +942,13 @@ ExprPtr TypecheckVisitor::transformDot(DotExpr *expr,
// Object access: y.method. Transform y.method to a partial call
// typeof(t).foo(y, ...).
std::vector<ExprPtr> methodArgs{expr->expr};
for (int i = 0; i < std::max(1, (int)bestMethod->args.size() - 2); i++)
methodArgs.push_back(N<EllipsisExpr>());
methodArgs.push_back(N<EllipsisExpr>());
// Handle @property methods.
if (bestMethod->ast->attributes.has(Attr::Property))
methodArgs.pop_back();
ExprPtr e = N<CallExpr>(N<IdExpr>(bestMethod->ast->name), methodArgs);
return transform(e, false, allowVoidExpr);
ExprPtr r = transform(e, false, allowVoidExpr);
return r;
}
}
@ -1046,8 +1056,11 @@ ExprPtr TypecheckVisitor::transformCall(CallExpr *expr, const types::TypePtr &in
// c = t.__new__(); c.__init__(args); c
ExprPtr var = N<IdExpr>(ctx->cache->getTemporaryVar("v"));
return transform(N<StmtExpr>(
N<AssignStmt>(clone(var), N<CallExpr>(N<DotExpr>(expr->expr, "__new__"))),
N<ExprStmt>(N<CallExpr>(N<DotExpr>(clone(var), "__init__"), expr->args)),
N<SuiteStmt>(
N<AssignStmt>(clone(var), N<CallExpr>(N<DotExpr>(expr->expr, "__new__"))),
N<ExprStmt>(N<CallExpr>(N<IdExpr>("std.internal.gc.register_finalizer"),
clone(var))),
N<ExprStmt>(N<CallExpr>(N<DotExpr>(clone(var), "__init__"), expr->args))),
clone(var)));
}
} else if (auto pc = callee->getPartial()) {
@ -1055,8 +1068,13 @@ ExprPtr TypecheckVisitor::transformCall(CallExpr *expr, const types::TypePtr &in
expr->expr = transform(N<StmtExpr>(N<AssignStmt>(clone(var), expr->expr),
N<IdExpr>(pc->func->ast->name)));
calleeFn = expr->expr->type->getFunc();
for (int i = 0, j = 0; i < calleeFn->ast->args.size(); i++)
known.push_back(calleeFn->ast->args[i].generic ? 0 : pc->known[j++]);
for (int i = 0, j = 0; i < pc->known.size(); i++)
if (pc->func->ast->args[i].generic) {
if (pc->known[i])
unify(calleeFn->funcGenerics[j].type, pc->func->funcGenerics[j].type);
j++;
}
known = pc->known;
seqassert(calleeFn, "not a function: {}", expr->expr->type->toString());
} else if (!callee->getFunc()) {
// Case 3: callee is not a named function. Route it through a __call__ method.
@ -1070,6 +1088,7 @@ ExprPtr TypecheckVisitor::transformCall(CallExpr *expr, const types::TypePtr &in
int typeArgCount = 0;
bool isPartial = false;
int ellipsisStage = -1;
auto newMask = std::vector<char>(calleeFn->ast->args.size(), 1);
if (expr->ordered)
args = expr->args;
else
@ -1085,6 +1104,7 @@ ExprPtr TypecheckVisitor::transformCall(CallExpr *expr, const types::TypePtr &in
typeArgs.push_back(slots[si].empty() ? nullptr
: expr->args[slots[si][0]].value);
typeArgCount += typeArgs.back() != nullptr;
newMask[si] = slots[si].empty() ? 0 : 1;
} else if (si == starArgIndex && !(partial && slots[si].empty())) {
std::vector<ExprPtr> extra;
for (auto &e : slots[si]) {
@ -1115,6 +1135,7 @@ ExprPtr TypecheckVisitor::transformCall(CallExpr *expr, const types::TypePtr &in
args.push_back({"", ex});
} else if (partial) {
args.push_back({"", transform(N<EllipsisExpr>())});
newMask[si] = 0;
} else {
auto es = calleeFn->ast->args[si].deflt->toString();
if (in(ctx->defaultCallDepth, es))
@ -1207,7 +1228,7 @@ ExprPtr TypecheckVisitor::transformCall(CallExpr *expr, const types::TypePtr &in
// Handle default generics (calleeFn.g. foo[S, T=int]) only if all arguments were
// unified.
// TODO: remove once the proper partial handling of overloaded functions land
if (unificationsDone)
if (unificationsDone) {
for (int i = 0, j = 0; i < calleeFn->ast->args.size(); i++)
if (calleeFn->ast->args[i].generic) {
if (calleeFn->ast->args[i].deflt &&
@ -1222,6 +1243,7 @@ ExprPtr TypecheckVisitor::transformCall(CallExpr *expr, const types::TypePtr &in
}
j++;
}
}
for (int si = 0; si < replacements.size(); si++)
if (replacements[si]) {
if (replacements[si]->getFunc())
@ -1238,13 +1260,7 @@ ExprPtr TypecheckVisitor::transformCall(CallExpr *expr, const types::TypePtr &in
expr->done &= expr->expr->done;
// Emit the final call.
std::vector<char> newMask;
if (isPartial)
newMask = std::vector<char>(calleeFn->args.size() - 1, 1);
for (int si = 0; si < calleeFn->args.size() - 1; si++)
if (args[si].value->getEllipsis() && !args[si].value->getEllipsis()->isPipeArg)
newMask[si] = 0;
if (!newMask.empty()) {
if (isPartial) {
// Case 1: partial call.
// Transform calleeFn(args...) to Partial.N<known>.<calleeFn>(args...).
auto partialTypeName = generatePartialStub(newMask, calleeFn->getFunc().get());
@ -1507,14 +1523,15 @@ std::string TypecheckVisitor::generateFunctionStub(int n) {
std::string TypecheckVisitor::generatePartialStub(const std::vector<char> &mask,
types::FuncType *fn) {
std::string strMask(mask.size(), '1');
int tupleSize = 0;
for (int i = 0; i < mask.size(); i++)
if (!mask[i])
strMask[i] = '0';
auto typeName = format(TYPE_PARTIAL "{}", strMask);
if (!ctx->find(typeName)) {
auto tupleSize = std::count_if(mask.begin(), mask.end(), [](char c) { return c; });
else if (!fn->ast->args[i].generic)
tupleSize++;
auto typeName = format(TYPE_PARTIAL "{}.{}", strMask, fn->ast->name);
if (!ctx->find(typeName))
generateTupleStub(tupleSize, typeName, {}, false);
}
return typeName;
}
@ -1569,7 +1586,14 @@ void TypecheckVisitor::generateFnCall(int n) {
ExprPtr TypecheckVisitor::partializeFunction(ExprPtr expr) {
auto fn = expr->getType()->getFunc();
seqassert(fn, "not a function: {}", expr->getType()->toString());
std::vector<char> mask(fn->args.size() - 1, 0);
std::vector<char> mask(fn->ast->args.size(), 0);
for (int i = 0, j = 0; i < fn->ast->args.size(); i++)
if (fn->ast->args[i].generic) {
// TODO: better detection of user-provided args...?
if (!fn->funcGenerics[j].type->getUnbound())
mask[i] = 1;
j++;
}
auto partialTypeName = generatePartialStub(mask, fn.get());
deactivateUnbounds(fn.get());
std::string var = ctx->cache->getTemporaryVar("partial");

View File

@ -188,7 +188,13 @@ types::TypePtr TypecheckVisitor::realizeFunc(types::FuncType *type) {
StmtPtr realized = nullptr;
if (!isInternal) {
auto oldBlockLevel = ctx->blockLevel;
auto oldReturnEarly = ctx->returnEarly;
ctx->blockLevel = 0;
ctx->returnEarly = false;
realized = inferTypes(ast->suite, false, type->realizedName()).second;
ctx->blockLevel = oldBlockLevel;
ctx->returnEarly = oldReturnEarly;
if (ast->attributes.has(Attr::LLVM)) {
auto s = realized->getSuite();
for (int i = 1; i < s->stmts.size(); i++) {

View File

@ -49,11 +49,16 @@ void TypecheckVisitor::defaultVisit(Stmt *s) {
void TypecheckVisitor::visit(SuiteStmt *stmt) {
std::vector<StmtPtr> stmts;
stmt->done = true;
for (auto &s : stmt->stmts)
ctx->blockLevel += int(stmt->ownBlock);
for (auto &s : stmt->stmts) {
if (ctx->returnEarly)
break;
if (auto t = transform(s)) {
stmts.push_back(t);
stmt->done &= stmts.back()->done;
}
}
ctx->blockLevel -= int(stmt->ownBlock);
stmt->stmts = stmts;
}
@ -205,16 +210,18 @@ void TypecheckVisitor::visit(ReturnStmt *stmt) {
stmt->expr = transform(stmt->expr);
if (stmt->expr) {
auto &base = ctx->bases.back();
wrapExpr(stmt->expr, base.returnType, nullptr);
if (!base.returnType->getUnbound())
wrapExpr(stmt->expr, base.returnType, nullptr);
if (stmt->expr->getType()->getFunc() &&
!(base.returnType->getClass() &&
startswith(base.returnType->getClass()->name, TYPE_FUNCTION)))
stmt->expr = partializeFunction(stmt->expr);
unify(base.returnType, stmt->expr->type);
auto retTyp = stmt->expr->getType()->getClass();
stmt->done = stmt->expr->done;
} else {
if (ctx->blockLevel == 1)
ctx->returnEarly = true;
stmt->done = true;
}
}
@ -307,7 +314,12 @@ void TypecheckVisitor::visit(IfStmt *stmt) {
isTrue = !stmt->cond->staticValue.getString().empty();
else
isTrue = stmt->cond->staticValue.getInt();
resultStmt = transform(isTrue ? stmt->ifSuite : stmt->elseSuite);
resultStmt = isTrue ? stmt->ifSuite : stmt->elseSuite;
bool isOwn = // these blocks will not be a real owning blocks after inlining
resultStmt && resultStmt->getSuite() && resultStmt->getSuite()->ownBlock;
ctx->blockLevel -= isOwn;
resultStmt = transform(resultStmt);
ctx->blockLevel += isOwn;
if (!resultStmt)
resultStmt = transform(N<SuiteStmt>());
return;

View File

@ -811,7 +811,12 @@ You can also add methods to types:
Type extensions
~~~~~~~~~~~~~~~
Suppose you have a class that lacks a method or an operator that might be really useful. Codon provides an ``@extend`` annotation that allows you to add and modify methods of various types at compile time, including built-in types like ``int`` or ``str``. This allows much of the functionality of built-in types to be implemented in Codon as type extensions in the standard library.
Suppose you have a class that lacks a method or an operator that might be really useful.
Codon provides an ``@extend`` annotation that allows programmers to add and modify
methods of various types at compile time, including built-in types like ``int`` or ``str``.
This actually allows much of the functionality of built-in types to be implemented in
Codon as type extensions in the standard library.
.. code:: python

View File

@ -35,7 +35,7 @@ def _partial_insertion_sort[S,T](arr: Array[T], begin: int, end: int, keyf: Call
arr[sift] = arr[sift_1]
sift -= 1
sift_1 -= 1
if sift == begin or keyf(tmp) >= keyf(arr[sift_1]):
if sift == begin or not keyf(tmp) < keyf(arr[sift_1]):
break
arr[sift] = tmp
@ -52,7 +52,7 @@ def _partition_left[S,T](arr: Array[T], begin: int, end: int, keyf: Callable[[T]
while True:
last -= 1
if keyf(pivot) >= keyf(arr[last]):
if not keyf(pivot) < keyf(arr[last]):
break
if (last + 1 == end):
@ -71,7 +71,7 @@ def _partition_left[S,T](arr: Array[T], begin: int, end: int, keyf: Callable[[T]
arr[first], arr[last] = arr[last], arr[first]
while True:
last -= 1
if keyf(pivot) >= keyf(arr[last]):
if not keyf(pivot) < keyf(arr[last]):
break
while True:
first += 1
@ -91,7 +91,7 @@ def _partition_right[S,T](arr: Array[T], begin: int, end: int, keyf: Callable[[T
while True:
first += 1
if keyf(arr[first]) >= keyf(pivot):
if not keyf(arr[first]) < keyf(pivot):
break
if first - 1 == begin:
@ -115,7 +115,7 @@ def _partition_right[S,T](arr: Array[T], begin: int, end: int, keyf: Callable[[T
while True:
first += 1
if keyf(arr[first]) >= keyf(pivot):
if not keyf(arr[first]) < keyf(pivot):
break
while True:
@ -155,7 +155,7 @@ def _pdq_sort[S,T](arr: Array[T], begin: int, end: int, keyf: Callable[[T], S],
else:
_sort3(arr, begin + size_2, begin, end - 1, keyf)
if not leftmost and keyf(arr[begin - 1]) >= keyf(arr[begin]):
if not leftmost and not keyf(arr[begin - 1]) < keyf(arr[begin]):
begin = _partition_left(arr, begin, end, keyf) + 1
continue

View File

@ -1,6 +1,8 @@
def _med3[S,T](a: int, b: int, c: int, d: Array[T], k: Callable[[T], S]):
return ((b if (k(d[b]) < k(d[c])) else (c if k(d[a]) < k(d[c]) else a))
if (k(d[a]) < k(d[b])) else (b if (k(d[b]) > k(d[c])) else (c if k(d[a]) > k(d[c]) else a)))
if k(d[a]) < k(d[b]):
return b if (k(d[b]) < k(d[c])) else (c if k(d[a]) < k(d[c]) else a)
else:
return b if not (k(d[b]) < k(d[c]) or k(d[b]) == k(d[c])) else (c if not (k(d[a]) < k(d[c]) or k(d[a]) == k(d[c])) else a)
def _swap[T](i: int, j: int, a: Array[T]):
a[i], a[j] = a[j], a[i]
@ -17,7 +19,7 @@ def _qsort[S,T](arr: Array[T], frm: int, cnt: int, keyf: Callable[[T], S]):
i = frm + 1
while i < frm + cnt:
j = i
while j > frm and keyf(arr[j - 1]) > keyf(arr[j]):
while j > frm and not (keyf(arr[j - 1]) < keyf(arr[j]) or keyf(arr[j - 1]) == keyf(arr[j])):
_swap(j, j - 1, arr)
j -= 1
i += 1
@ -41,13 +43,13 @@ def _qsort[S,T](arr: Array[T], frm: int, cnt: int, keyf: Callable[[T], S]):
d = c
while True:
while b <= c and keyf(arr[b]) <= keyf(arr[frm]):
while b <= c and (keyf(arr[b]) < keyf(arr[frm]) or keyf(arr[b]) == keyf(arr[frm])):
if keyf(arr[b]) == keyf(arr[frm]):
_swap(a, b, arr)
a += 1
b += 1
while c >= b and keyf(arr[c]) >= keyf(arr[frm]):
while c >= b and not keyf(arr[c]) < keyf(arr[frm]):
if keyf(arr[c]) == keyf(arr[frm]):
_swap(c, d, arr)
d -= 1

View File

@ -104,6 +104,9 @@ class deque[T]:
return True
return False
def __deepcopy__(self):
return deque(i.__deepcopy__() for i in self)
def __copy__(self):
return deque[T](copy(self._arr), self._head, self._tail, self._maxlen)

View File

@ -148,7 +148,7 @@ def nsmallest[T](n: int, iterable: Generator[T], key = Optional[int]()):
Equivalent to: sorted(iterable, key=key)[:n]
"""
if n == 1:
v = List[T](1)
v = List(1)
for a in iterable:
if not v:
v.append(a)
@ -165,7 +165,7 @@ def nsmallest[T](n: int, iterable: Generator[T], key = Optional[int]()):
it = iter(iterable)
# put the range(n) first so that zip() doesn't
# consume one too many elements from the iterator
result = List[Tuple[T,int]](n)
result = List(n)
done = False
for i in range(n):
if it.done():
@ -174,7 +174,7 @@ def nsmallest[T](n: int, iterable: Generator[T], key = Optional[int]()):
result.append((it.next(), i))
if not result:
it.destroy()
return List[T](0)
return []
_heapify_max(result)
top = result[0][0]
order = n
@ -191,7 +191,7 @@ def nsmallest[T](n: int, iterable: Generator[T], key = Optional[int]()):
else:
# General case, slowest method
it = iter(iterable)
result = List[Tuple[type(key(T())),int,T]](n)
result = List(n)
done = False
for i in range(n):
if it.done():
@ -201,7 +201,7 @@ def nsmallest[T](n: int, iterable: Generator[T], key = Optional[int]()):
result.append((key(elem), i, elem))
if not result:
it.destroy()
return List[T](0)
return []
_heapify_max(result)
top = result[0][0]
order = n
@ -222,7 +222,7 @@ def nlargest[T](n: int, iterable: Generator[T], key = Optional[int]()):
Equivalent to: sorted(iterable, key=key, reverse=True)[:n]
"""
if n == 1:
v = List[T](1)
v = List(1)
for a in iterable:
if not v:
v.append(a)
@ -237,7 +237,7 @@ def nlargest[T](n: int, iterable: Generator[T], key = Optional[int]()):
# When key is none, use simpler decoration
if isinstance(key, Optional):
it = iter(iterable)
result = List[Tuple[T,int]](n)
result = List(n)
done = False
for i in range(0, -n, -1):
if it.done():
@ -246,7 +246,7 @@ def nlargest[T](n: int, iterable: Generator[T], key = Optional[int]()):
result.append((it.next(), i))
if not result:
it.destroy()
return List[T](0)
return []
heapify(result)
top = result[0][0]
order = -n
@ -263,7 +263,7 @@ def nlargest[T](n: int, iterable: Generator[T], key = Optional[int]()):
else:
# General case, slowest method
it = iter(iterable)
result = List[Tuple[type(key(T())),int,T]](n)
result = List(n)
done = False
for i in range(0, -n, -1):
if it.done():
@ -272,7 +272,7 @@ def nlargest[T](n: int, iterable: Generator[T], key = Optional[int]()):
elem = it.next()
result.append((key(elem), i, elem))
if not result:
return List[T](0)
return []
heapify(result)
top = result[0][0]
order = -n

View File

@ -27,7 +27,7 @@ class File:
def __iter__(self):
for a in self._iter():
yield copy(a)
yield a.__ptrcopy__()
def readlines(self):
return [l for l in self]
@ -142,7 +142,7 @@ class gzFile:
def __iter__(self):
for a in self._iter():
yield copy(a)
yield a.__ptrcopy__()
def __enter__(self):
pass

View File

@ -35,9 +35,6 @@ def realloc(p: cobj, sz: int):
def free(p: cobj):
seq_free(p)
def register_finalizer(p: cobj, f: Function[[cobj, cobj], void]):
seq_register_finalizer(p, f.__raw__())
def add_roots(start: cobj, end: cobj):
seq_gc_add_roots(start, end)
@ -49,3 +46,9 @@ def clear_roots():
def exclude_static_roots(start: cobj, end: cobj):
seq_gc_exclude_static_roots(start, end)
def register_finalizer(p):
if hasattr(p, '__del__'):
def f(x: cobj, data: cobj, T: type):
Ptr[T](__ptr__(x).as_byte())[0].__del__()
seq_register_finalizer(p.__raw__(), f(T=type(p), ...).__raw__())

View File

@ -132,12 +132,6 @@ class pyobj:
def _getattr(self, name: str):
return pyobj.exc_wrap(pyobj(PyObject_GetAttrString(self.p, name.c_str())))
def __getitem__(self, t):
return self._getattr("__getitem__")(t)
def __add__(self, t):
return self._getattr("__add__")(t)
def __setitem__(self, name: str, val: pyobj):
return pyobj.exc_wrap(pyobj(PyObject_SetAttrString(self.p, name.c_str(), val.p)))
@ -232,7 +226,7 @@ class pyobj:
ensure_initialized()
PyRun_SimpleString(code.c_str())
def get[T](self) -> T:
def get(self, T: type) -> T:
return T.__from_py__(self)
def none():
@ -244,7 +238,7 @@ def none():
def py(x) -> pyobj:
return x.__to_py__()
def get[T](x: pyobj) -> T:
def get(x: pyobj, T: type) -> T:
return T.__from_py__(x)
@extend

View File

@ -106,12 +106,6 @@ class str:
n += self.len
return str(p, total)
def __copy__(self):
n = len(self)
p = cobj(n)
str.memcpy(p, self.ptr, n)
return str(p, n)
def _cmp(self, other: str):
n = min(self.len, other.len)
i = 0

View File

@ -10,6 +10,13 @@ class Array:
p = Ptr[T](self.len)
str.memcpy(p.as_byte(), self.ptr.as_byte(), self.len * sizeof(T))
return (self.len, p)
def __deepcopy__(self) -> Array[T]:
p = Ptr[T](self.len)
i = 0
while i < self.len:
p[i] = self.ptr[i].__deepcopy__()
i += 1
return (self.len, p)
def __len__(self) -> int:
return self.len
def __bool__(self) -> bool:

View File

@ -10,6 +10,8 @@ class bool:
return "True" if self else "False"
def __copy__(self) -> bool:
return self
def __deepcopy__(self) -> bool:
return self
def __bool__(self) -> bool:
return self
def __hash__(self):

View File

@ -21,6 +21,8 @@ class byte:
ret i8 %0
def __copy__(self) -> byte:
return self
def __deepcopy__(self) -> byte:
return self
@pure
@llvm
def __bool__(self) -> bool:

View File

@ -113,6 +113,9 @@ class Dict[K,V]:
str.memcpy(vals_copy.as_byte(), self._vals.as_byte(), n * gc.sizeof(V))
return Dict[K,V](n, self._size, self._n_occupied, self._upper_bound, flags_copy, keys_copy, vals_copy)
def __deepcopy__(self):
return {k.__deepcopy__(): v.__deepcopy__() for k, v in self.items()}
def __repr__(self):
n = self.__len__()
if n == 0:

View File

@ -173,6 +173,9 @@ class List:
def __copy__(self):
return List[T](self.arr.__copy__(), self.len)
def __deepcopy__(self):
return [l.__deepcopy__() for l in self]
def __iter__(self):
i = 0
N = self.len

View File

@ -119,6 +119,9 @@ class Set[K]:
str.memcpy(keys_copy.as_byte(), self._keys.as_byte(), n * gc.sizeof(K))
return Set[K](n, self._size, self._n_occupied, self._upper_bound, flags_copy, keys_copy)
def __deepcopy__(self):
return {s.__deepcopy__() for s in self}
def __repr__(self):
n = self.__len__()
if n == 0:

View File

@ -266,3 +266,13 @@ class complex:
declare double @llvm.log.f64(double)
%y = call double @llvm.log.f64(double %x)
ret double %y
@extend
class int:
def __suffix_j__(x: int):
return complex(0, x)
@extend
class float:
def __suffix_j__(x: float):
return complex(0, x)

View File

@ -17,6 +17,8 @@ class float:
return s if s != "-nan" else "nan"
def __copy__(self) -> float:
return self
def __deepcopy__(self) -> float:
return self
@pure
@llvm
def __int__(self) -> int:

View File

@ -31,6 +31,8 @@ class int:
return seq_str_int(self)
def __copy__(self) -> int:
return self
def __deepcopy__(self) -> int:
return self
def __hash__(self) -> int:
return self
@pure

View File

@ -50,6 +50,8 @@ class Int:
return int(self)
def __copy__(self) -> Int[N]:
return self
def __deepcopy__(self) -> Int[N]:
return self
def __hash__(self) -> int:
return int(self)
@pure
@ -255,6 +257,8 @@ class UInt:
return int(self)
def __copy__(self) -> UInt[N]:
return self
def __deepcopy__(self) -> UInt[N]:
return self
def __hash__(self) -> int:
return int(self)
@pure

View File

@ -24,6 +24,13 @@ class str:
return self.len != 0
def __copy__(self) -> str:
return self
def __deepcopy__(self) -> str:
return self
def __ptrcopy__(self) -> str:
n = self.len
p = cobj(n)
str.memcpy(p, self.ptr, n)
return str(p, n)
@llvm
def memcpy(dest: Ptr[byte], src: Ptr[byte], len: int) -> void:
declare void @llvm.memcpy.p0i8.p0i8.i64(i8* %dest, i8* %src, i64 %len, i32 %align, i1 %isvolatile)

View File

@ -283,8 +283,9 @@ def starmap(function, iterable):
for args in iterable:
yield function(*args)
# TODO: fix this once Optional[Callable] lands
@inline
def groupby[T](iterable: Generator[T], key: Callable[[T], S] = NoneType, S: type = NoneType):
def groupby(iterable, key = Optional[int]()):
"""
Make an iterator that returns consecutive keys and groups from the iterable.
"""
@ -292,7 +293,7 @@ def groupby[T](iterable: Generator[T], key: Callable[[T], S] = NoneType, S: type
group = []
for currvalue in iterable:
k = currvalue if isinstance(key, NoneType) else key(currvalue)
k = currvalue if isinstance(key, Optional) else key(currvalue)
if currkey is None:
currkey = k
if k != ~currkey:

View File

@ -27,8 +27,8 @@ print -1u7 #: 127
@extend
class int:
def __suffix_test__(s: str):
return 'TEST: ' + s
def __suffix_test__(s):
return 'TEST: ' + str(s)
print 123_456test #: TEST: 123456
#%% int_error,barebones
@ -39,8 +39,13 @@ print 5.15 #: 5.15
print 2e2 #: 200
print 2.e-2 #: 0.02
#%% float_error,barebones
print 5.__str__() #! syntax error
#%% float_suffix,barebones
@extend
class float:
def __suffix_zoo__(x):
return str(x) + '_zoo'
print 1.2e-1zoo #: 0.12_zoo
#%% string,barebones
print 'kthxbai', "kthxbai" #: kthxbai kthxbai

View File

@ -851,6 +851,11 @@ print np.transpose(a)
n = np.array([[1, 2], [3, 4]])
print n[0], n[0][0] + 1 #: [1 2] 2
a = np.array([1,2,3])
print(a + 1) #: [2 3 4]
print(a - 1) #: [0 1 2]
print(1 - a) #: [ 0 -1 -2]
#%% python_import_fn
from python import re.split(str, str) -> List[str] as rs
print rs(r'\W+', 'Words, words, words.') #: ['Words', 'words', 'words', '']
@ -943,10 +948,23 @@ class defdict(Dict[str,float]):
z = defdict()
z[1.1] #! cannot unify float and str
#%% inherit_tuple,barebones
class Foo:
a: int
b: str
def __init__(self, a: int):
self.a, self.b = a, 'yoo'
@tuple
class FooTup(Foo): pass
f = Foo(5)
print f.a, f.b #: 5 yoo
fp = FooTup(6, 's')
print fp #: (a: 6, b: 's')
#%% inherit_class_err_1,barebones
class defdict(Array[int]):
pass #! tuples cannot inherit reference classes (and vice versa)
pass #! reference classes cannot inherit by-value classes
#%% inherit_class_err_2,barebones
@tuple

View File

@ -441,6 +441,7 @@ def foo(x, *args, **kwargs):
print x, args, kwargs
p = foo(...)
p(1, z=5) #: 1 () (z: 5)
p('s', zh=65) #: s () (zh: 65)
q = p(zh=43, ...)
q(1) #: 1 () (zh: 43)
r = q(5, 38, ...)

View File

@ -282,3 +282,33 @@ class int:
yield self
self -= 1
print list((5).run_lola_run()) #: [5, 4, 3, 2, 1]
#%% early_return,barebones
def foo(x):
print x-1
return
print len(x)
foo(5) #: 4
def foo(x):
if isinstance(x, int):
print x+1
return
print len(x)
foo(1) #: 2
foo('s') #: 1
def foo(x, y: Static[int] = 5):
if y < 3:
if y > 1:
if isinstance(x, int):
print x+1
return
if isinstance(x, int):
return
print len(x)
foo(1, 1)
foo(1, 2) #: 2
foo(1)
foo('s') #: 1

View File

@ -765,10 +765,7 @@ def bar[T](x):
foo(bar, 1)
#: 1 int
#: bar[...]
foo(bar(T=int,...), 1)
#: 1 int
#: bar[...]
foo(bar(T=str,...), 's')
foo(bar(...), 's')
#: s str
#: bar[...]
z = bar
@ -780,7 +777,16 @@ z(1, T=str)
zz = bar(T=int,...)
zz(1)
#: 1 int
# zz('s') # TODO: zz = foo[int] is update stmt, messes up everything... :/
#%% forward_error,barebones
def foo(f, x):
f(x, type(x))
print f.__class__
def bar[T](x):
print x, T.__class__
foo(bar(T=int,...), 1)
#! too many arguments for bar[T1,int] (expected maximum 2, got 2)
#! while realizing foo (arguments foo[bar[...],int])
#%% sort_partial
def foo(x, y):
@ -1082,3 +1088,22 @@ def test(v: Optional[int]):
print v.__class__
test(5) #: Optional[int]
test(None) #: Optional[int]
#%% methodcaller,barebones
def foo():
def bar(a, b):
print 'bar', a, b
return bar
foo()(1, 2) #: bar 1 2
def methodcaller(foo: Static[str]):
def caller(foo: Static[str], obj, *args, **kwargs):
if isinstance(getattr(obj, foo)(*args, **kwargs), void):
getattr(obj, foo)(*args, **kwargs)
else:
return getattr(obj, foo)(*args, **kwargs)
return caller(foo=foo, ...)
v = [1]
methodcaller('append')(v, 42)
print v #: [1, 42]
print methodcaller('index')(v, 42) #: 1

View File

@ -3,7 +3,6 @@ import cmath
INF = float('inf')
NAN = float('nan')
j = complex(0, 1)
def float_identical(x, y):
if math.isnan(x) or math.isnan(y):
@ -27,11 +26,11 @@ def complex_identical(x, y):
###########
ZERO_DIVISION = (
(1+1*j, 0+0*j),
(1+1*j, 0.0+0*j),
(1+1*j, 0+0*j),
(1.0+0*j, 0+0*j),
(1+0*j, 0+0*j),
(1+1j, 0+0j),
(1+1j, 0.0+0j),
(1+1j, 0+0j),
(1.0+0j, 0+0j),
(1+0j, 0+0j),
)
def close_abs(x, y, eps=1e-9):
@ -81,14 +80,14 @@ def test_truediv():
# A naive complex division algorithm (such as in 2.0) is very prone to
# nonsense errors for these (overflows and underflows).
assert check_div(complex(1e200, 1e200), 1+0*j)
assert check_div(complex(1e-200, 1e-200), 1+0*j)
assert check_div(complex(1e200, 1e200), 1+0j)
assert check_div(complex(1e-200, 1e-200), 1+0j)
# Just for fun.
for i in range(100):
check_div(complex(random(), random()), complex(random(), random()))
assert close_complex(complex.__truediv__(2+0*j, 1+1*j), 1-1*j)
assert close_complex(complex.__truediv__(2+0j, 1+1j), 1-1j)
for denom_real, denom_imag in [(0., NAN), (NAN, 0.), (NAN, NAN)]:
z = complex(0, 0) / complex(denom_real, denom_imag)
@ -98,51 +97,51 @@ test_truediv()
@test
def test_richcompare():
assert not complex.__eq__(1+1*j, 1<<10000)
assert complex.__eq__(1+1*j, 1+1*j)
assert not complex.__eq__(1+1*j, 2+2*j)
assert not complex.__ne__(1+1*j, 1+1*j)
assert complex.__ne__(1+1*j, 2+2*j), True
assert not complex.__eq__(1+1j, 1<<10000)
assert complex.__eq__(1+1j, 1+1j)
assert not complex.__eq__(1+1j, 2+2j)
assert not complex.__ne__(1+1j, 1+1j)
assert complex.__ne__(1+1j, 2+2j), True
for i in range(1, 100):
f = i / 100.0
assert complex.__eq__(f+0*j, f)
assert not complex.__ne__(f+0*j, f)
assert complex.__eq__(f+0j, f)
assert not complex.__ne__(f+0j, f)
assert not complex.__eq__(complex(f, f), f)
assert complex.__ne__(complex(f, f), f)
import operator
assert operator.eq(1+1*j, 1+1*j) == True
assert operator.eq(1+1*j, 2+2*j) == False
assert operator.ne(1+1*j, 1+1*j) == False
assert operator.ne(1+1*j, 2+2*j) == True
assert operator.eq(1+1j, 1+1j) == True
assert operator.eq(1+1j, 2+2j) == False
assert operator.ne(1+1j, 1+1j) == False
assert operator.ne(1+1j, 2+2j) == True
test_richcompare()
@test
def test_pow():
def pow(a, b): return a ** b
assert close_complex(pow(1+1*j, 0+0*j), 1.0)
assert close_complex(pow(0+0*j, 2+0*j), 0.0)
assert close_complex(pow(1*j, -1), 1/(1*j))
assert close_complex(pow(1*j, 200), 1)
assert close_complex(pow(1+1j, 0+0j), 1.0)
assert close_complex(pow(0+0j, 2+0j), 0.0)
assert close_complex(pow(1j, -1), 1/(1j))
assert close_complex(pow(1j, 200), 1)
a = 3.33+4.43*j
assert a ** (0*j) == 1
assert a ** (0.+0.*j) == 1
a = 3.33+4.43j
assert a ** (0j) == 1
assert a ** (0.+0.j) == 1
assert (3*j) ** (0*j) == 1
assert (3*j) ** 0 == 1
assert (3j) ** (0j) == 1
assert (3j) ** 0 == 1
# The following is used to exercise certain code paths
assert a ** 105 == a ** 105
assert a ** -105 == a ** -105
assert a ** -30 == a ** -30
assert (0.0*j) ** 0 == 1
assert (0.0j) ** 0 == 1
test_pow()
@test
def test_conjugate():
assert close_complex(complex(5.3, 9.8).conjugate(), 5.3-9.8*j)
assert close_complex(complex(5.3, 9.8).conjugate(), 5.3-9.8j)
test_conjugate()
@test
@ -419,8 +418,8 @@ def test_polar():
assert check(0, (0., 0.))
assert check(1, (1., 0.))
assert check(-1, (1., pi))
assert check(1*j, (1., pi / 2))
assert check(-3*j, (3., -pi / 2))
assert check(1j, (1., pi / 2))
assert check(-3j, (3., -pi / 2))
inf = float('inf')
assert check(complex(inf, 0), (inf, 0.))
assert check(complex(-inf, 0), (inf, pi))
@ -446,10 +445,10 @@ def test_phase():
assert almost_equal(phase(0), 0.)
assert almost_equal(phase(1.), 0.)
assert almost_equal(phase(-1.), pi)
assert almost_equal(phase(-1.+1E-300*j), pi)
assert almost_equal(phase(-1.-1E-300*j), -pi)
assert almost_equal(phase(1*j), pi/2)
assert almost_equal(phase(-1*j), -pi/2)
assert almost_equal(phase(-1.+1E-300j), pi)
assert almost_equal(phase(-1.-1E-300j), -pi)
assert almost_equal(phase(1j), pi/2)
assert almost_equal(phase(-1j), -pi/2)
# zeros
assert phase(complex(0.0, 0.0)) == 0.0
@ -539,7 +538,7 @@ test_isfinite()
@test
def test_isnan():
assert not cmath.isnan(1)
assert not cmath.isnan(1*j)
assert not cmath.isnan(1j)
assert not cmath.isnan(INF)
assert cmath.isnan(NAN)
assert cmath.isnan(complex(NAN, 0))
@ -552,7 +551,7 @@ test_isnan()
@test
def test_isinf():
assert not cmath.isinf(1)
assert not cmath.isinf(1*j)
assert not cmath.isinf(1j)
assert not cmath.isinf(NAN)
assert cmath.isinf(INF)
assert cmath.isinf(complex(INF, 0))
@ -583,10 +582,10 @@ test_atanh_sign()
@test
def test_is_close():
# test complex values that are close to within 12 decimal places
complex_examples = [(1.0+1.0*j, 1.000000000001+1.0*j),
(1.0+1.0*j, 1.0+1.000000000001*j),
(-1.0+1.0*j, -1.000000000001+1.0*j),
(1.0-1.0*j, 1.0-0.999999999999*j),
complex_examples = [(1.0+1.0j, 1.000000000001+1.0j),
(1.0+1.0j, 1.0+1.000000000001j),
(-1.0+1.0j, -1.000000000001+1.0j),
(1.0-1.0j, 1.0-0.999999999999j),
]
for a,b in complex_examples:
@ -594,20 +593,20 @@ def test_is_close():
assert not cmath.isclose(a, b, rel_tol=1e-13)
# test values near zero that are near to within three decimal places
near_zero_examples = [(0.001*j, 0),
(0.001 + 0*j, 0),
(0.001+0.001*j, 0),
(-0.001+0.001*j, 0),
(0.001-0.001*j, 0),
(-0.001-0.001*j, 0),
near_zero_examples = [(0.001j, 0),
(0.001 + 0j, 0),
(0.001+0.001j, 0),
(-0.001+0.001j, 0),
(0.001-0.001j, 0),
(-0.001-0.001j, 0),
]
for a,b in near_zero_examples:
assert cmath.isclose(a, b, abs_tol=1.5e-03)
assert not cmath.isclose(a, b, abs_tol=0.5e-03)
assert cmath.isclose(0.001-0.001*j, 0.001+0.001*j, abs_tol=2e-03)
assert not cmath.isclose(0.001-0.001*j, 0.001+0.001*j, abs_tol=1e-03)
assert cmath.isclose(0.001-0.001j, 0.001+0.001j, abs_tol=2e-03)
assert not cmath.isclose(0.001-0.001j, 0.001+0.001j, abs_tol=1e-03)
test_is_close()
@test