mirror of https://github.com/exaloop/codon
More engine updates
parent
910c880666
commit
62e5577a1e
|
@ -120,10 +120,13 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
typedef int MainFunc(int, char **);
|
||||
typedef void InputFunc();
|
||||
|
||||
JIT::JIT(ir::Module *module)
|
||||
: module(module), pm(std::make_unique<ir::transform::PassManager>(/*debug=*/true)),
|
||||
plm(std::make_unique<PluginManager>()),
|
||||
llvisitor(std::make_unique<ir::LLVMVisitor>(/*debug=*/true)) {
|
||||
llvisitor(std::make_unique<ir::LLVMVisitor>(/*debug=*/true, /*jit=*/true)) {
|
||||
if (auto e = Engine::create()) {
|
||||
engine = std::move(e.get());
|
||||
} else {
|
||||
|
@ -133,5 +136,29 @@ JIT::JIT(ir::Module *module)
|
|||
llvisitor->setPluginManager(plm.get());
|
||||
}
|
||||
|
||||
void JIT::init() {
|
||||
module->accept(*llvisitor);
|
||||
auto module = llvisitor->takeModule();
|
||||
llvm::cantFail(
|
||||
engine->addModule({std::move(module), std::make_unique<llvm::LLVMContext>()}));
|
||||
auto func = llvm::cantFail(engine->lookup("main"));
|
||||
auto *main = (MainFunc *)func.getAddress();
|
||||
(*main)(0, nullptr);
|
||||
}
|
||||
|
||||
void JIT::run(const ir::Func *input, const std::vector<ir::Var *> &newGlobals) {
|
||||
const std::string name = ir::LLVMVisitor::getNameForFunction(input);
|
||||
llvisitor->registerGlobal(input);
|
||||
for (auto *var : newGlobals)
|
||||
llvisitor->registerGlobal(var);
|
||||
input->accept(*llvisitor);
|
||||
auto module = llvisitor->takeModule();
|
||||
llvm::cantFail(
|
||||
engine->addModule({std::move(module), std::make_unique<llvm::LLVMContext>()}));
|
||||
auto func = llvm::cantFail(engine->lookup(name));
|
||||
auto *repl = (InputFunc *)func.getAddress();
|
||||
(*repl)();
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace codon
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "codon/sir/llvm/llvisitor.h"
|
||||
#include "codon/sir/transform/manager.h"
|
||||
#include "codon/sir/var.h"
|
||||
|
||||
namespace codon {
|
||||
namespace jit {
|
||||
|
@ -21,6 +23,8 @@ private:
|
|||
public:
|
||||
JIT(ir::Module *module);
|
||||
ir::Module *getModule() const { return module; }
|
||||
void init();
|
||||
void run(const ir::Func *input, const std::vector<ir::Var *> &newGlobals = {});
|
||||
};
|
||||
|
||||
} // namespace jit
|
||||
|
|
|
@ -13,39 +13,6 @@
|
|||
|
||||
namespace codon {
|
||||
namespace ir {
|
||||
namespace {
|
||||
std::string getNameForFunction(const Func *x) {
|
||||
if (auto *externalFunc = cast<ExternalFunc>(x)) {
|
||||
return x->getUnmangledName();
|
||||
} else {
|
||||
return x->referenceString();
|
||||
}
|
||||
}
|
||||
|
||||
std::string getDebugNameForVariable(const Var *x) {
|
||||
std::string name = x->getName();
|
||||
auto pos = name.find(".");
|
||||
if (pos != 0 && pos != std::string::npos) {
|
||||
return name.substr(0, pos);
|
||||
} else {
|
||||
return name;
|
||||
}
|
||||
}
|
||||
|
||||
const SrcInfo *getSrcInfo(const Node *x) {
|
||||
if (auto *srcInfo = x->getAttribute<SrcInfoAttribute>()) {
|
||||
return &srcInfo->info;
|
||||
} else {
|
||||
static SrcInfo defaultSrcInfo("<internal>", 0, 0, 0);
|
||||
return &defaultSrcInfo;
|
||||
}
|
||||
}
|
||||
|
||||
llvm::Value *getDummyVoidValue(llvm::LLVMContext &context) {
|
||||
return llvm::ConstantTokenNone::get(context);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
llvm::DIFile *LLVMVisitor::DebugInfo::getFile(const std::string &path) {
|
||||
std::string filename;
|
||||
std::string directory;
|
||||
|
@ -60,10 +27,10 @@ llvm::DIFile *LLVMVisitor::DebugInfo::getFile(const std::string &path) {
|
|||
return builder->createFile(filename, directory);
|
||||
}
|
||||
|
||||
LLVMVisitor::LLVMVisitor(bool debug, const std::string &flags)
|
||||
LLVMVisitor::LLVMVisitor(bool debug, bool jit, const std::string &flags)
|
||||
: util::ConstVisitor(), context(), builder(context), module(), func(nullptr),
|
||||
block(nullptr), value(nullptr), vars(), funcs(), coro(), loops(), trycatch(),
|
||||
db(debug, flags), plugins(nullptr) {
|
||||
db(debug, jit, flags), plugins(nullptr) {
|
||||
llvm::InitializeAllTargets();
|
||||
llvm::InitializeAllTargetMCs();
|
||||
llvm::InitializeAllAsmPrinters();
|
||||
|
@ -107,65 +74,102 @@ LLVMVisitor::LLVMVisitor(bool debug, const std::string &flags)
|
|||
llvm::initializeTypePromotionPass(registry);
|
||||
}
|
||||
|
||||
void LLVMVisitor::registerGlobal(const Var *var) {
|
||||
if (!var->isGlobal())
|
||||
return;
|
||||
|
||||
if (auto *f = cast<Func>(var)) {
|
||||
makeLLVMFunction(f);
|
||||
funcs.insert(f, func);
|
||||
} else {
|
||||
llvm::Type *llvmType = getLLVMType(var->getType());
|
||||
if (llvmType->isVoidTy()) {
|
||||
vars.insert(var, getDummyVoidValue());
|
||||
} else {
|
||||
auto *storage = new llvm::GlobalVariable(
|
||||
*module, llvmType, /*isConstant=*/false, llvm::GlobalVariable::PrivateLinkage,
|
||||
llvm::Constant::getNullValue(llvmType), var->getName());
|
||||
vars.insert(var, storage);
|
||||
|
||||
// debug info
|
||||
auto *srcInfo = getSrcInfo(var);
|
||||
llvm::DIFile *file = db.getFile(srcInfo->file);
|
||||
llvm::DIScope *scope = db.unit;
|
||||
llvm::DIGlobalVariableExpression *debugVar =
|
||||
db.builder->createGlobalVariableExpression(
|
||||
scope, getDebugNameForVariable(var), var->getName(), file, srcInfo->line,
|
||||
getDIType(var->getType()),
|
||||
/*IsLocalToUnit=*/true);
|
||||
storage->addDebugInfo(debugVar);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
llvm::Value *LLVMVisitor::getVar(const Var *var) {
|
||||
llvm::Value *val = vars[var];
|
||||
if (!val)
|
||||
return nullptr;
|
||||
if (db.jit && var->isGlobal()) {
|
||||
if (val) {
|
||||
llvm::Module *m = nullptr;
|
||||
if (auto *x = llvm::dyn_cast<llvm::Instruction>(val))
|
||||
m = x->getModule();
|
||||
else if (auto *x = llvm::dyn_cast<llvm::GlobalValue>(val))
|
||||
m = x->getParent();
|
||||
|
||||
llvm::Module *m = nullptr;
|
||||
if (auto *x = llvm::dyn_cast<llvm::Instruction>(val))
|
||||
m = x->getModule();
|
||||
else if (auto *x = llvm::dyn_cast<llvm::GlobalValue>(val))
|
||||
m = x->getParent();
|
||||
if (m != module.get()) {
|
||||
// see if it's in the module already
|
||||
auto name = var->getName();
|
||||
if (auto *global = module->getNamedValue(name))
|
||||
return global;
|
||||
|
||||
// the following happens when JIT'ing
|
||||
if (var->isGlobal() && m != module.get()) {
|
||||
// see if it's in the module already
|
||||
auto name = var->getName();
|
||||
if (auto *global = module->getNamedValue(name))
|
||||
return global;
|
||||
llvm::Type *llvmType = getLLVMType(var->getType());
|
||||
auto *storage =
|
||||
new llvm::GlobalVariable(*module, llvmType, /*isConstant=*/false,
|
||||
llvm::GlobalVariable::ExternalLinkage,
|
||||
/*Initializer=*/nullptr, name);
|
||||
storage->setExternallyInitialized(true);
|
||||
|
||||
llvm::Type *llvmType = getLLVMType(var->getType());
|
||||
auto *storage = new llvm::GlobalVariable(*module, llvmType, /*isConstant=*/false,
|
||||
llvm::GlobalVariable::ExternalLinkage,
|
||||
/*Initializer=*/nullptr, name);
|
||||
storage->setExternallyInitialized(true);
|
||||
|
||||
// debug info
|
||||
auto *srcInfo = getSrcInfo(var);
|
||||
llvm::DIFile *file = db.getFile(srcInfo->file);
|
||||
llvm::DIScope *scope = db.unit;
|
||||
llvm::DIGlobalVariableExpression *debugVar =
|
||||
db.builder->createGlobalVariableExpression(scope, getDebugNameForVariable(var),
|
||||
name, file, srcInfo->line,
|
||||
getDIType(var->getType()),
|
||||
/*IsLocalToUnit=*/true);
|
||||
storage->addDebugInfo(debugVar);
|
||||
return storage;
|
||||
// debug info
|
||||
auto *srcInfo = getSrcInfo(var);
|
||||
llvm::DIFile *file = db.getFile(srcInfo->file);
|
||||
llvm::DIScope *scope = db.unit;
|
||||
llvm::DIGlobalVariableExpression *debugVar =
|
||||
db.builder->createGlobalVariableExpression(
|
||||
scope, getDebugNameForVariable(var), name, file, srcInfo->line,
|
||||
getDIType(var->getType()),
|
||||
/*IsLocalToUnit=*/true);
|
||||
storage->addDebugInfo(debugVar);
|
||||
vars.insert(var, storage);
|
||||
return storage;
|
||||
}
|
||||
} else {
|
||||
registerGlobal(var);
|
||||
return vars[var];
|
||||
}
|
||||
}
|
||||
|
||||
// should never have a non-global val from another module
|
||||
return val;
|
||||
}
|
||||
|
||||
llvm::Function *LLVMVisitor::getFunc(const Func *func) {
|
||||
llvm::Function *f = funcs[func];
|
||||
if (!f)
|
||||
return nullptr;
|
||||
if (db.jit) {
|
||||
if (f) {
|
||||
if (f->getParent() != module.get()) {
|
||||
// see if it's in the module already
|
||||
if (auto *g = module->getFunction(f->getName()))
|
||||
return g;
|
||||
|
||||
// the following happens when JIT'ing
|
||||
if (f->getParent() != module.get()) {
|
||||
// see if it's in the module already
|
||||
if (auto *g = module->getFunction(f->getName()))
|
||||
return g;
|
||||
|
||||
auto *g =
|
||||
llvm::Function::Create(f->getFunctionType(), llvm::Function::ExternalLinkage,
|
||||
f->getName(), module.get());
|
||||
g->copyAttributesFrom(f);
|
||||
return g;
|
||||
auto *g = llvm::Function::Create(f->getFunctionType(),
|
||||
llvm::Function::ExternalLinkage, f->getName(),
|
||||
module.get());
|
||||
g->copyAttributesFrom(f);
|
||||
funcs.insert(func, g);
|
||||
return g;
|
||||
}
|
||||
} else {
|
||||
registerGlobal(func);
|
||||
return funcs[func];
|
||||
}
|
||||
}
|
||||
|
||||
return f;
|
||||
}
|
||||
|
||||
|
@ -511,44 +515,12 @@ void LLVMVisitor::visit(const Module *x) {
|
|||
module = makeModule(getSrcInfo(x));
|
||||
|
||||
// args variable
|
||||
const Var *argVar = x->getArgVar();
|
||||
llvm::Type *argVarType = getLLVMType(argVar->getType());
|
||||
auto *argStorage = new llvm::GlobalVariable(
|
||||
*module, argVarType, /*isConstant=*/false, llvm::GlobalValue::PrivateLinkage,
|
||||
llvm::Constant::getNullValue(argVarType), argVar->getName());
|
||||
vars.insert(argVar, argStorage);
|
||||
seqassert(x->getArgVar()->isGlobal(), "arg var is not global");
|
||||
registerGlobal(x->getArgVar());
|
||||
|
||||
// set up global variables and initialize functions
|
||||
for (auto *var : *x) {
|
||||
if (!var->isGlobal())
|
||||
continue;
|
||||
|
||||
if (auto *f = cast<Func>(var)) {
|
||||
makeLLVMFunction(f);
|
||||
funcs.insert(f, func);
|
||||
} else {
|
||||
llvm::Type *llvmType = getLLVMType(var->getType());
|
||||
if (llvmType->isVoidTy()) {
|
||||
vars.insert(var, getDummyVoidValue(context));
|
||||
} else {
|
||||
auto *storage = new llvm::GlobalVariable(
|
||||
*module, llvmType, /*isConstant=*/false,
|
||||
llvm::GlobalVariable::PrivateLinkage,
|
||||
llvm::Constant::getNullValue(llvmType), var->getName());
|
||||
vars.insert(var, storage);
|
||||
|
||||
// debug info
|
||||
auto *srcInfo = getSrcInfo(var);
|
||||
llvm::DIFile *file = db.getFile(srcInfo->file);
|
||||
llvm::DIScope *scope = db.unit;
|
||||
llvm::DIGlobalVariableExpression *debugVar =
|
||||
db.builder->createGlobalVariableExpression(
|
||||
scope, getDebugNameForVariable(var), var->getName(), file,
|
||||
srcInfo->line, getDIType(var->getType()),
|
||||
/*IsLocalToUnit=*/true);
|
||||
storage->addDebugInfo(debugVar);
|
||||
}
|
||||
}
|
||||
registerGlobal(var);
|
||||
}
|
||||
|
||||
// process functions
|
||||
|
@ -632,6 +604,8 @@ void LLVMVisitor::visit(const Module *x) {
|
|||
builder.CreateBr(loopBlock);
|
||||
|
||||
builder.SetInsertPoint(exitBlock);
|
||||
llvm::Value *argStorage = vars[x->getArgVar()];
|
||||
seqassert(argStorage, "argument storage missing");
|
||||
builder.CreateStore(arr, argStorage);
|
||||
builder.CreateCall(initFunc, builder.getInt32(db.debug ? 1 : 0));
|
||||
|
||||
|
@ -739,7 +713,7 @@ void LLVMVisitor::makeYield(llvm::Value *value, bool finalYield) {
|
|||
}
|
||||
|
||||
void LLVMVisitor::visit(const ExternalFunc *x) {
|
||||
func = module->getFunction(getNameForFunction(x)); // inserted during module visit
|
||||
func = module->getFunction(getNameForFunction(x));
|
||||
coro = {};
|
||||
seqassert(func, "{} not inserted", *x);
|
||||
func->setDoesNotThrow();
|
||||
|
@ -775,7 +749,7 @@ bool internalFuncMatches(const std::string &name, const InternalFunc *x) {
|
|||
|
||||
void LLVMVisitor::visit(const InternalFunc *x) {
|
||||
using namespace types;
|
||||
func = module->getFunction(getNameForFunction(x)); // inserted during module visit
|
||||
func = module->getFunction(getNameForFunction(x));
|
||||
coro = {};
|
||||
seqassert(func, "{} not inserted", *x);
|
||||
setDebugInfoForNode(x);
|
||||
|
@ -953,7 +927,7 @@ void LLVMVisitor::visit(const LLVMFunc *x) {
|
|||
}
|
||||
|
||||
void LLVMVisitor::visit(const BodiedFunc *x) {
|
||||
func = module->getFunction(getNameForFunction(x)); // inserted during module visit
|
||||
func = module->getFunction(getNameForFunction(x));
|
||||
coro = {};
|
||||
seqassert(func, "{} not inserted", *x);
|
||||
setDebugInfoForNode(x);
|
||||
|
@ -1009,7 +983,7 @@ void LLVMVisitor::visit(const BodiedFunc *x) {
|
|||
for (auto *var : *x) {
|
||||
llvm::Type *llvmType = getLLVMType(var->getType());
|
||||
if (llvmType->isVoidTy()) {
|
||||
vars.insert(var, getDummyVoidValue(context));
|
||||
vars.insert(var, getDummyVoidValue());
|
||||
} else {
|
||||
llvm::Value *storage = builder.CreateAlloca(llvmType);
|
||||
vars.insert(var, storage);
|
||||
|
@ -2031,7 +2005,7 @@ void LLVMVisitor::visit(const AssignInstr *x) {
|
|||
llvm::Value *var = getVar(x->getLhs());
|
||||
seqassert(var, "could not find {} var", *x);
|
||||
process(x->getRhs());
|
||||
if (var != getDummyVoidValue(context)) {
|
||||
if (var != getDummyVoidValue()) {
|
||||
builder.SetInsertPoint(block);
|
||||
builder.CreateStore(value, var);
|
||||
}
|
||||
|
|
|
@ -97,11 +97,13 @@ private:
|
|||
llvm::DICompileUnit *unit;
|
||||
/// Whether we are compiling in debug mode
|
||||
bool debug;
|
||||
/// Whether we are compiling in JIT mode
|
||||
bool jit;
|
||||
/// Program command-line flags
|
||||
std::string flags;
|
||||
|
||||
explicit DebugInfo(bool debug, const std::string &flags)
|
||||
: builder(), unit(nullptr), debug(debug), flags(flags) {}
|
||||
DebugInfo(bool debug, bool jit, const std::string &flags)
|
||||
: builder(), unit(nullptr), debug(debug), jit(jit), flags(flags) {}
|
||||
|
||||
llvm::DIFile *getFile(const std::string &path);
|
||||
};
|
||||
|
@ -178,9 +180,37 @@ private:
|
|||
|
||||
llvm::Value *getVar(const Var *var);
|
||||
llvm::Function *getFunc(const Func *func);
|
||||
llvm::Value *getDummyVoidValue() { return llvm::ConstantTokenNone::get(context); }
|
||||
|
||||
public:
|
||||
LLVMVisitor(bool debug = false, const std::string &flags = "");
|
||||
static std::string getNameForFunction(const Func *x) {
|
||||
if (auto *externalFunc = cast<ExternalFunc>(x)) {
|
||||
return x->getUnmangledName();
|
||||
} else {
|
||||
return x->referenceString();
|
||||
}
|
||||
}
|
||||
|
||||
static std::string getDebugNameForVariable(const Var *x) {
|
||||
std::string name = x->getName();
|
||||
auto pos = name.find(".");
|
||||
if (pos != 0 && pos != std::string::npos) {
|
||||
return name.substr(0, pos);
|
||||
} else {
|
||||
return name;
|
||||
}
|
||||
}
|
||||
|
||||
static const SrcInfo *getSrcInfo(const Node *x) {
|
||||
if (auto *srcInfo = x->getAttribute<SrcInfoAttribute>()) {
|
||||
return &srcInfo->info;
|
||||
} else {
|
||||
static SrcInfo defaultSrcInfo("<internal>", 0, 0, 0);
|
||||
return &defaultSrcInfo;
|
||||
}
|
||||
}
|
||||
|
||||
LLVMVisitor(bool debug = false, bool jit = false, const std::string &flags = "");
|
||||
|
||||
llvm::LLVMContext &getContext() { return context; }
|
||||
llvm::IRBuilder<> &getBuilder() { return builder; }
|
||||
|
@ -199,6 +229,11 @@ public:
|
|||
void setBlock(llvm::BasicBlock *b) { block = b; }
|
||||
void setValue(llvm::Value *v) { value = v; }
|
||||
|
||||
/// Registers a new global variable or function with
|
||||
/// this visitor.
|
||||
/// @param var the global variable (or function) to register
|
||||
void registerGlobal(const Var *var);
|
||||
|
||||
/// Returns a new LLVM module initialized for the host
|
||||
/// architecture.
|
||||
/// @param src source information for the new module
|
||||
|
|
|
@ -45,7 +45,7 @@ void GlobalDemotionPass::run(Module *M) {
|
|||
}
|
||||
|
||||
for (auto it : localGlobals) {
|
||||
if (!it.second)
|
||||
if (!it.second || it.first->getId() == M->getArgVar()->getId())
|
||||
continue;
|
||||
seqassert(it.first->isGlobal(), "var was not global");
|
||||
it.first->setGlobal(false);
|
||||
|
|
Loading…
Reference in New Issue