From 963ddb3b60dd2e86803ae2e2d419b08a339907f9 Mon Sep 17 00:00:00 2001 From: "A. R. Shajii" Date: Tue, 26 Jul 2022 16:06:19 -0400 Subject: [PATCH] Add allocation removal LLVM pass (#39) * Add allocation removal LLVM pass * Fix replacement * Fix alloc demotion * Remove tail calls as needed * Fix instruction erasing * Optimize reallocs * Set string constants as unnamed_addr in LLVM * Use peephole instead of loop-opt-late --- codon/sir/llvm/llvisitor.cpp | 1 + codon/sir/llvm/llvm.h | 3 + codon/sir/llvm/optimize.cpp | 406 ++++++++++++++++++++++++++++++++++- 3 files changed, 408 insertions(+), 2 deletions(-) diff --git a/codon/sir/llvm/llvisitor.cpp b/codon/sir/llvm/llvisitor.cpp index 01a2e4a5..484545f1 100644 --- a/codon/sir/llvm/llvisitor.cpp +++ b/codon/sir/llvm/llvisitor.cpp @@ -1572,6 +1572,7 @@ void LLVMVisitor::visit(const StringConst *x) { *M, llvm::ArrayType::get(B->getInt8Ty(), s.length() + 1), /*isConstant=*/true, llvm::GlobalValue::PrivateLinkage, llvm::ConstantDataArray::getString(*context, s), "str_literal"); + strVar->setUnnamedAddr(llvm::GlobalValue::UnnamedAddr::Global); auto *strType = llvm::StructType::get(B->getInt64Ty(), B->getInt8PtrTy()); llvm::Value *ptr = B->CreateBitCast(strVar, B->getInt8PtrTy()); llvm::Value *len = B->getInt64(s.length()); diff --git a/codon/sir/llvm/llvm.h b/codon/sir/llvm/llvm.h index d3e2f3e5..f8c2b21f 100644 --- a/codon/sir/llvm/llvm.h +++ b/codon/sir/llvm/llvm.h @@ -1,9 +1,12 @@ #pragma once +#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/Triple.h" #include "llvm/Analysis/CallGraph.h" #include "llvm/Analysis/CallGraphSCCPass.h" +#include "llvm/Analysis/CaptureTracking.h" #include "llvm/Analysis/LoopPass.h" #include "llvm/Analysis/RegionPass.h" #include "llvm/Analysis/TargetLibraryInfo.h" diff --git a/codon/sir/llvm/optimize.cpp b/codon/sir/llvm/optimize.cpp index c1d42630..b908e4b5 100644 --- a/codon/sir/llvm/optimize.cpp +++ b/codon/sir/llvm/optimize.cpp @@ -1,5 +1,7 @@ #include "optimize.h" +#include + #include "codon/sir/llvm/coro/Coroutines.h" #include "codon/util/common.h" #include "llvm/CodeGen/CommandFlags.h" @@ -79,6 +81,404 @@ void applyDebugTransformations(llvm::Module *module, bool debug, bool jit) { } } +/// Lowers allocations of known, small size to alloca when possible. +/// Also removes unused allocations. +struct AllocationRemover : public llvm::FunctionPass { + std::string alloc; + std::string allocAtomic; + std::string realloc; + std::string free; + + static char ID; + AllocationRemover(const std::string &alloc = "seq_alloc", + const std::string &allocAtomic = "seq_alloc_atomic", + const std::string &realloc = "seq_realloc", + const std::string &free = "seq_free") + : llvm::FunctionPass(ID), alloc(alloc), allocAtomic(allocAtomic), + realloc(realloc), free(free) {} + + static bool sizeOkToDemote(uint64_t size) { return 0 < size && size <= 1024; } + + static const llvm::Function *getCalledFunction(const llvm::Value *value) { + // Don't care about intrinsics in this case. + if (llvm::isa(value)) + return nullptr; + + const auto *cb = llvm::dyn_cast(value); + if (!cb) + return nullptr; + + if (const llvm::Function *callee = cb->getCalledFunction()) + return callee; + return nullptr; + } + + bool isAlloc(const llvm::Value *value) { + if (auto *func = getCalledFunction(value)) { + return func->arg_size() == 1 && + (func->getName() == alloc || func->getName() == allocAtomic); + } + return false; + } + + bool isRealloc(const llvm::Value *value) { + if (auto *func = getCalledFunction(value)) { + return func->arg_size() == 2 && func->getName() == realloc; + } + return false; + } + + bool isFree(const llvm::Value *value) { + if (auto *func = getCalledFunction(value)) { + return func->arg_size() == 1 && func->getName() == free; + } + return false; + } + + static bool getFixedArg(llvm::CallBase &cb, uint64_t &size, unsigned idx = 0) { + if (cb.arg_empty()) + return false; + + if (auto *ci = llvm::dyn_cast(cb.getArgOperand(idx))) { + size = ci->getZExtValue(); + return true; + } + + return false; + } + + bool isNeverEqualToUnescapedAlloc(llvm::Value *value, llvm::Instruction *ai) { + using namespace llvm; + + if (isa(value)) + return true; + if (auto *li = dyn_cast(value)) + return isa(li->getPointerOperand()); + // Two distinct allocations will never be equal. + return isAlloc(value) && value != ai; + } + + bool isAllocSiteRemovable(llvm::Instruction *ai, + llvm::SmallVectorImpl &users) { + using namespace llvm; + + // Should never be an invoke, so just check right away. + if (isa(ai)) + return false; + + SmallVector worklist; + worklist.push_back(ai); + + do { + Instruction *pi = worklist.pop_back_val(); + for (User *u : pi->users()) { + Instruction *instr = cast(u); + switch (instr->getOpcode()) { + default: + // Give up the moment we see something we can't handle. + return false; + + case Instruction::AddrSpaceCast: + case Instruction::BitCast: + case Instruction::GetElementPtr: + users.emplace_back(instr); + worklist.push_back(instr); + continue; + + case Instruction::ICmp: { + ICmpInst *cmp = cast(instr); + // We can fold eq/ne comparisons with null to false/true, respectively. + // We also fold comparisons in some conditions provided the alloc has + // not escaped (see isNeverEqualToUnescapedAlloc). + if (!cmp->isEquality()) + return false; + unsigned otherIndex = (cmp->getOperand(0) == pi) ? 1 : 0; + if (!isNeverEqualToUnescapedAlloc(cmp->getOperand(otherIndex), ai)) + return false; + users.emplace_back(instr); + continue; + } + + case Instruction::Call: + // Ignore no-op and store intrinsics. + if (IntrinsicInst *intrinsic = dyn_cast(instr)) { + switch (intrinsic->getIntrinsicID()) { + default: + return false; + + case Intrinsic::memmove: + case Intrinsic::memcpy: + case Intrinsic::memset: { + MemIntrinsic *MI = cast(intrinsic); + if (MI->isVolatile() || MI->getRawDest() != pi) + return false; + LLVM_FALLTHROUGH; + } + case Intrinsic::assume: + case Intrinsic::invariant_start: + case Intrinsic::invariant_end: + case Intrinsic::lifetime_start: + case Intrinsic::lifetime_end: + users.emplace_back(instr); + continue; + case Intrinsic::launder_invariant_group: + case Intrinsic::strip_invariant_group: + users.emplace_back(instr); + worklist.push_back(instr); + continue; + } + } + + if (isFree(instr)) { + users.emplace_back(instr); + continue; + } + + if (isRealloc(instr)) { + users.emplace_back(instr); + worklist.push_back(instr); + continue; + } + + return false; + + case Instruction::Store: { + StoreInst *si = cast(instr); + if (si->isVolatile() || si->getPointerOperand() != pi) + return false; + users.emplace_back(instr); + continue; + } + } + seqassert(false, "missing a return?"); + } + } while (!worklist.empty()); + return true; + } + + bool isAllocSiteDemotable(llvm::Instruction *ai, uint64_t &size, + llvm::SmallVectorImpl &users) { + using namespace llvm; + + // Should never be an invoke, so just check right away. + if (isa(ai)) + return false; + + if (!(getFixedArg(*dyn_cast(&*ai), size) && sizeOkToDemote(size))) + return false; + + SmallVector worklist; + worklist.push_back(ai); + + do { + Instruction *pi = worklist.pop_back_val(); + for (User *u : pi->users()) { + Instruction *instr = cast(u); + switch (instr->getOpcode()) { + default: + // Give up the moment we see something we can't handle. + return false; + + case Instruction::AddrSpaceCast: + case Instruction::BitCast: + case Instruction::GetElementPtr: + worklist.push_back(instr); + continue; + + case Instruction::ICmp: { + ICmpInst *cmp = cast(instr); + // We can fold eq/ne comparisons with null to false/true, respectively. + // We also fold comparisons in some conditions provided the alloc has + // not escaped (see isNeverEqualToUnescapedAlloc). + if (!cmp->isEquality()) + return false; + unsigned otherIndex = (cmp->getOperand(0) == pi) ? 1 : 0; + if (!isNeverEqualToUnescapedAlloc(cmp->getOperand(otherIndex), ai)) + return false; + continue; + } + + case Instruction::Call: + // Ignore no-op and store intrinsics. + if (IntrinsicInst *intrinsic = dyn_cast(instr)) { + switch (intrinsic->getIntrinsicID()) { + default: + return false; + + case Intrinsic::memmove: + case Intrinsic::memcpy: + case Intrinsic::memset: { + MemIntrinsic *MI = cast(intrinsic); + if (MI->isVolatile()) + return false; + LLVM_FALLTHROUGH; + } + case Intrinsic::assume: + case Intrinsic::invariant_start: + case Intrinsic::invariant_end: + case Intrinsic::lifetime_start: + case Intrinsic::lifetime_end: + users.emplace_back(instr); + continue; + case Intrinsic::launder_invariant_group: + case Intrinsic::strip_invariant_group: + users.emplace_back(instr); + worklist.push_back(instr); + continue; + } + } + + if (isFree(instr)) { + users.emplace_back(instr); + continue; + } + + if (isRealloc(instr)) { + // If the realloc also has constant small size, + // then we can just update the assumed size to be + // max of original alloc's and this realloc's. + uint64_t newSize = 0; + if (getFixedArg(*dyn_cast(instr), newSize, 1) && + sizeOkToDemote(newSize)) { + size = std::max(size, newSize); + } else { + return false; + } + + users.emplace_back(instr); + worklist.push_back(instr); + continue; + } + + return false; + + case Instruction::Store: { + StoreInst *si = cast(instr); + if (si->isVolatile() || si->getPointerOperand() != pi) + return false; + continue; + } + + case Instruction::Load: { + LoadInst *li = cast(instr); + if (li->isVolatile()) + return false; + continue; + } + } + seqassert(false, "missing a return?"); + } + } while (!worklist.empty()); + return true; + } + + void getErasesAndReplacementsForAlloc( + llvm::Instruction &mi, llvm::SmallPtrSetImpl &erase, + llvm::SmallVectorImpl> &replace, + llvm::SmallVectorImpl &alloca, + llvm::SmallVectorImpl &untail) { + using namespace llvm; + + uint64_t size = 0; + SmallVector users; + + if (isAllocSiteRemovable(&mi, users)) { + for (unsigned i = 0, e = users.size(); i != e; ++i) { + if (!users[i]) + continue; + + Instruction *instr = cast(&*users[i]); + if (ICmpInst *cmp = dyn_cast(instr)) { + replace.emplace_back(cmp, ConstantInt::get(Type::getInt1Ty(cmp->getContext()), + cmp->isFalseWhenEqual())); + } else if (!isa(instr)) { + // Casts, GEP, or anything else: we're about to delete this instruction, + // so it can not have any valid uses. + replace.emplace_back(instr, PoisonValue::get(instr->getType())); + } + erase.insert(instr); + } + erase.insert(&mi); + return; + } else { + users.clear(); + } + + if (isAllocSiteDemotable(&mi, size, users)) { + auto *replacement = new AllocaInst( + Type::getInt8Ty(mi.getContext()), 0, + ConstantInt::get(Type::getInt64Ty(mi.getContext()), size), Align()); + alloca.push_back(replacement); + replace.emplace_back(&mi, replacement); + erase.insert(&mi); + + for (unsigned i = 0, e = users.size(); i != e; ++i) { + if (!users[i]) + continue; + + Instruction *instr = cast(&*users[i]); + if (isFree(instr)) { + erase.insert(instr); + } else if (isRealloc(instr)) { + replace.emplace_back(instr, replacement); + erase.insert(instr); + } else if (auto *ci = dyn_cast(&*instr)) { + if (ci->isTailCall() || ci->isMustTailCall()) + untail.push_back(ci); + } + } + } + } + + bool runOnFunction(llvm::Function &func) override { + using namespace llvm; + + SmallSet erase; + SmallVector, 32> replace; + SmallVector alloca; + SmallVector untail; + + for (inst_iterator instr = inst_begin(func), end = inst_end(func); instr != end; + ++instr) { + auto *cb = dyn_cast(&*instr); + if (!cb || !isAlloc(cb)) + continue; + + getErasesAndReplacementsForAlloc(*cb, erase, replace, alloca, untail); + } + + for (auto *A : alloca) { + A->insertBefore(func.getEntryBlock().getFirstNonPHI()); + } + + for (auto *C : untail) { + C->setTailCall(false); + } + + for (auto &P : replace) { + P.first->replaceAllUsesWith(P.second); + } + + for (auto *I : erase) { + I->dropAllReferences(); + } + + for (auto *I : erase) { + I->eraseFromParent(); + } + + return !erase.empty() || !replace.empty() || !alloca.empty() || !untail.empty(); + } +}; + +void addAllocationRemover(const llvm::PassManagerBuilder &builder, + llvm::legacy::PassManagerBase &pm) { + pm.add(new AllocationRemover()); +} + +char AllocationRemover::ID = 0; +llvm::RegisterPass X1("alloc-remove", "Allocation Remover"); + /// Sometimes coroutine lowering produces hard-to-analyze loops involving /// function pointer comparisons. This pass puts them into a somewhat /// easier-to-analyze form. @@ -161,8 +561,8 @@ void addCoroutineBranchSimplifier(const llvm::PassManagerBuilder &builder, } char CoroBranchSimplifier::ID = 0; -llvm::RegisterPass X("coro-br-simpl", - "Coroutine Branch Simplifier"); +llvm::RegisterPass X2("coro-br-simpl", + "Coroutine Branch Simplifier"); void runLLVMOptimizationPasses(llvm::Module *module, bool debug, bool jit, PluginManager *plugins) { @@ -211,6 +611,8 @@ void runLLVMOptimizationPasses(llvm::Module *module, bool debug, bool jit, if (!debug) { pmb.addExtension(llvm::PassManagerBuilder::EP_LateLoopOptimizations, addCoroutineBranchSimplifier); + pmb.addExtension(llvm::PassManagerBuilder::EP_Peephole, + addAllocationRemover); } if (plugins) {