Fix pipeline JIT bug

pull/7/head
Ibrahim Numanagić 2021-12-10 10:46:48 -08:00
parent 2c72eda249
commit 5834228f3d
3 changed files with 25 additions and 15 deletions

View File

@ -411,26 +411,18 @@ void TranslateVisitor::transformFunction(types::FuncType *type, FunctionStmt *as
ir::Func *func) {
std::vector<std::string> names;
std::vector<int> indices;
std::vector<SrcInfo> srcInfos;
std::vector<codon::ir::types::Type *> types;
for (int i = 0, j = 1; i < ast->args.size(); i++)
if (!ast->args[i].generic) {
if (!type->args[j]->getFunc()) {
types.push_back(getType(type->args[j]));
names.push_back(ctx->cache->reverseIdentifierLookup[ast->args[i].name]);
indices.push_back(i);
}
j++;
}
if (ast->hasAttr(Attr::CVarArg)) {
types.pop_back();
names.pop_back();
indices.pop_back();
}
auto irType = ctx->getModule()->unsafeGetFuncType(
type->realizedName(), getType(type->args[0]), types, ast->hasAttr(Attr::CVarArg));
irType->setAstType(type->getFunc());
func->realize(irType, names);
// TODO: refactor IR attribute API
std::map<std::string, std::string> attr;
attr[".module"] = ast->attributes.module;
@ -461,20 +453,14 @@ void TranslateVisitor::transformFunction(types::FuncType *type, FunctionStmt *as
void TranslateVisitor::transformLLVMFunction(types::FuncType *type, FunctionStmt *ast,
ir::Func *func) {
std::vector<std::string> names;
std::vector<codon::ir::types::Type *> types;
std::vector<int> indices;
for (int i = 0, j = 1; i < ast->args.size(); i++)
if (!ast->args[i].generic) {
types.push_back(getType(type->args[j]));
names.push_back(ctx->cache->reverseIdentifierLookup[ast->args[i].name]);
indices.push_back(i);
j++;
}
auto irType = ctx->getModule()->unsafeGetFuncType(type->realizedName(),
getType(type->args[0]), types);
irType->setAstType(type->getFunc());
auto f = cast<ir::LLVMFunc>(func);
f->realize(irType, names);
// TODO: refactor IR attribute API
std::map<std::string, std::string> attr;
attr[".module"] = ast->attributes.module;

View File

@ -232,6 +232,7 @@ types::TypePtr TypecheckVisitor::realizeFunc(types::FuncType *type) {
r->ir->setParentType(getLLVMType(parent->getClass().get()));
}
r->ir->setGlobal();
ctx->cache->pendingRealizations.insert({type->ast->name, type->realizedName()});
seqassert(!type || ast->args.size() ==
@ -245,6 +246,29 @@ types::TypePtr TypecheckVisitor::realizeFunc(types::FuncType *type) {
}
r->ast = Nx<FunctionStmt>(ast, type->realizedName(), nullptr, args, realized,
ast->attributes);
// Set up IR node
std::vector<std::string> names;
std::vector<codon::ir::types::Type *> types;
for (int i = 0, j = 1; i < r->ast->args.size(); i++)
if (!r->ast->args[i].generic) {
if (!type->args[j]->getFunc()) {
types.push_back(getLLVMType(type->args[j]->getClass().get()));
names.push_back(
ctx->cache->reverseIdentifierLookup[r->ast->args[i].name]);
}
j++;
}
if (r->ast->hasAttr(Attr::CVarArg)) {
types.pop_back();
names.pop_back();
}
auto irType = ctx->cache->module->unsafeGetFuncType(
type->realizedName(), getLLVMType(type->args[0]->getClass().get()), types,
r->ast->hasAttr(Attr::CVarArg));
irType->setAstType(type->getFunc());
r->ir->realize(irType, names);
ctx->cache->functions[type->ast->name].realizations[type->realizedName()] = r;
} else {
ctx->cache->functions[type->ast->name].realizations[oldKey] =

View File

@ -371,7 +371,7 @@ if (x := foo(4)) and False:
print x #: 16
a = [y := foo(1), y+1, y+2]
print a, y #: [1, 2, 3] 1
print a #: [1, 2, 3]
print {y: b for y in [1,2,3] if (b := (y - 1))} #: {2: 1, 3: 2}
print list(b for y in [1,2,3] if (b := (y // 3))) #: [1]