diff --git a/codon/parser/visitors/translate/translate.cpp b/codon/parser/visitors/translate/translate.cpp index 13a89dfa..492211d6 100644 --- a/codon/parser/visitors/translate/translate.cpp +++ b/codon/parser/visitors/translate/translate.cpp @@ -411,26 +411,18 @@ void TranslateVisitor::transformFunction(types::FuncType *type, FunctionStmt *as ir::Func *func) { std::vector names; std::vector indices; - std::vector srcInfos; - std::vector types; for (int i = 0, j = 1; i < ast->args.size(); i++) if (!ast->args[i].generic) { if (!type->args[j]->getFunc()) { - types.push_back(getType(type->args[j])); names.push_back(ctx->cache->reverseIdentifierLookup[ast->args[i].name]); indices.push_back(i); } j++; } if (ast->hasAttr(Attr::CVarArg)) { - types.pop_back(); names.pop_back(); indices.pop_back(); } - auto irType = ctx->getModule()->unsafeGetFuncType( - type->realizedName(), getType(type->args[0]), types, ast->hasAttr(Attr::CVarArg)); - irType->setAstType(type->getFunc()); - func->realize(irType, names); // TODO: refactor IR attribute API std::map attr; attr[".module"] = ast->attributes.module; @@ -461,20 +453,14 @@ void TranslateVisitor::transformFunction(types::FuncType *type, FunctionStmt *as void TranslateVisitor::transformLLVMFunction(types::FuncType *type, FunctionStmt *ast, ir::Func *func) { std::vector names; - std::vector types; std::vector indices; for (int i = 0, j = 1; i < ast->args.size(); i++) if (!ast->args[i].generic) { - types.push_back(getType(type->args[j])); names.push_back(ctx->cache->reverseIdentifierLookup[ast->args[i].name]); indices.push_back(i); j++; } - auto irType = ctx->getModule()->unsafeGetFuncType(type->realizedName(), - getType(type->args[0]), types); - irType->setAstType(type->getFunc()); auto f = cast(func); - f->realize(irType, names); // TODO: refactor IR attribute API std::map attr; attr[".module"] = ast->attributes.module; diff --git a/codon/parser/visitors/typecheck/typecheck_infer.cpp b/codon/parser/visitors/typecheck/typecheck_infer.cpp index 7347c4a0..329b51e2 100644 --- a/codon/parser/visitors/typecheck/typecheck_infer.cpp +++ b/codon/parser/visitors/typecheck/typecheck_infer.cpp @@ -232,6 +232,7 @@ types::TypePtr TypecheckVisitor::realizeFunc(types::FuncType *type) { r->ir->setParentType(getLLVMType(parent->getClass().get())); } r->ir->setGlobal(); + ctx->cache->pendingRealizations.insert({type->ast->name, type->realizedName()}); seqassert(!type || ast->args.size() == @@ -245,6 +246,29 @@ types::TypePtr TypecheckVisitor::realizeFunc(types::FuncType *type) { } r->ast = Nx(ast, type->realizedName(), nullptr, args, realized, ast->attributes); + + // Set up IR node + std::vector names; + std::vector types; + for (int i = 0, j = 1; i < r->ast->args.size(); i++) + if (!r->ast->args[i].generic) { + if (!type->args[j]->getFunc()) { + types.push_back(getLLVMType(type->args[j]->getClass().get())); + names.push_back( + ctx->cache->reverseIdentifierLookup[r->ast->args[i].name]); + } + j++; + } + if (r->ast->hasAttr(Attr::CVarArg)) { + types.pop_back(); + names.pop_back(); + } + auto irType = ctx->cache->module->unsafeGetFuncType( + type->realizedName(), getLLVMType(type->args[0]->getClass().get()), types, + r->ast->hasAttr(Attr::CVarArg)); + irType->setAstType(type->getFunc()); + r->ir->realize(irType, names); + ctx->cache->functions[type->ast->name].realizations[type->realizedName()] = r; } else { ctx->cache->functions[type->ast->name].realizations[oldKey] = diff --git a/test/parser/simplify_expr.codon b/test/parser/simplify_expr.codon index 1086055a..6be91e16 100644 --- a/test/parser/simplify_expr.codon +++ b/test/parser/simplify_expr.codon @@ -371,7 +371,7 @@ if (x := foo(4)) and False: print x #: 16 a = [y := foo(1), y+1, y+2] -print a, y #: [1, 2, 3] 1 +print a #: [1, 2, 3] print {y: b for y in [1,2,3] if (b := (y - 1))} #: {2: 1, 3: 2} print list(b for y in [1,2,3] if (b := (y // 3))) #: [1]