More engine updates

pull/6/head
A. R. Shajii 2021-10-22 17:57:23 -04:00
parent 910c880666
commit 62e5577a1e
5 changed files with 166 additions and 126 deletions

View File

@ -120,10 +120,13 @@ public:
}
};
typedef int MainFunc(int, char **);
typedef void InputFunc();
JIT::JIT(ir::Module *module)
: module(module), pm(std::make_unique<ir::transform::PassManager>(/*debug=*/true)),
plm(std::make_unique<PluginManager>()),
llvisitor(std::make_unique<ir::LLVMVisitor>(/*debug=*/true)) {
llvisitor(std::make_unique<ir::LLVMVisitor>(/*debug=*/true, /*jit=*/true)) {
if (auto e = Engine::create()) {
engine = std::move(e.get());
} else {
@ -133,5 +136,29 @@ JIT::JIT(ir::Module *module)
llvisitor->setPluginManager(plm.get());
}
void JIT::init() {
module->accept(*llvisitor);
auto module = llvisitor->takeModule();
llvm::cantFail(
engine->addModule({std::move(module), std::make_unique<llvm::LLVMContext>()}));
auto func = llvm::cantFail(engine->lookup("main"));
auto *main = (MainFunc *)func.getAddress();
(*main)(0, nullptr);
}
void JIT::run(const ir::Func *input, const std::vector<ir::Var *> &newGlobals) {
const std::string name = ir::LLVMVisitor::getNameForFunction(input);
llvisitor->registerGlobal(input);
for (auto *var : newGlobals)
llvisitor->registerGlobal(var);
input->accept(*llvisitor);
auto module = llvisitor->takeModule();
llvm::cantFail(
engine->addModule({std::move(module), std::make_unique<llvm::LLVMContext>()}));
auto func = llvm::cantFail(engine->lookup(name));
auto *repl = (InputFunc *)func.getAddress();
(*repl)();
}
} // namespace jit
} // namespace codon

View File

@ -1,9 +1,11 @@
#pragma once
#include <memory>
#include <vector>
#include "codon/sir/llvm/llvisitor.h"
#include "codon/sir/transform/manager.h"
#include "codon/sir/var.h"
namespace codon {
namespace jit {
@ -21,6 +23,8 @@ private:
public:
JIT(ir::Module *module);
ir::Module *getModule() const { return module; }
void init();
void run(const ir::Func *input, const std::vector<ir::Var *> &newGlobals = {});
};
} // namespace jit

View File

@ -13,39 +13,6 @@
namespace codon {
namespace ir {
namespace {
std::string getNameForFunction(const Func *x) {
if (auto *externalFunc = cast<ExternalFunc>(x)) {
return x->getUnmangledName();
} else {
return x->referenceString();
}
}
std::string getDebugNameForVariable(const Var *x) {
std::string name = x->getName();
auto pos = name.find(".");
if (pos != 0 && pos != std::string::npos) {
return name.substr(0, pos);
} else {
return name;
}
}
const SrcInfo *getSrcInfo(const Node *x) {
if (auto *srcInfo = x->getAttribute<SrcInfoAttribute>()) {
return &srcInfo->info;
} else {
static SrcInfo defaultSrcInfo("<internal>", 0, 0, 0);
return &defaultSrcInfo;
}
}
llvm::Value *getDummyVoidValue(llvm::LLVMContext &context) {
return llvm::ConstantTokenNone::get(context);
}
} // namespace
llvm::DIFile *LLVMVisitor::DebugInfo::getFile(const std::string &path) {
std::string filename;
std::string directory;
@ -60,10 +27,10 @@ llvm::DIFile *LLVMVisitor::DebugInfo::getFile(const std::string &path) {
return builder->createFile(filename, directory);
}
LLVMVisitor::LLVMVisitor(bool debug, const std::string &flags)
LLVMVisitor::LLVMVisitor(bool debug, bool jit, const std::string &flags)
: util::ConstVisitor(), context(), builder(context), module(), func(nullptr),
block(nullptr), value(nullptr), vars(), funcs(), coro(), loops(), trycatch(),
db(debug, flags), plugins(nullptr) {
db(debug, jit, flags), plugins(nullptr) {
llvm::InitializeAllTargets();
llvm::InitializeAllTargetMCs();
llvm::InitializeAllAsmPrinters();
@ -107,65 +74,102 @@ LLVMVisitor::LLVMVisitor(bool debug, const std::string &flags)
llvm::initializeTypePromotionPass(registry);
}
void LLVMVisitor::registerGlobal(const Var *var) {
if (!var->isGlobal())
return;
if (auto *f = cast<Func>(var)) {
makeLLVMFunction(f);
funcs.insert(f, func);
} else {
llvm::Type *llvmType = getLLVMType(var->getType());
if (llvmType->isVoidTy()) {
vars.insert(var, getDummyVoidValue());
} else {
auto *storage = new llvm::GlobalVariable(
*module, llvmType, /*isConstant=*/false, llvm::GlobalVariable::PrivateLinkage,
llvm::Constant::getNullValue(llvmType), var->getName());
vars.insert(var, storage);
// debug info
auto *srcInfo = getSrcInfo(var);
llvm::DIFile *file = db.getFile(srcInfo->file);
llvm::DIScope *scope = db.unit;
llvm::DIGlobalVariableExpression *debugVar =
db.builder->createGlobalVariableExpression(
scope, getDebugNameForVariable(var), var->getName(), file, srcInfo->line,
getDIType(var->getType()),
/*IsLocalToUnit=*/true);
storage->addDebugInfo(debugVar);
}
}
}
llvm::Value *LLVMVisitor::getVar(const Var *var) {
llvm::Value *val = vars[var];
if (!val)
return nullptr;
if (db.jit && var->isGlobal()) {
if (val) {
llvm::Module *m = nullptr;
if (auto *x = llvm::dyn_cast<llvm::Instruction>(val))
m = x->getModule();
else if (auto *x = llvm::dyn_cast<llvm::GlobalValue>(val))
m = x->getParent();
llvm::Module *m = nullptr;
if (auto *x = llvm::dyn_cast<llvm::Instruction>(val))
m = x->getModule();
else if (auto *x = llvm::dyn_cast<llvm::GlobalValue>(val))
m = x->getParent();
if (m != module.get()) {
// see if it's in the module already
auto name = var->getName();
if (auto *global = module->getNamedValue(name))
return global;
// the following happens when JIT'ing
if (var->isGlobal() && m != module.get()) {
// see if it's in the module already
auto name = var->getName();
if (auto *global = module->getNamedValue(name))
return global;
llvm::Type *llvmType = getLLVMType(var->getType());
auto *storage =
new llvm::GlobalVariable(*module, llvmType, /*isConstant=*/false,
llvm::GlobalVariable::ExternalLinkage,
/*Initializer=*/nullptr, name);
storage->setExternallyInitialized(true);
llvm::Type *llvmType = getLLVMType(var->getType());
auto *storage = new llvm::GlobalVariable(*module, llvmType, /*isConstant=*/false,
llvm::GlobalVariable::ExternalLinkage,
/*Initializer=*/nullptr, name);
storage->setExternallyInitialized(true);
// debug info
auto *srcInfo = getSrcInfo(var);
llvm::DIFile *file = db.getFile(srcInfo->file);
llvm::DIScope *scope = db.unit;
llvm::DIGlobalVariableExpression *debugVar =
db.builder->createGlobalVariableExpression(scope, getDebugNameForVariable(var),
name, file, srcInfo->line,
getDIType(var->getType()),
/*IsLocalToUnit=*/true);
storage->addDebugInfo(debugVar);
return storage;
// debug info
auto *srcInfo = getSrcInfo(var);
llvm::DIFile *file = db.getFile(srcInfo->file);
llvm::DIScope *scope = db.unit;
llvm::DIGlobalVariableExpression *debugVar =
db.builder->createGlobalVariableExpression(
scope, getDebugNameForVariable(var), name, file, srcInfo->line,
getDIType(var->getType()),
/*IsLocalToUnit=*/true);
storage->addDebugInfo(debugVar);
vars.insert(var, storage);
return storage;
}
} else {
registerGlobal(var);
return vars[var];
}
}
// should never have a non-global val from another module
return val;
}
llvm::Function *LLVMVisitor::getFunc(const Func *func) {
llvm::Function *f = funcs[func];
if (!f)
return nullptr;
if (db.jit) {
if (f) {
if (f->getParent() != module.get()) {
// see if it's in the module already
if (auto *g = module->getFunction(f->getName()))
return g;
// the following happens when JIT'ing
if (f->getParent() != module.get()) {
// see if it's in the module already
if (auto *g = module->getFunction(f->getName()))
return g;
auto *g =
llvm::Function::Create(f->getFunctionType(), llvm::Function::ExternalLinkage,
f->getName(), module.get());
g->copyAttributesFrom(f);
return g;
auto *g = llvm::Function::Create(f->getFunctionType(),
llvm::Function::ExternalLinkage, f->getName(),
module.get());
g->copyAttributesFrom(f);
funcs.insert(func, g);
return g;
}
} else {
registerGlobal(func);
return funcs[func];
}
}
return f;
}
@ -511,44 +515,12 @@ void LLVMVisitor::visit(const Module *x) {
module = makeModule(getSrcInfo(x));
// args variable
const Var *argVar = x->getArgVar();
llvm::Type *argVarType = getLLVMType(argVar->getType());
auto *argStorage = new llvm::GlobalVariable(
*module, argVarType, /*isConstant=*/false, llvm::GlobalValue::PrivateLinkage,
llvm::Constant::getNullValue(argVarType), argVar->getName());
vars.insert(argVar, argStorage);
seqassert(x->getArgVar()->isGlobal(), "arg var is not global");
registerGlobal(x->getArgVar());
// set up global variables and initialize functions
for (auto *var : *x) {
if (!var->isGlobal())
continue;
if (auto *f = cast<Func>(var)) {
makeLLVMFunction(f);
funcs.insert(f, func);
} else {
llvm::Type *llvmType = getLLVMType(var->getType());
if (llvmType->isVoidTy()) {
vars.insert(var, getDummyVoidValue(context));
} else {
auto *storage = new llvm::GlobalVariable(
*module, llvmType, /*isConstant=*/false,
llvm::GlobalVariable::PrivateLinkage,
llvm::Constant::getNullValue(llvmType), var->getName());
vars.insert(var, storage);
// debug info
auto *srcInfo = getSrcInfo(var);
llvm::DIFile *file = db.getFile(srcInfo->file);
llvm::DIScope *scope = db.unit;
llvm::DIGlobalVariableExpression *debugVar =
db.builder->createGlobalVariableExpression(
scope, getDebugNameForVariable(var), var->getName(), file,
srcInfo->line, getDIType(var->getType()),
/*IsLocalToUnit=*/true);
storage->addDebugInfo(debugVar);
}
}
registerGlobal(var);
}
// process functions
@ -632,6 +604,8 @@ void LLVMVisitor::visit(const Module *x) {
builder.CreateBr(loopBlock);
builder.SetInsertPoint(exitBlock);
llvm::Value *argStorage = vars[x->getArgVar()];
seqassert(argStorage, "argument storage missing");
builder.CreateStore(arr, argStorage);
builder.CreateCall(initFunc, builder.getInt32(db.debug ? 1 : 0));
@ -739,7 +713,7 @@ void LLVMVisitor::makeYield(llvm::Value *value, bool finalYield) {
}
void LLVMVisitor::visit(const ExternalFunc *x) {
func = module->getFunction(getNameForFunction(x)); // inserted during module visit
func = module->getFunction(getNameForFunction(x));
coro = {};
seqassert(func, "{} not inserted", *x);
func->setDoesNotThrow();
@ -775,7 +749,7 @@ bool internalFuncMatches(const std::string &name, const InternalFunc *x) {
void LLVMVisitor::visit(const InternalFunc *x) {
using namespace types;
func = module->getFunction(getNameForFunction(x)); // inserted during module visit
func = module->getFunction(getNameForFunction(x));
coro = {};
seqassert(func, "{} not inserted", *x);
setDebugInfoForNode(x);
@ -953,7 +927,7 @@ void LLVMVisitor::visit(const LLVMFunc *x) {
}
void LLVMVisitor::visit(const BodiedFunc *x) {
func = module->getFunction(getNameForFunction(x)); // inserted during module visit
func = module->getFunction(getNameForFunction(x));
coro = {};
seqassert(func, "{} not inserted", *x);
setDebugInfoForNode(x);
@ -1009,7 +983,7 @@ void LLVMVisitor::visit(const BodiedFunc *x) {
for (auto *var : *x) {
llvm::Type *llvmType = getLLVMType(var->getType());
if (llvmType->isVoidTy()) {
vars.insert(var, getDummyVoidValue(context));
vars.insert(var, getDummyVoidValue());
} else {
llvm::Value *storage = builder.CreateAlloca(llvmType);
vars.insert(var, storage);
@ -2031,7 +2005,7 @@ void LLVMVisitor::visit(const AssignInstr *x) {
llvm::Value *var = getVar(x->getLhs());
seqassert(var, "could not find {} var", *x);
process(x->getRhs());
if (var != getDummyVoidValue(context)) {
if (var != getDummyVoidValue()) {
builder.SetInsertPoint(block);
builder.CreateStore(value, var);
}

View File

@ -97,11 +97,13 @@ private:
llvm::DICompileUnit *unit;
/// Whether we are compiling in debug mode
bool debug;
/// Whether we are compiling in JIT mode
bool jit;
/// Program command-line flags
std::string flags;
explicit DebugInfo(bool debug, const std::string &flags)
: builder(), unit(nullptr), debug(debug), flags(flags) {}
DebugInfo(bool debug, bool jit, const std::string &flags)
: builder(), unit(nullptr), debug(debug), jit(jit), flags(flags) {}
llvm::DIFile *getFile(const std::string &path);
};
@ -178,9 +180,37 @@ private:
llvm::Value *getVar(const Var *var);
llvm::Function *getFunc(const Func *func);
llvm::Value *getDummyVoidValue() { return llvm::ConstantTokenNone::get(context); }
public:
LLVMVisitor(bool debug = false, const std::string &flags = "");
static std::string getNameForFunction(const Func *x) {
if (auto *externalFunc = cast<ExternalFunc>(x)) {
return x->getUnmangledName();
} else {
return x->referenceString();
}
}
static std::string getDebugNameForVariable(const Var *x) {
std::string name = x->getName();
auto pos = name.find(".");
if (pos != 0 && pos != std::string::npos) {
return name.substr(0, pos);
} else {
return name;
}
}
static const SrcInfo *getSrcInfo(const Node *x) {
if (auto *srcInfo = x->getAttribute<SrcInfoAttribute>()) {
return &srcInfo->info;
} else {
static SrcInfo defaultSrcInfo("<internal>", 0, 0, 0);
return &defaultSrcInfo;
}
}
LLVMVisitor(bool debug = false, bool jit = false, const std::string &flags = "");
llvm::LLVMContext &getContext() { return context; }
llvm::IRBuilder<> &getBuilder() { return builder; }
@ -199,6 +229,11 @@ public:
void setBlock(llvm::BasicBlock *b) { block = b; }
void setValue(llvm::Value *v) { value = v; }
/// Registers a new global variable or function with
/// this visitor.
/// @param var the global variable (or function) to register
void registerGlobal(const Var *var);
/// Returns a new LLVM module initialized for the host
/// architecture.
/// @param src source information for the new module

View File

@ -45,7 +45,7 @@ void GlobalDemotionPass::run(Module *M) {
}
for (auto it : localGlobals) {
if (!it.second)
if (!it.second || it.first->getId() == M->getArgVar()->getId())
continue;
seqassert(it.first->isGlobal(), "var was not global");
it.first->setGlobal(false);