Add allocation removal LLVM pass (#39)

* Add allocation removal LLVM pass

* Fix replacement

* Fix alloc demotion

* Remove tail calls as needed

* Fix instruction erasing

* Optimize reallocs

* Set string constants as unnamed_addr in LLVM

* Use peephole instead of loop-opt-late
pull/28/head^2
A. R. Shajii 2022-07-26 16:06:19 -04:00 committed by GitHub
parent cb945f569c
commit 963ddb3b60
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 408 additions and 2 deletions

View File

@ -1572,6 +1572,7 @@ void LLVMVisitor::visit(const StringConst *x) {
*M, llvm::ArrayType::get(B->getInt8Ty(), s.length() + 1),
/*isConstant=*/true, llvm::GlobalValue::PrivateLinkage,
llvm::ConstantDataArray::getString(*context, s), "str_literal");
strVar->setUnnamedAddr(llvm::GlobalValue::UnnamedAddr::Global);
auto *strType = llvm::StructType::get(B->getInt64Ty(), B->getInt8PtrTy());
llvm::Value *ptr = B->CreateBitCast(strVar, B->getInt8PtrTy());
llvm::Value *len = B->getInt64(s.length());

View File

@ -1,9 +1,12 @@
#pragma once
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/Triple.h"
#include "llvm/Analysis/CallGraph.h"
#include "llvm/Analysis/CallGraphSCCPass.h"
#include "llvm/Analysis/CaptureTracking.h"
#include "llvm/Analysis/LoopPass.h"
#include "llvm/Analysis/RegionPass.h"
#include "llvm/Analysis/TargetLibraryInfo.h"

View File

@ -1,5 +1,7 @@
#include "optimize.h"
#include <algorithm>
#include "codon/sir/llvm/coro/Coroutines.h"
#include "codon/util/common.h"
#include "llvm/CodeGen/CommandFlags.h"
@ -79,6 +81,404 @@ void applyDebugTransformations(llvm::Module *module, bool debug, bool jit) {
}
}
/// Lowers allocations of known, small size to alloca when possible.
/// Also removes unused allocations.
struct AllocationRemover : public llvm::FunctionPass {
std::string alloc;
std::string allocAtomic;
std::string realloc;
std::string free;
static char ID;
AllocationRemover(const std::string &alloc = "seq_alloc",
const std::string &allocAtomic = "seq_alloc_atomic",
const std::string &realloc = "seq_realloc",
const std::string &free = "seq_free")
: llvm::FunctionPass(ID), alloc(alloc), allocAtomic(allocAtomic),
realloc(realloc), free(free) {}
static bool sizeOkToDemote(uint64_t size) { return 0 < size && size <= 1024; }
static const llvm::Function *getCalledFunction(const llvm::Value *value) {
// Don't care about intrinsics in this case.
if (llvm::isa<llvm::IntrinsicInst>(value))
return nullptr;
const auto *cb = llvm::dyn_cast<llvm::CallBase>(value);
if (!cb)
return nullptr;
if (const llvm::Function *callee = cb->getCalledFunction())
return callee;
return nullptr;
}
bool isAlloc(const llvm::Value *value) {
if (auto *func = getCalledFunction(value)) {
return func->arg_size() == 1 &&
(func->getName() == alloc || func->getName() == allocAtomic);
}
return false;
}
bool isRealloc(const llvm::Value *value) {
if (auto *func = getCalledFunction(value)) {
return func->arg_size() == 2 && func->getName() == realloc;
}
return false;
}
bool isFree(const llvm::Value *value) {
if (auto *func = getCalledFunction(value)) {
return func->arg_size() == 1 && func->getName() == free;
}
return false;
}
static bool getFixedArg(llvm::CallBase &cb, uint64_t &size, unsigned idx = 0) {
if (cb.arg_empty())
return false;
if (auto *ci = llvm::dyn_cast<llvm::ConstantInt>(cb.getArgOperand(idx))) {
size = ci->getZExtValue();
return true;
}
return false;
}
bool isNeverEqualToUnescapedAlloc(llvm::Value *value, llvm::Instruction *ai) {
using namespace llvm;
if (isa<ConstantPointerNull>(value))
return true;
if (auto *li = dyn_cast<LoadInst>(value))
return isa<GlobalVariable>(li->getPointerOperand());
// Two distinct allocations will never be equal.
return isAlloc(value) && value != ai;
}
bool isAllocSiteRemovable(llvm::Instruction *ai,
llvm::SmallVectorImpl<llvm::WeakTrackingVH> &users) {
using namespace llvm;
// Should never be an invoke, so just check right away.
if (isa<InvokeInst>(ai))
return false;
SmallVector<Instruction *, 4> worklist;
worklist.push_back(ai);
do {
Instruction *pi = worklist.pop_back_val();
for (User *u : pi->users()) {
Instruction *instr = cast<Instruction>(u);
switch (instr->getOpcode()) {
default:
// Give up the moment we see something we can't handle.
return false;
case Instruction::AddrSpaceCast:
case Instruction::BitCast:
case Instruction::GetElementPtr:
users.emplace_back(instr);
worklist.push_back(instr);
continue;
case Instruction::ICmp: {
ICmpInst *cmp = cast<ICmpInst>(instr);
// We can fold eq/ne comparisons with null to false/true, respectively.
// We also fold comparisons in some conditions provided the alloc has
// not escaped (see isNeverEqualToUnescapedAlloc).
if (!cmp->isEquality())
return false;
unsigned otherIndex = (cmp->getOperand(0) == pi) ? 1 : 0;
if (!isNeverEqualToUnescapedAlloc(cmp->getOperand(otherIndex), ai))
return false;
users.emplace_back(instr);
continue;
}
case Instruction::Call:
// Ignore no-op and store intrinsics.
if (IntrinsicInst *intrinsic = dyn_cast<IntrinsicInst>(instr)) {
switch (intrinsic->getIntrinsicID()) {
default:
return false;
case Intrinsic::memmove:
case Intrinsic::memcpy:
case Intrinsic::memset: {
MemIntrinsic *MI = cast<MemIntrinsic>(intrinsic);
if (MI->isVolatile() || MI->getRawDest() != pi)
return false;
LLVM_FALLTHROUGH;
}
case Intrinsic::assume:
case Intrinsic::invariant_start:
case Intrinsic::invariant_end:
case Intrinsic::lifetime_start:
case Intrinsic::lifetime_end:
users.emplace_back(instr);
continue;
case Intrinsic::launder_invariant_group:
case Intrinsic::strip_invariant_group:
users.emplace_back(instr);
worklist.push_back(instr);
continue;
}
}
if (isFree(instr)) {
users.emplace_back(instr);
continue;
}
if (isRealloc(instr)) {
users.emplace_back(instr);
worklist.push_back(instr);
continue;
}
return false;
case Instruction::Store: {
StoreInst *si = cast<StoreInst>(instr);
if (si->isVolatile() || si->getPointerOperand() != pi)
return false;
users.emplace_back(instr);
continue;
}
}
seqassert(false, "missing a return?");
}
} while (!worklist.empty());
return true;
}
bool isAllocSiteDemotable(llvm::Instruction *ai, uint64_t &size,
llvm::SmallVectorImpl<llvm::WeakTrackingVH> &users) {
using namespace llvm;
// Should never be an invoke, so just check right away.
if (isa<InvokeInst>(ai))
return false;
if (!(getFixedArg(*dyn_cast<CallBase>(&*ai), size) && sizeOkToDemote(size)))
return false;
SmallVector<Instruction *, 4> worklist;
worklist.push_back(ai);
do {
Instruction *pi = worklist.pop_back_val();
for (User *u : pi->users()) {
Instruction *instr = cast<Instruction>(u);
switch (instr->getOpcode()) {
default:
// Give up the moment we see something we can't handle.
return false;
case Instruction::AddrSpaceCast:
case Instruction::BitCast:
case Instruction::GetElementPtr:
worklist.push_back(instr);
continue;
case Instruction::ICmp: {
ICmpInst *cmp = cast<ICmpInst>(instr);
// We can fold eq/ne comparisons with null to false/true, respectively.
// We also fold comparisons in some conditions provided the alloc has
// not escaped (see isNeverEqualToUnescapedAlloc).
if (!cmp->isEquality())
return false;
unsigned otherIndex = (cmp->getOperand(0) == pi) ? 1 : 0;
if (!isNeverEqualToUnescapedAlloc(cmp->getOperand(otherIndex), ai))
return false;
continue;
}
case Instruction::Call:
// Ignore no-op and store intrinsics.
if (IntrinsicInst *intrinsic = dyn_cast<IntrinsicInst>(instr)) {
switch (intrinsic->getIntrinsicID()) {
default:
return false;
case Intrinsic::memmove:
case Intrinsic::memcpy:
case Intrinsic::memset: {
MemIntrinsic *MI = cast<MemIntrinsic>(intrinsic);
if (MI->isVolatile())
return false;
LLVM_FALLTHROUGH;
}
case Intrinsic::assume:
case Intrinsic::invariant_start:
case Intrinsic::invariant_end:
case Intrinsic::lifetime_start:
case Intrinsic::lifetime_end:
users.emplace_back(instr);
continue;
case Intrinsic::launder_invariant_group:
case Intrinsic::strip_invariant_group:
users.emplace_back(instr);
worklist.push_back(instr);
continue;
}
}
if (isFree(instr)) {
users.emplace_back(instr);
continue;
}
if (isRealloc(instr)) {
// If the realloc also has constant small size,
// then we can just update the assumed size to be
// max of original alloc's and this realloc's.
uint64_t newSize = 0;
if (getFixedArg(*dyn_cast<CallBase>(instr), newSize, 1) &&
sizeOkToDemote(newSize)) {
size = std::max(size, newSize);
} else {
return false;
}
users.emplace_back(instr);
worklist.push_back(instr);
continue;
}
return false;
case Instruction::Store: {
StoreInst *si = cast<StoreInst>(instr);
if (si->isVolatile() || si->getPointerOperand() != pi)
return false;
continue;
}
case Instruction::Load: {
LoadInst *li = cast<LoadInst>(instr);
if (li->isVolatile())
return false;
continue;
}
}
seqassert(false, "missing a return?");
}
} while (!worklist.empty());
return true;
}
void getErasesAndReplacementsForAlloc(
llvm::Instruction &mi, llvm::SmallPtrSetImpl<llvm::Instruction *> &erase,
llvm::SmallVectorImpl<std::pair<llvm::Instruction *, llvm::Value *>> &replace,
llvm::SmallVectorImpl<llvm::AllocaInst *> &alloca,
llvm::SmallVectorImpl<llvm::CallInst *> &untail) {
using namespace llvm;
uint64_t size = 0;
SmallVector<WeakTrackingVH, 64> users;
if (isAllocSiteRemovable(&mi, users)) {
for (unsigned i = 0, e = users.size(); i != e; ++i) {
if (!users[i])
continue;
Instruction *instr = cast<Instruction>(&*users[i]);
if (ICmpInst *cmp = dyn_cast<ICmpInst>(instr)) {
replace.emplace_back(cmp, ConstantInt::get(Type::getInt1Ty(cmp->getContext()),
cmp->isFalseWhenEqual()));
} else if (!isa<StoreInst>(instr)) {
// Casts, GEP, or anything else: we're about to delete this instruction,
// so it can not have any valid uses.
replace.emplace_back(instr, PoisonValue::get(instr->getType()));
}
erase.insert(instr);
}
erase.insert(&mi);
return;
} else {
users.clear();
}
if (isAllocSiteDemotable(&mi, size, users)) {
auto *replacement = new AllocaInst(
Type::getInt8Ty(mi.getContext()), 0,
ConstantInt::get(Type::getInt64Ty(mi.getContext()), size), Align());
alloca.push_back(replacement);
replace.emplace_back(&mi, replacement);
erase.insert(&mi);
for (unsigned i = 0, e = users.size(); i != e; ++i) {
if (!users[i])
continue;
Instruction *instr = cast<Instruction>(&*users[i]);
if (isFree(instr)) {
erase.insert(instr);
} else if (isRealloc(instr)) {
replace.emplace_back(instr, replacement);
erase.insert(instr);
} else if (auto *ci = dyn_cast<CallInst>(&*instr)) {
if (ci->isTailCall() || ci->isMustTailCall())
untail.push_back(ci);
}
}
}
}
bool runOnFunction(llvm::Function &func) override {
using namespace llvm;
SmallSet<Instruction *, 32> erase;
SmallVector<std::pair<Instruction *, llvm::Value *>, 32> replace;
SmallVector<AllocaInst *, 32> alloca;
SmallVector<CallInst *, 32> untail;
for (inst_iterator instr = inst_begin(func), end = inst_end(func); instr != end;
++instr) {
auto *cb = dyn_cast<CallBase>(&*instr);
if (!cb || !isAlloc(cb))
continue;
getErasesAndReplacementsForAlloc(*cb, erase, replace, alloca, untail);
}
for (auto *A : alloca) {
A->insertBefore(func.getEntryBlock().getFirstNonPHI());
}
for (auto *C : untail) {
C->setTailCall(false);
}
for (auto &P : replace) {
P.first->replaceAllUsesWith(P.second);
}
for (auto *I : erase) {
I->dropAllReferences();
}
for (auto *I : erase) {
I->eraseFromParent();
}
return !erase.empty() || !replace.empty() || !alloca.empty() || !untail.empty();
}
};
void addAllocationRemover(const llvm::PassManagerBuilder &builder,
llvm::legacy::PassManagerBase &pm) {
pm.add(new AllocationRemover());
}
char AllocationRemover::ID = 0;
llvm::RegisterPass<AllocationRemover> X1("alloc-remove", "Allocation Remover");
/// Sometimes coroutine lowering produces hard-to-analyze loops involving
/// function pointer comparisons. This pass puts them into a somewhat
/// easier-to-analyze form.
@ -161,8 +561,8 @@ void addCoroutineBranchSimplifier(const llvm::PassManagerBuilder &builder,
}
char CoroBranchSimplifier::ID = 0;
llvm::RegisterPass<CoroBranchSimplifier> X("coro-br-simpl",
"Coroutine Branch Simplifier");
llvm::RegisterPass<CoroBranchSimplifier> X2("coro-br-simpl",
"Coroutine Branch Simplifier");
void runLLVMOptimizationPasses(llvm::Module *module, bool debug, bool jit,
PluginManager *plugins) {
@ -211,6 +611,8 @@ void runLLVMOptimizationPasses(llvm::Module *module, bool debug, bool jit,
if (!debug) {
pmb.addExtension(llvm::PassManagerBuilder::EP_LateLoopOptimizations,
addCoroutineBranchSimplifier);
pmb.addExtension(llvm::PassManagerBuilder::EP_Peephole,
addAllocationRemover);
}
if (plugins) {