From 62e5577a1e2304a24d156c3ad79c75aa95db09cc Mon Sep 17 00:00:00 2001 From: "A. R. Shajii" Date: Fri, 22 Oct 2021 17:57:23 -0400 Subject: [PATCH] More engine updates --- codon/jit/engine.cpp | 29 ++- codon/jit/engine.h | 4 + codon/sir/llvm/llvisitor.cpp | 216 ++++++++---------- codon/sir/llvm/llvisitor.h | 41 +++- codon/sir/transform/cleanup/global_demote.cpp | 2 +- 5 files changed, 166 insertions(+), 126 deletions(-) diff --git a/codon/jit/engine.cpp b/codon/jit/engine.cpp index 2f800233..613e4935 100644 --- a/codon/jit/engine.cpp +++ b/codon/jit/engine.cpp @@ -120,10 +120,13 @@ public: } }; +typedef int MainFunc(int, char **); +typedef void InputFunc(); + JIT::JIT(ir::Module *module) : module(module), pm(std::make_unique(/*debug=*/true)), plm(std::make_unique()), - llvisitor(std::make_unique(/*debug=*/true)) { + llvisitor(std::make_unique(/*debug=*/true, /*jit=*/true)) { if (auto e = Engine::create()) { engine = std::move(e.get()); } else { @@ -133,5 +136,29 @@ JIT::JIT(ir::Module *module) llvisitor->setPluginManager(plm.get()); } +void JIT::init() { + module->accept(*llvisitor); + auto module = llvisitor->takeModule(); + llvm::cantFail( + engine->addModule({std::move(module), std::make_unique()})); + auto func = llvm::cantFail(engine->lookup("main")); + auto *main = (MainFunc *)func.getAddress(); + (*main)(0, nullptr); +} + +void JIT::run(const ir::Func *input, const std::vector &newGlobals) { + const std::string name = ir::LLVMVisitor::getNameForFunction(input); + llvisitor->registerGlobal(input); + for (auto *var : newGlobals) + llvisitor->registerGlobal(var); + input->accept(*llvisitor); + auto module = llvisitor->takeModule(); + llvm::cantFail( + engine->addModule({std::move(module), std::make_unique()})); + auto func = llvm::cantFail(engine->lookup(name)); + auto *repl = (InputFunc *)func.getAddress(); + (*repl)(); +} + } // namespace jit } // namespace codon diff --git a/codon/jit/engine.h b/codon/jit/engine.h index bda58e93..f180ce7b 100644 --- a/codon/jit/engine.h +++ b/codon/jit/engine.h @@ -1,9 +1,11 @@ #pragma once #include +#include #include "codon/sir/llvm/llvisitor.h" #include "codon/sir/transform/manager.h" +#include "codon/sir/var.h" namespace codon { namespace jit { @@ -21,6 +23,8 @@ private: public: JIT(ir::Module *module); ir::Module *getModule() const { return module; } + void init(); + void run(const ir::Func *input, const std::vector &newGlobals = {}); }; } // namespace jit diff --git a/codon/sir/llvm/llvisitor.cpp b/codon/sir/llvm/llvisitor.cpp index a358bc66..ab442123 100644 --- a/codon/sir/llvm/llvisitor.cpp +++ b/codon/sir/llvm/llvisitor.cpp @@ -13,39 +13,6 @@ namespace codon { namespace ir { -namespace { -std::string getNameForFunction(const Func *x) { - if (auto *externalFunc = cast(x)) { - return x->getUnmangledName(); - } else { - return x->referenceString(); - } -} - -std::string getDebugNameForVariable(const Var *x) { - std::string name = x->getName(); - auto pos = name.find("."); - if (pos != 0 && pos != std::string::npos) { - return name.substr(0, pos); - } else { - return name; - } -} - -const SrcInfo *getSrcInfo(const Node *x) { - if (auto *srcInfo = x->getAttribute()) { - return &srcInfo->info; - } else { - static SrcInfo defaultSrcInfo("", 0, 0, 0); - return &defaultSrcInfo; - } -} - -llvm::Value *getDummyVoidValue(llvm::LLVMContext &context) { - return llvm::ConstantTokenNone::get(context); -} -} // namespace - llvm::DIFile *LLVMVisitor::DebugInfo::getFile(const std::string &path) { std::string filename; std::string directory; @@ -60,10 +27,10 @@ llvm::DIFile *LLVMVisitor::DebugInfo::getFile(const std::string &path) { return builder->createFile(filename, directory); } -LLVMVisitor::LLVMVisitor(bool debug, const std::string &flags) +LLVMVisitor::LLVMVisitor(bool debug, bool jit, const std::string &flags) : util::ConstVisitor(), context(), builder(context), module(), func(nullptr), block(nullptr), value(nullptr), vars(), funcs(), coro(), loops(), trycatch(), - db(debug, flags), plugins(nullptr) { + db(debug, jit, flags), plugins(nullptr) { llvm::InitializeAllTargets(); llvm::InitializeAllTargetMCs(); llvm::InitializeAllAsmPrinters(); @@ -107,65 +74,102 @@ LLVMVisitor::LLVMVisitor(bool debug, const std::string &flags) llvm::initializeTypePromotionPass(registry); } +void LLVMVisitor::registerGlobal(const Var *var) { + if (!var->isGlobal()) + return; + + if (auto *f = cast(var)) { + makeLLVMFunction(f); + funcs.insert(f, func); + } else { + llvm::Type *llvmType = getLLVMType(var->getType()); + if (llvmType->isVoidTy()) { + vars.insert(var, getDummyVoidValue()); + } else { + auto *storage = new llvm::GlobalVariable( + *module, llvmType, /*isConstant=*/false, llvm::GlobalVariable::PrivateLinkage, + llvm::Constant::getNullValue(llvmType), var->getName()); + vars.insert(var, storage); + + // debug info + auto *srcInfo = getSrcInfo(var); + llvm::DIFile *file = db.getFile(srcInfo->file); + llvm::DIScope *scope = db.unit; + llvm::DIGlobalVariableExpression *debugVar = + db.builder->createGlobalVariableExpression( + scope, getDebugNameForVariable(var), var->getName(), file, srcInfo->line, + getDIType(var->getType()), + /*IsLocalToUnit=*/true); + storage->addDebugInfo(debugVar); + } + } +} + llvm::Value *LLVMVisitor::getVar(const Var *var) { llvm::Value *val = vars[var]; - if (!val) - return nullptr; + if (db.jit && var->isGlobal()) { + if (val) { + llvm::Module *m = nullptr; + if (auto *x = llvm::dyn_cast(val)) + m = x->getModule(); + else if (auto *x = llvm::dyn_cast(val)) + m = x->getParent(); - llvm::Module *m = nullptr; - if (auto *x = llvm::dyn_cast(val)) - m = x->getModule(); - else if (auto *x = llvm::dyn_cast(val)) - m = x->getParent(); + if (m != module.get()) { + // see if it's in the module already + auto name = var->getName(); + if (auto *global = module->getNamedValue(name)) + return global; - // the following happens when JIT'ing - if (var->isGlobal() && m != module.get()) { - // see if it's in the module already - auto name = var->getName(); - if (auto *global = module->getNamedValue(name)) - return global; + llvm::Type *llvmType = getLLVMType(var->getType()); + auto *storage = + new llvm::GlobalVariable(*module, llvmType, /*isConstant=*/false, + llvm::GlobalVariable::ExternalLinkage, + /*Initializer=*/nullptr, name); + storage->setExternallyInitialized(true); - llvm::Type *llvmType = getLLVMType(var->getType()); - auto *storage = new llvm::GlobalVariable(*module, llvmType, /*isConstant=*/false, - llvm::GlobalVariable::ExternalLinkage, - /*Initializer=*/nullptr, name); - storage->setExternallyInitialized(true); - - // debug info - auto *srcInfo = getSrcInfo(var); - llvm::DIFile *file = db.getFile(srcInfo->file); - llvm::DIScope *scope = db.unit; - llvm::DIGlobalVariableExpression *debugVar = - db.builder->createGlobalVariableExpression(scope, getDebugNameForVariable(var), - name, file, srcInfo->line, - getDIType(var->getType()), - /*IsLocalToUnit=*/true); - storage->addDebugInfo(debugVar); - return storage; + // debug info + auto *srcInfo = getSrcInfo(var); + llvm::DIFile *file = db.getFile(srcInfo->file); + llvm::DIScope *scope = db.unit; + llvm::DIGlobalVariableExpression *debugVar = + db.builder->createGlobalVariableExpression( + scope, getDebugNameForVariable(var), name, file, srcInfo->line, + getDIType(var->getType()), + /*IsLocalToUnit=*/true); + storage->addDebugInfo(debugVar); + vars.insert(var, storage); + return storage; + } + } else { + registerGlobal(var); + return vars[var]; + } } - - // should never have a non-global val from another module return val; } llvm::Function *LLVMVisitor::getFunc(const Func *func) { llvm::Function *f = funcs[func]; - if (!f) - return nullptr; + if (db.jit) { + if (f) { + if (f->getParent() != module.get()) { + // see if it's in the module already + if (auto *g = module->getFunction(f->getName())) + return g; - // the following happens when JIT'ing - if (f->getParent() != module.get()) { - // see if it's in the module already - if (auto *g = module->getFunction(f->getName())) - return g; - - auto *g = - llvm::Function::Create(f->getFunctionType(), llvm::Function::ExternalLinkage, - f->getName(), module.get()); - g->copyAttributesFrom(f); - return g; + auto *g = llvm::Function::Create(f->getFunctionType(), + llvm::Function::ExternalLinkage, f->getName(), + module.get()); + g->copyAttributesFrom(f); + funcs.insert(func, g); + return g; + } + } else { + registerGlobal(func); + return funcs[func]; + } } - return f; } @@ -511,44 +515,12 @@ void LLVMVisitor::visit(const Module *x) { module = makeModule(getSrcInfo(x)); // args variable - const Var *argVar = x->getArgVar(); - llvm::Type *argVarType = getLLVMType(argVar->getType()); - auto *argStorage = new llvm::GlobalVariable( - *module, argVarType, /*isConstant=*/false, llvm::GlobalValue::PrivateLinkage, - llvm::Constant::getNullValue(argVarType), argVar->getName()); - vars.insert(argVar, argStorage); + seqassert(x->getArgVar()->isGlobal(), "arg var is not global"); + registerGlobal(x->getArgVar()); // set up global variables and initialize functions for (auto *var : *x) { - if (!var->isGlobal()) - continue; - - if (auto *f = cast(var)) { - makeLLVMFunction(f); - funcs.insert(f, func); - } else { - llvm::Type *llvmType = getLLVMType(var->getType()); - if (llvmType->isVoidTy()) { - vars.insert(var, getDummyVoidValue(context)); - } else { - auto *storage = new llvm::GlobalVariable( - *module, llvmType, /*isConstant=*/false, - llvm::GlobalVariable::PrivateLinkage, - llvm::Constant::getNullValue(llvmType), var->getName()); - vars.insert(var, storage); - - // debug info - auto *srcInfo = getSrcInfo(var); - llvm::DIFile *file = db.getFile(srcInfo->file); - llvm::DIScope *scope = db.unit; - llvm::DIGlobalVariableExpression *debugVar = - db.builder->createGlobalVariableExpression( - scope, getDebugNameForVariable(var), var->getName(), file, - srcInfo->line, getDIType(var->getType()), - /*IsLocalToUnit=*/true); - storage->addDebugInfo(debugVar); - } - } + registerGlobal(var); } // process functions @@ -632,6 +604,8 @@ void LLVMVisitor::visit(const Module *x) { builder.CreateBr(loopBlock); builder.SetInsertPoint(exitBlock); + llvm::Value *argStorage = vars[x->getArgVar()]; + seqassert(argStorage, "argument storage missing"); builder.CreateStore(arr, argStorage); builder.CreateCall(initFunc, builder.getInt32(db.debug ? 1 : 0)); @@ -739,7 +713,7 @@ void LLVMVisitor::makeYield(llvm::Value *value, bool finalYield) { } void LLVMVisitor::visit(const ExternalFunc *x) { - func = module->getFunction(getNameForFunction(x)); // inserted during module visit + func = module->getFunction(getNameForFunction(x)); coro = {}; seqassert(func, "{} not inserted", *x); func->setDoesNotThrow(); @@ -775,7 +749,7 @@ bool internalFuncMatches(const std::string &name, const InternalFunc *x) { void LLVMVisitor::visit(const InternalFunc *x) { using namespace types; - func = module->getFunction(getNameForFunction(x)); // inserted during module visit + func = module->getFunction(getNameForFunction(x)); coro = {}; seqassert(func, "{} not inserted", *x); setDebugInfoForNode(x); @@ -953,7 +927,7 @@ void LLVMVisitor::visit(const LLVMFunc *x) { } void LLVMVisitor::visit(const BodiedFunc *x) { - func = module->getFunction(getNameForFunction(x)); // inserted during module visit + func = module->getFunction(getNameForFunction(x)); coro = {}; seqassert(func, "{} not inserted", *x); setDebugInfoForNode(x); @@ -1009,7 +983,7 @@ void LLVMVisitor::visit(const BodiedFunc *x) { for (auto *var : *x) { llvm::Type *llvmType = getLLVMType(var->getType()); if (llvmType->isVoidTy()) { - vars.insert(var, getDummyVoidValue(context)); + vars.insert(var, getDummyVoidValue()); } else { llvm::Value *storage = builder.CreateAlloca(llvmType); vars.insert(var, storage); @@ -2031,7 +2005,7 @@ void LLVMVisitor::visit(const AssignInstr *x) { llvm::Value *var = getVar(x->getLhs()); seqassert(var, "could not find {} var", *x); process(x->getRhs()); - if (var != getDummyVoidValue(context)) { + if (var != getDummyVoidValue()) { builder.SetInsertPoint(block); builder.CreateStore(value, var); } diff --git a/codon/sir/llvm/llvisitor.h b/codon/sir/llvm/llvisitor.h index 82fcdfd4..74f59a87 100644 --- a/codon/sir/llvm/llvisitor.h +++ b/codon/sir/llvm/llvisitor.h @@ -97,11 +97,13 @@ private: llvm::DICompileUnit *unit; /// Whether we are compiling in debug mode bool debug; + /// Whether we are compiling in JIT mode + bool jit; /// Program command-line flags std::string flags; - explicit DebugInfo(bool debug, const std::string &flags) - : builder(), unit(nullptr), debug(debug), flags(flags) {} + DebugInfo(bool debug, bool jit, const std::string &flags) + : builder(), unit(nullptr), debug(debug), jit(jit), flags(flags) {} llvm::DIFile *getFile(const std::string &path); }; @@ -178,9 +180,37 @@ private: llvm::Value *getVar(const Var *var); llvm::Function *getFunc(const Func *func); + llvm::Value *getDummyVoidValue() { return llvm::ConstantTokenNone::get(context); } public: - LLVMVisitor(bool debug = false, const std::string &flags = ""); + static std::string getNameForFunction(const Func *x) { + if (auto *externalFunc = cast(x)) { + return x->getUnmangledName(); + } else { + return x->referenceString(); + } + } + + static std::string getDebugNameForVariable(const Var *x) { + std::string name = x->getName(); + auto pos = name.find("."); + if (pos != 0 && pos != std::string::npos) { + return name.substr(0, pos); + } else { + return name; + } + } + + static const SrcInfo *getSrcInfo(const Node *x) { + if (auto *srcInfo = x->getAttribute()) { + return &srcInfo->info; + } else { + static SrcInfo defaultSrcInfo("", 0, 0, 0); + return &defaultSrcInfo; + } + } + + LLVMVisitor(bool debug = false, bool jit = false, const std::string &flags = ""); llvm::LLVMContext &getContext() { return context; } llvm::IRBuilder<> &getBuilder() { return builder; } @@ -199,6 +229,11 @@ public: void setBlock(llvm::BasicBlock *b) { block = b; } void setValue(llvm::Value *v) { value = v; } + /// Registers a new global variable or function with + /// this visitor. + /// @param var the global variable (or function) to register + void registerGlobal(const Var *var); + /// Returns a new LLVM module initialized for the host /// architecture. /// @param src source information for the new module diff --git a/codon/sir/transform/cleanup/global_demote.cpp b/codon/sir/transform/cleanup/global_demote.cpp index f8270e0f..90fb44b2 100644 --- a/codon/sir/transform/cleanup/global_demote.cpp +++ b/codon/sir/transform/cleanup/global_demote.cpp @@ -45,7 +45,7 @@ void GlobalDemotionPass::run(Module *M) { } for (auto it : localGlobals) { - if (!it.second) + if (!it.second || it.first->getId() == M->getArgVar()->getId()) continue; seqassert(it.first->isGlobal(), "var was not global"); it.first->setGlobal(false);