Bugfixes 2023-08 (#440)

* Fix type argument overload issue; Fix Cython version for CI

* Add __contains__ for kwargs

* Add get() for kwargs

* Add static <<, >> and unary ~

* Fix CI

* Fix OpenMP "ordered" clause

* Fix static ~

* Fix Cython 3 issues

* Fix Python MANIFEST.in

---------

Co-authored-by: A. R. Shajii <ars@ars.me>
pull/454/head
Ibrahim Numanagić 2023-08-12 16:39:45 +02:00 committed by GitHub
parent 7198a0971a
commit 750bb28c9c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 82 additions and 17 deletions

View File

@ -151,19 +151,35 @@ std::string TypecheckVisitor::generateTuple(size_t len, const std::string &name,
StmtPtr stmt = N<ClassStmt>(ctx->cache->generateSrcInfo(), typeName, args, nullptr,
std::vector<ExprPtr>{N<IdExpr>("tuple")});
// Add getItem for KwArgs:
// Add helpers for KwArgs:
// `def __getitem__(self, key: Static[str]): return getattr(self, key)`
// `def __contains__(self, key: Static[str]): return hasattr(self, key)`
auto getItem = N<FunctionStmt>(
"__getitem__", nullptr,
std::vector<Param>{Param{"self"}, Param{"key", N<IndexExpr>(N<IdExpr>("Static"),
N<IdExpr>("str"))}},
N<SuiteStmt>(N<ReturnStmt>(
N<CallExpr>(N<IdExpr>("getattr"), N<IdExpr>("self"), N<IdExpr>("key")))));
auto contains = N<FunctionStmt>(
"__contains__", nullptr,
std::vector<Param>{Param{"self"}, Param{"key", N<IndexExpr>(N<IdExpr>("Static"),
N<IdExpr>("str"))}},
N<SuiteStmt>(N<ReturnStmt>(
N<CallExpr>(N<IdExpr>("hasattr"), N<IdExpr>("self"), N<IdExpr>("key")))));
auto getDef = N<FunctionStmt>(
"get", nullptr,
std::vector<Param>{
Param{"self"},
Param{"key", N<IndexExpr>(N<IdExpr>("Static"), N<IdExpr>("str"))},
Param{"default", nullptr, N<CallExpr>(N<IdExpr>("NoneType"))}},
N<SuiteStmt>(N<ReturnStmt>(
N<CallExpr>(N<DotExpr>(N<IdExpr>("__internal__"), "kwargs_get"),
N<IdExpr>("self"), N<IdExpr>("key"), N<IdExpr>("default")))));
if (startswith(typeName, TYPE_KWTUPLE))
stmt->getClass()->suite = getItem;
stmt->getClass()->suite = N<SuiteStmt>(getItem, contains, getDef);
// Add getItem for KwArgs:
// `def __repr__(self,): return __magic__.repr_partial(self)`
// Add repr for KwArgs:
// `def __repr__(self): return __magic__.repr_partial(self)`
auto repr = N<FunctionStmt>(
"__repr__", nullptr, std::vector<Param>{Param{"self"}},
N<SuiteStmt>(N<ReturnStmt>(N<CallExpr>(

View File

@ -22,7 +22,8 @@ void TypecheckVisitor::visit(UnaryExpr *expr) {
transform(expr->expr);
static std::unordered_map<StaticValue::Type, std::unordered_set<std::string>>
staticOps = {{StaticValue::INT, {"-", "+", "!"}}, {StaticValue::STRING, {"@"}}};
staticOps = {{StaticValue::INT, {"-", "+", "!", "~"}},
{StaticValue::STRING, {"@"}}};
// Handle static expressions
if (expr->expr->isStatic() && in(staticOps[expr->expr->staticValue.type], expr->op)) {
resultExpr = evaluateStaticUnary(expr);
@ -62,7 +63,7 @@ void TypecheckVisitor::visit(BinaryExpr *expr) {
static std::unordered_map<StaticValue::Type, std::unordered_set<std::string>>
staticOps = {{StaticValue::INT,
{"<", "<=", ">", ">=", "==", "!=", "&&", "||", "+", "-", "*", "//",
"%", "&", "|", "^"}},
"%", "&", "|", "^", ">>", "<<"}},
{StaticValue::STRING, {"==", "!=", "+"}}};
if (expr->lexpr->isStatic() && expr->rexpr->isStatic() &&
expr->lexpr->staticValue.type == expr->rexpr->staticValue.type &&
@ -370,13 +371,15 @@ ExprPtr TypecheckVisitor::evaluateStaticUnary(UnaryExpr *expr) {
}
// Case: static integers
if (expr->op == "-" || expr->op == "+" || expr->op == "!") {
if (expr->op == "-" || expr->op == "+" || expr->op == "!" || expr->op == "~") {
if (expr->expr->staticValue.evaluated) {
int64_t value = expr->expr->staticValue.getInt();
if (expr->op == "+")
;
else if (expr->op == "-")
value = -value;
else if (expr->op == "~")
value = ~value;
else
value = !bool(value);
LOG_TYPECHECK("[cond::un] {}: {}", getSrcInfo(), value);
@ -484,6 +487,10 @@ ExprPtr TypecheckVisitor::evaluateStaticBinary(BinaryExpr *expr) {
lvalue = lvalue & rvalue;
else if (expr->op == "|")
lvalue = lvalue | rvalue;
else if (expr->op == ">>")
lvalue = lvalue >> rvalue;
else if (expr->op == "<<")
lvalue = lvalue << rvalue;
else if (expr->op == "//")
lvalue = divMod(ctx, lvalue, rvalue).first;
else if (expr->op == "%")

View File

@ -253,7 +253,7 @@ int TypecheckVisitor::canCall(const types::FuncTypePtr &fn,
auto score = ctx->reorderNamedArgs(
fn.get(), args,
[&](int s, int k, const std::vector<std::vector<int>> &slots, bool _) {
for (int si = 0; si < slots.size(); si++) {
for (int si = 0, gi = 0; si < slots.size(); si++) {
if (fn->ast->args[si].status == Param::Generic) {
if (slots[si].empty()) {
// is this "real" type?
@ -263,8 +263,13 @@ int TypecheckVisitor::canCall(const types::FuncTypePtr &fn,
}
reordered.push_back({nullptr, 0});
} else {
seqassert(gi < fn->funcGenerics.size(), "bad fn");
if (!fn->funcGenerics[gi].type->isStaticType() &&
!args[slots[si][0]].value->isType())
return -1;
reordered.push_back({args[slots[si][0]].value->type, slots[si][0]});
}
gi++;
} else if (si == s || si == k || slots[si].size() != 1) {
// Ignore *args, *kwargs and default arguments
reordered.push_back({nullptr, 0});

View File

@ -0,0 +1 @@
include codon/*.pxd

View File

@ -216,7 +216,7 @@ def jit(fn=None, debug=None, sample_size=5, pyvars=None):
file=sys.stderr,
)
return _jit.run_wrapper(
obj_name, types, f.__module__, pyvars, args, 1 if debug else 0
obj_name, list(types), f.__module__, list(pyvars), args, 1 if debug else 0
)
except JITError:
_reset_jit()

View File

@ -65,7 +65,7 @@ else:
jit_extension = Extension(
"codon.codon_jit",
sources=["codon/jit.pyx", "codon/jit.pxd"],
sources=["codon/jit.pyx"],
libraries=libraries,
language="c++",
extra_compile_args=["-w"],

View File

@ -435,6 +435,12 @@ class __internal__:
e.col = col
return e
def kwargs_get(kw, key: Static[str], default):
if hasattr(kw, key):
return getattr(kw, key)
else:
return default
@extend
class __magic__:

View File

@ -136,8 +136,8 @@ def _master_end(loc_ref: Ptr[Ident], gtid: int):
__kmpc_end_master(loc_ref, i32(gtid))
def _ordered_begin(loc_ref: Ptr[Ident], gtid: int):
from C import __kmpc_ordered(Ptr[Ident], i32) -> i32
return int(__kmpc_ordered(loc_ref, i32(gtid)))
from C import __kmpc_ordered(Ptr[Ident], i32)
__kmpc_ordered(loc_ref, i32(gtid))
def _ordered_end(loc_ref: Ptr[Ident], gtid: int):
from C import __kmpc_end_ordered(Ptr[Ident], i32)
@ -781,7 +781,7 @@ def ordered(func):
def _wrapper(*args, **kwargs):
gtid = get_thread_num()
loc = _default_loc()
if _ordered_begin(loc, gtid) != 0:
_ordered_begin(loc, gtid)
try:
func(*args, **kwargs)
finally:

View File

@ -1231,6 +1231,21 @@ def foo(x):
print foo('hi') #: (3, 2)
print foo('hi', 1) #: (2, 'hi_1')
def fox(a: int, b: int, c: int, dtype: type = int):
print('fox 1:', a, b, c)
@overload
def fox(a: int, b: int, dtype: type = int):
print('fox 2:', a, b, dtype.__class__.__name__)
fox(1, 2, float)
#: fox 2: 1 2 float
fox(1, 2)
#: fox 2: 1 2 int
fox(1, 2, 3)
#: fox 1: 1 2 3
#%% fn_shadow,barebones
def foo(x):
return 1, x

View File

@ -889,6 +889,20 @@ def test_omp_collapse():
assert A6 == B6
@test
def test_omp_ordered(N: int = 1000):
@omp.ordered
def f(A, i):
A.append(i)
A = []
@par(schedule='dynamic', chunk_size=1, num_threads=2, ordered=True)
for i in range(N):
f(A, i)
assert A == list(range(N))
test_omp_api()
test_omp_schedules()
test_omp_ranges()
@ -901,3 +915,4 @@ test_omp_transform(111.1, 222.2, 333.3)
test_omp_nested()
test_omp_corner_cases()
test_omp_collapse()
test_omp_ordered()