Add JIT engine

pull/6/head
A. R. Shajii 2021-10-22 11:01:32 -04:00
parent 681df2069b
commit 910c880666
5 changed files with 285 additions and 27 deletions

View File

@ -80,6 +80,7 @@ add_definitions(${LLVM_DEFINITIONS})
set(CODON_HPPFILES
codon/dsl/dsl.h
codon/dsl/plugins.h
codon/jit/engine.h
codon/parser/ast.h
codon/parser/ast/expr.h
codon/parser/ast/stmt.h
@ -201,6 +202,7 @@ set(CODON_HPPFILES
codon/util/toml++/toml_utf8_streams.h)
set(CODON_CPPFILES
codon/dsl/plugins.cpp
codon/jit/engine.cpp
codon/parser/ast/expr.cpp
codon/parser/ast/stmt.cpp
codon/parser/ast/types.cpp

137
codon/jit/engine.cpp Normal file
View File

@ -0,0 +1,137 @@
#include "engine.h"
#include "codon/sir/llvm/llvm.h"
#include "codon/sir/llvm/memory_manager.h"
#include "codon/sir/llvm/optimize.h"
#include "llvm/ExecutionEngine/Orc/CompileOnDemandLayer.h"
#include "llvm/ExecutionEngine/Orc/CompileUtils.h"
#include "llvm/ExecutionEngine/Orc/Core.h"
#include "llvm/ExecutionEngine/Orc/ExecutionUtils.h"
#include "llvm/ExecutionEngine/Orc/IRCompileLayer.h"
#include "llvm/ExecutionEngine/Orc/IRTransformLayer.h"
#include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h"
#include "llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h"
#include "llvm/ExecutionEngine/Orc/TPCIndirectionUtils.h"
#include "llvm/ExecutionEngine/Orc/TargetProcessControl.h"
namespace codon {
namespace jit {
class Engine {
private:
std::unique_ptr<llvm::orc::TargetProcessControl> tpc;
std::unique_ptr<llvm::orc::ExecutionSession> sess;
std::unique_ptr<llvm::orc::TPCIndirectionUtils> tpciu;
llvm::DataLayout layout;
llvm::orc::MangleAndInterner mangle;
llvm::orc::RTDyldObjectLinkingLayer objectLayer;
llvm::orc::IRCompileLayer compileLayer;
llvm::orc::IRTransformLayer optimizeLayer;
llvm::orc::CompileOnDemandLayer codLayer;
llvm::orc::JITDylib &mainJD;
static void handleLazyCallThroughError() {
llvm::errs() << "LazyCallThrough error: Could not find function body";
exit(1);
}
static llvm::Expected<llvm::orc::ThreadSafeModule>
optimizeModule(llvm::orc::ThreadSafeModule module,
const llvm::orc::MaterializationResponsibility &R) {
module.withModuleDo(
[](llvm::Module &module) { ir::optimize(&module, /*debug=*/true); });
return std::move(module);
}
public:
Engine(std::unique_ptr<llvm::orc::TargetProcessControl> tpc,
std::unique_ptr<llvm::orc::ExecutionSession> sess,
std::unique_ptr<llvm::orc::TPCIndirectionUtils> tpciu,
llvm::orc::JITTargetMachineBuilder jtmb, llvm::DataLayout layout)
: tpc(std::move(tpc)), sess(std::move(sess)), tpciu(std::move(tpciu)),
layout(std::move(layout)), mangle(*this->sess, this->layout),
objectLayer(*this->sess,
[]() { return std::make_unique<ir::BoehmGCMemoryManager>(); }),
compileLayer(
*this->sess, objectLayer,
std::make_unique<llvm::orc::ConcurrentIRCompiler>(std::move(jtmb))),
optimizeLayer(*this->sess, compileLayer, optimizeModule),
codLayer(*this->sess, optimizeLayer, this->tpciu->getLazyCallThroughManager(),
[this] { return this->tpciu->createIndirectStubsManager(); }),
mainJD(this->sess->createBareJITDylib("<main>")) {
mainJD.addGenerator(
llvm::cantFail(llvm::orc::DynamicLibrarySearchGenerator::GetForCurrentProcess(
layout.getGlobalPrefix())));
}
~Engine() {
if (auto err = sess->endSession())
sess->reportError(std::move(err));
if (auto err = tpciu->cleanup())
sess->reportError(std::move(err));
}
static llvm::Expected<std::unique_ptr<Engine>> create() {
auto ssp = std::make_shared<llvm::orc::SymbolStringPool>();
auto tpc = llvm::orc::SelfTargetProcessControl::Create(ssp);
if (!tpc)
return tpc.takeError();
auto sess = std::make_unique<llvm::orc::ExecutionSession>(std::move(ssp));
auto tpciu = llvm::orc::TPCIndirectionUtils::Create(**tpc);
if (!tpciu)
return tpciu.takeError();
(*tpciu)->createLazyCallThroughManager(
*sess, llvm::pointerToJITTargetAddress(&handleLazyCallThroughError));
if (auto err = llvm::orc::setUpInProcessLCTMReentryViaTPCIU(**tpciu))
return std::move(err);
llvm::orc::JITTargetMachineBuilder jtmb((*tpc)->getTargetTriple());
auto layout = jtmb.getDefaultDataLayoutForTarget();
if (!layout)
return layout.takeError();
return std::make_unique<Engine>(std::move(*tpc), std::move(sess), std::move(*tpciu),
std::move(jtmb), std::move(*layout));
}
const llvm::DataLayout &getDataLayout() const { return layout; }
llvm::orc::JITDylib &getMainJITDylib() { return mainJD; }
llvm::Error addModule(llvm::orc::ThreadSafeModule module,
llvm::orc::ResourceTrackerSP rt = nullptr) {
if (!rt)
rt = mainJD.getDefaultResourceTracker();
return optimizeLayer.add(rt, std::move(module));
}
llvm::Expected<llvm::JITEvaluatedSymbol> lookup(llvm::StringRef name) {
return sess->lookup({&mainJD}, mangle(name.str()));
}
};
JIT::JIT(ir::Module *module)
: module(module), pm(std::make_unique<ir::transform::PassManager>(/*debug=*/true)),
plm(std::make_unique<PluginManager>()),
llvisitor(std::make_unique<ir::LLVMVisitor>(/*debug=*/true)) {
if (auto e = Engine::create()) {
engine = std::move(e.get());
} else {
engine = {};
seqassert(false, "JIT engine creation error");
}
llvisitor->setPluginManager(plm.get());
}
} // namespace jit
} // namespace codon

27
codon/jit/engine.h Normal file
View File

@ -0,0 +1,27 @@
#pragma once
#include <memory>
#include "codon/sir/llvm/llvisitor.h"
#include "codon/sir/transform/manager.h"
namespace codon {
namespace jit {
class Engine;
class JIT {
private:
ir::Module *module;
std::unique_ptr<ir::transform::PassManager> pm;
std::unique_ptr<PluginManager> plm;
std::unique_ptr<ir::LLVMVisitor> llvisitor;
std::unique_ptr<Engine> engine;
public:
JIT(ir::Module *module);
ir::Module *getModule() const { return module; }
};
} // namespace jit
} // namespace codon

View File

@ -107,6 +107,100 @@ LLVMVisitor::LLVMVisitor(bool debug, const std::string &flags)
llvm::initializeTypePromotionPass(registry);
}
llvm::Value *LLVMVisitor::getVar(const Var *var) {
llvm::Value *val = vars[var];
if (!val)
return nullptr;
llvm::Module *m = nullptr;
if (auto *x = llvm::dyn_cast<llvm::Instruction>(val))
m = x->getModule();
else if (auto *x = llvm::dyn_cast<llvm::GlobalValue>(val))
m = x->getParent();
// the following happens when JIT'ing
if (var->isGlobal() && m != module.get()) {
// see if it's in the module already
auto name = var->getName();
if (auto *global = module->getNamedValue(name))
return global;
llvm::Type *llvmType = getLLVMType(var->getType());
auto *storage = new llvm::GlobalVariable(*module, llvmType, /*isConstant=*/false,
llvm::GlobalVariable::ExternalLinkage,
/*Initializer=*/nullptr, name);
storage->setExternallyInitialized(true);
// debug info
auto *srcInfo = getSrcInfo(var);
llvm::DIFile *file = db.getFile(srcInfo->file);
llvm::DIScope *scope = db.unit;
llvm::DIGlobalVariableExpression *debugVar =
db.builder->createGlobalVariableExpression(scope, getDebugNameForVariable(var),
name, file, srcInfo->line,
getDIType(var->getType()),
/*IsLocalToUnit=*/true);
storage->addDebugInfo(debugVar);
return storage;
}
// should never have a non-global val from another module
return val;
}
llvm::Function *LLVMVisitor::getFunc(const Func *func) {
llvm::Function *f = funcs[func];
if (!f)
return nullptr;
// the following happens when JIT'ing
if (f->getParent() != module.get()) {
// see if it's in the module already
if (auto *g = module->getFunction(f->getName()))
return g;
auto *g =
llvm::Function::Create(f->getFunctionType(), llvm::Function::ExternalLinkage,
f->getName(), module.get());
g->copyAttributesFrom(f);
return g;
}
return f;
}
std::unique_ptr<llvm::Module> LLVMVisitor::makeModule(const SrcInfo *src) {
auto module = std::make_unique<llvm::Module>("codon", context);
module->setTargetTriple(
llvm::EngineBuilder().selectTarget()->getTargetTriple().str());
module->setDataLayout(llvm::EngineBuilder().selectTarget()->createDataLayout());
if (src) {
module->setSourceFileName(src->file);
// debug info setup
db.builder = std::make_unique<llvm::DIBuilder>(*module);
llvm::DIFile *file = db.getFile(src->file);
db.unit = db.builder->createCompileUnit(llvm::dwarf::DW_LANG_C, file,
("codon version " CODON_VERSION), !db.debug,
db.flags,
/*RV=*/0);
}
module->addModuleFlag(llvm::Module::Warning, "Debug Info Version",
llvm::DEBUG_METADATA_VERSION);
// darwin only supports dwarf2
if (llvm::Triple(module->getTargetTriple()).isOSDarwin()) {
module->addModuleFlag(llvm::Module::Warning, "Dwarf Version", 2);
}
return module;
}
std::unique_ptr<llvm::Module> LLVMVisitor::takeModule(const SrcInfo *src) {
auto currentModule = std::move(module);
module = makeModule(src);
return currentModule;
}
void LLVMVisitor::setDebugInfoForNode(const Node *x) {
if (x && func) {
auto *srcInfo = getSrcInfo(x);
@ -414,26 +508,7 @@ LLVMVisitor::TryCatchData *LLVMVisitor::getInnermostTryCatchBeforeLoop() {
*/
void LLVMVisitor::visit(const Module *x) {
module = std::make_unique<llvm::Module>("codon", context);
module->setTargetTriple(
llvm::EngineBuilder().selectTarget()->getTargetTriple().str());
module->setDataLayout(llvm::EngineBuilder().selectTarget()->createDataLayout());
auto *srcInfo = getSrcInfo(x->getMainFunc());
module->setSourceFileName(srcInfo->file);
// debug info setup
db.builder = std::make_unique<llvm::DIBuilder>(*module);
llvm::DIFile *file = db.getFile(srcInfo->file);
db.unit = db.builder->createCompileUnit(llvm::dwarf::DW_LANG_C, file,
("codon version " CODON_VERSION), !db.debug,
db.flags,
/*RV=*/0);
module->addModuleFlag(llvm::Module::Warning, "Debug Info Version",
llvm::DEBUG_METADATA_VERSION);
// darwin only supports dwarf2
if (llvm::Triple(module->getTargetTriple()).isOSDarwin()) {
module->addModuleFlag(llvm::Module::Warning, "Dwarf Version", 2);
}
module = makeModule(getSrcInfo(x));
// args variable
const Var *argVar = x->getArgVar();
@ -1057,10 +1132,10 @@ void LLVMVisitor::visit(const Var *x) { seqassert(0, "cannot visit var"); }
void LLVMVisitor::visit(const VarValue *x) {
if (auto *f = cast<Func>(x->getVar())) {
value = funcs[f];
value = getFunc(f);
seqassert(value, "{} value not found", *x);
} else {
llvm::Value *varPtr = vars[x->getVar()];
llvm::Value *varPtr = getVar(x->getVar());
seqassert(varPtr, "{} value not found", *x);
builder.SetInsertPoint(block);
value = builder.CreateLoad(varPtr);
@ -1068,7 +1143,7 @@ void LLVMVisitor::visit(const VarValue *x) {
}
void LLVMVisitor::visit(const PointerValue *x) {
llvm::Value *var = vars[x->getVar()];
llvm::Value *var = getVar(x->getVar());
seqassert(var, "{} variable not found", *x);
value = var; // note: we don't load the pointer
}
@ -1415,7 +1490,7 @@ void LLVMVisitor::visit(const WhileFlow *x) {
void LLVMVisitor::visit(const ForFlow *x) {
seqassert(!x->isParallel(), "parallel for-loop not lowered");
llvm::Type *loopVarType = getLLVMType(x->getVar()->getType());
llvm::Value *loopVar = vars[x->getVar()];
llvm::Value *loopVar = getVar(x->getVar());
seqassert(loopVar, "{} loop variable not found", *x);
auto *condBlock = llvm::BasicBlock::Create(context, "for.cond", func);
@ -1473,7 +1548,7 @@ void LLVMVisitor::visit(const ForFlow *x) {
void LLVMVisitor::visit(const ImperativeForFlow *x) {
seqassert(!x->isParallel(), "parallel for-loop not lowered");
llvm::Value *loopVar = vars[x->getVar()];
llvm::Value *loopVar = getVar(x->getVar());
seqassert(loopVar, "{} loop variable not found", *x);
seqassert(x->getStep() != 0, "step cannot be 0");
@ -1828,7 +1903,7 @@ void LLVMVisitor::visit(const TryCatchFlow *x) {
if (var) {
llvm::Value *obj =
builder.CreateBitCast(objPtr, getLLVMType(catches[i]->getType()));
llvm::Value *varPtr = vars[var];
llvm::Value *varPtr = getVar(var);
seqassert(varPtr, "could not get catch var");
builder.CreateStore(obj, varPtr);
}
@ -1953,7 +2028,7 @@ void LLVMVisitor::visit(const dsl::CustomFlow *x) {
*/
void LLVMVisitor::visit(const AssignInstr *x) {
llvm::Value *var = vars[x->getLhs()];
llvm::Value *var = getVar(x->getLhs());
seqassert(var, "could not find {} var", *x);
process(x->getRhs());
if (var != getDummyVoidValue(context)) {

View File

@ -3,6 +3,7 @@
#include "codon/dsl/plugins.h"
#include "codon/sir/llvm/llvm.h"
#include "codon/sir/sir.h"
#include "codon/util/common.h"
#include <string>
#include <unordered_map>
@ -175,6 +176,9 @@ private:
// LLVM passes
void runLLVMPipeline();
llvm::Value *getVar(const Var *var);
llvm::Function *getFunc(const Func *func);
public:
LLVMVisitor(bool debug = false, const std::string &flags = "");
@ -195,6 +199,19 @@ public:
void setBlock(llvm::BasicBlock *b) { block = b; }
void setValue(llvm::Value *v) { value = v; }
/// Returns a new LLVM module initialized for the host
/// architecture.
/// @param src source information for the new module
/// @return a new module
std::unique_ptr<llvm::Module> makeModule(const SrcInfo *src = nullptr);
/// Returns the current LLVM module and replaces it with a
/// new, fresh one. References to variables or functions
/// from the old module will be included as "external".
/// @param src source information for the new module
/// @return the current module, replaced internally
std::unique_ptr<llvm::Module> takeModule(const SrcInfo *src = nullptr);
/// Sets current debug info based on a given node.
/// @param node the node whose debug info to use
void setDebugInfoForNode(const Node *node);