From: Sergey Dmitriev Date: Fri, 8 Feb 2019 06:55:18 +0000 (+0000) Subject: [CodeExtractor] Update function's assumption cache after extracting blocks from it X-Git-Url: https://granicus.if.org/sourcecode?a=commitdiff_plain;h=9a22bfd50bf3694bbc4624f3f784108f23dce08f;p=llvm [CodeExtractor] Update function's assumption cache after extracting blocks from it Summary: Assumption cache's self-updating mechanism does not correctly handle the case when blocks are extracted from the function by the CodeExtractor. As a result function's assumption cache may have stale references to the llvm.assume calls that were moved to the outlined function. This patch fixes this problem by removing extracted llvm.assume calls from the function’s assumption cache. Reviewers: hfinkel, vsk, fhahn, davidxl, sanjoy Reviewed By: hfinkel, vsk Subscribers: llvm-commits Tags: #llvm Differential Revision: https://reviews.llvm.org/D57215 git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@353500 91177308-0d34-0410-b5e6-96231b3b80d8 --- diff --git a/include/llvm/Analysis/AssumptionCache.h b/include/llvm/Analysis/AssumptionCache.h index 93dbae8de79..b42846472f2 100644 --- a/include/llvm/Analysis/AssumptionCache.h +++ b/include/llvm/Analysis/AssumptionCache.h @@ -103,6 +103,10 @@ public: /// not already be in the cache. void registerAssumption(CallInst *CI); + /// Remove an \@llvm.assume intrinsic from this function's cache if it has + /// been added to the cache earlier. + void unregisterAssumption(CallInst *CI); + /// Update the cache of values being affected by this assumption (i.e. /// the values about which this assumption provides information). void updateAffectedValues(CallInst *CI); @@ -208,6 +212,10 @@ public: /// existing cache will be returned. AssumptionCache &getAssumptionCache(Function &F); + /// Return the cached assumptions for a function if it has already been + /// scanned. Otherwise return nullptr. + AssumptionCache *lookupAssumptionCache(Function &F); + AssumptionCacheTracker(); ~AssumptionCacheTracker() override; diff --git a/include/llvm/Transforms/Utils/CodeExtractor.h b/include/llvm/Transforms/Utils/CodeExtractor.h index e13940bb0a5..becbf0ea62f 100644 --- a/include/llvm/Transforms/Utils/CodeExtractor.h +++ b/include/llvm/Transforms/Utils/CodeExtractor.h @@ -26,6 +26,7 @@ class BasicBlock; class BlockFrequency; class BlockFrequencyInfo; class BranchProbabilityInfo; +class AssumptionCache; class CallInst; class DominatorTree; class Function; @@ -56,6 +57,7 @@ class Value; const bool AggregateArgs; BlockFrequencyInfo *BFI; BranchProbabilityInfo *BPI; + AssumptionCache *AC; // If true, varargs functions can be extracted. bool AllowVarArgs; @@ -84,6 +86,7 @@ class Value; CodeExtractor(ArrayRef BBs, DominatorTree *DT = nullptr, bool AggregateArgs = false, BlockFrequencyInfo *BFI = nullptr, BranchProbabilityInfo *BPI = nullptr, + AssumptionCache *AC = nullptr, bool AllowVarArgs = false, bool AllowAlloca = false, std::string Suffix = ""); @@ -94,6 +97,7 @@ class Value; CodeExtractor(DominatorTree &DT, Loop &L, bool AggregateArgs = false, BlockFrequencyInfo *BFI = nullptr, BranchProbabilityInfo *BPI = nullptr, + AssumptionCache *AC = nullptr, std::string Suffix = ""); /// Perform the extraction, returning the new function. diff --git a/lib/Analysis/AssumptionCache.cpp b/lib/Analysis/AssumptionCache.cpp index bc7b0938f9a..cf2f845dee0 100644 --- a/lib/Analysis/AssumptionCache.cpp +++ b/lib/Analysis/AssumptionCache.cpp @@ -53,11 +53,11 @@ AssumptionCache::getOrInsertAffectedValues(Value *V) { return AVIP.first->second; } -void AssumptionCache::updateAffectedValues(CallInst *CI) { +static void findAffectedValues(CallInst *CI, + SmallVectorImpl &Affected) { // Note: This code must be kept in-sync with the code in // computeKnownBitsFromAssume in ValueTracking. - SmallVector Affected; auto AddAffected = [&Affected](Value *V) { if (isa(V)) { Affected.push_back(V); @@ -108,6 +108,11 @@ void AssumptionCache::updateAffectedValues(CallInst *CI) { AddAffectedFromEq(B); } } +} + +void AssumptionCache::updateAffectedValues(CallInst *CI) { + SmallVector Affected; + findAffectedValues(CI, Affected); for (auto &AV : Affected) { auto &AVV = getOrInsertAffectedValues(AV); @@ -116,6 +121,18 @@ void AssumptionCache::updateAffectedValues(CallInst *CI) { } } +void AssumptionCache::unregisterAssumption(CallInst *CI) { + SmallVector Affected; + findAffectedValues(CI, Affected); + + for (auto &AV : Affected) { + auto AVI = AffectedValues.find_as(AV); + if (AVI != AffectedValues.end()) + AffectedValues.erase(AVI); + } + remove_if(AssumeHandles, [CI](WeakTrackingVH &VH) { return CI == VH; }); +} + void AssumptionCache::AffectedValueCallbackVH::deleted() { auto AVI = AC->AffectedValues.find(getValPtr()); if (AVI != AC->AffectedValues.end()) @@ -240,6 +257,13 @@ AssumptionCache &AssumptionCacheTracker::getAssumptionCache(Function &F) { return *IP.first->second; } +AssumptionCache *AssumptionCacheTracker::lookupAssumptionCache(Function &F) { + auto I = AssumptionCaches.find_as(&F); + if (I != AssumptionCaches.end()) + return I->second.get(); + return nullptr; +} + void AssumptionCacheTracker::verifyAnalysis() const { // FIXME: In the long term the verifier should not be controllable with a // flag. We should either fix all passes to correctly update the assumption diff --git a/lib/Transforms/IPO/HotColdSplitting.cpp b/lib/Transforms/IPO/HotColdSplitting.cpp index 648e2aed758..b8def7ad3ea 100644 --- a/lib/Transforms/IPO/HotColdSplitting.cpp +++ b/lib/Transforms/IPO/HotColdSplitting.cpp @@ -173,8 +173,9 @@ public: HotColdSplitting(ProfileSummaryInfo *ProfSI, function_ref GBFI, function_ref GTTI, - std::function *GORE) - : PSI(ProfSI), GetBFI(GBFI), GetTTI(GTTI), GetORE(GORE) {} + std::function *GORE, + function_ref LAC) + : PSI(ProfSI), GetBFI(GBFI), GetTTI(GTTI), GetORE(GORE), LookupAC(LAC) {} bool run(Module &M); private: @@ -183,11 +184,13 @@ private: bool outlineColdRegions(Function &F, bool HasProfileSummary); Function *extractColdRegion(const BlockSequence &Region, DominatorTree &DT, BlockFrequencyInfo *BFI, TargetTransformInfo &TTI, - OptimizationRemarkEmitter &ORE, unsigned Count); + OptimizationRemarkEmitter &ORE, + AssumptionCache *AC, unsigned Count); ProfileSummaryInfo *PSI; function_ref GetBFI; function_ref GetTTI; std::function *GetORE; + function_ref LookupAC; }; class HotColdSplittingLegacyPass : public ModulePass { @@ -198,10 +201,10 @@ public: } void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.addRequired(); AU.addRequired(); AU.addRequired(); AU.addRequired(); + AU.addUsedIfAvailable(); } bool runOnModule(Module &M) override; @@ -316,12 +319,13 @@ Function *HotColdSplitting::extractColdRegion(const BlockSequence &Region, BlockFrequencyInfo *BFI, TargetTransformInfo &TTI, OptimizationRemarkEmitter &ORE, + AssumptionCache *AC, unsigned Count) { assert(!Region.empty()); // TODO: Pass BFI and BPI to update profile information. CodeExtractor CE(Region, &DT, /* AggregateArgs */ false, /* BFI */ nullptr, - /* BPI */ nullptr, /* AllowVarArgs */ false, + /* BPI */ nullptr, AC, /* AllowVarArgs */ false, /* AllowAlloca */ false, /* Suffix */ "cold." + std::to_string(Count)); @@ -577,6 +581,7 @@ bool HotColdSplitting::outlineColdRegions(Function &F, bool HasProfileSummary) { TargetTransformInfo &TTI = GetTTI(F); OptimizationRemarkEmitter &ORE = (*GetORE)(F); + AssumptionCache *AC = LookupAC(F); // Find all cold regions. for (BasicBlock *BB : RPOT) { @@ -638,8 +643,8 @@ bool HotColdSplitting::outlineColdRegions(Function &F, bool HasProfileSummary) { BB->dump(); }); - Function *Outlined = - extractColdRegion(SubRegion, *DT, BFI, TTI, ORE, OutlinedFunctionID); + Function *Outlined = extractColdRegion(SubRegion, *DT, BFI, TTI, ORE, AC, + OutlinedFunctionID); if (Outlined) { ++OutlinedFunctionID; Changed = true; @@ -698,17 +703,21 @@ bool HotColdSplittingLegacyPass::runOnModule(Module &M) { ORE.reset(new OptimizationRemarkEmitter(&F)); return *ORE.get(); }; + auto LookupAC = [this](Function &F) -> AssumptionCache * { + if (auto *ACT = getAnalysisIfAvailable()) + return ACT->lookupAssumptionCache(F); + return nullptr; + }; - return HotColdSplitting(PSI, GBFI, GTTI, &GetORE).run(M); + return HotColdSplitting(PSI, GBFI, GTTI, &GetORE, LookupAC).run(M); } PreservedAnalyses HotColdSplittingPass::run(Module &M, ModuleAnalysisManager &AM) { auto &FAM = AM.getResult(M).getManager(); - std::function GetAssumptionCache = - [&FAM](Function &F) -> AssumptionCache & { - return FAM.getResult(F); + auto LookupAC = [&FAM](Function &F) -> AssumptionCache * { + return FAM.getCachedResult(F); }; auto GBFI = [&FAM](Function &F) { @@ -729,7 +738,7 @@ HotColdSplittingPass::run(Module &M, ModuleAnalysisManager &AM) { ProfileSummaryInfo *PSI = &AM.getResult(M); - if (HotColdSplitting(PSI, GBFI, GTTI, &GetORE).run(M)) + if (HotColdSplitting(PSI, GBFI, GTTI, &GetORE, LookupAC).run(M)) return PreservedAnalyses::none(); return PreservedAnalyses::all(); } diff --git a/lib/Transforms/IPO/LoopExtractor.cpp b/lib/Transforms/IPO/LoopExtractor.cpp index 6e7e59a8537..91c7b5f5f13 100644 --- a/lib/Transforms/IPO/LoopExtractor.cpp +++ b/lib/Transforms/IPO/LoopExtractor.cpp @@ -14,6 +14,7 @@ //===----------------------------------------------------------------------===// #include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/LoopPass.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Instructions.h" @@ -50,6 +51,7 @@ namespace { AU.addRequiredID(LoopSimplifyID); AU.addRequired(); AU.addRequired(); + AU.addUsedIfAvailable(); } }; } @@ -138,7 +140,10 @@ bool LoopExtractor::runOnLoop(Loop *L, LPPassManager &LPM) { if (ShouldExtractLoop) { if (NumLoops == 0) return Changed; --NumLoops; - CodeExtractor Extractor(DT, *L); + AssumptionCache *AC = nullptr; + if (auto *ACT = getAnalysisIfAvailable()) + AC = ACT->lookupAssumptionCache(*L->getHeader()->getParent()); + CodeExtractor Extractor(DT, *L, false, nullptr, nullptr, AC); if (Extractor.extractCodeRegion() != nullptr) { Changed = true; // After extraction, the loop is replaced by a function call, so diff --git a/lib/Transforms/IPO/PartialInlining.cpp b/lib/Transforms/IPO/PartialInlining.cpp index f971cee3ab5..8339eb456da 100644 --- a/lib/Transforms/IPO/PartialInlining.cpp +++ b/lib/Transforms/IPO/PartialInlining.cpp @@ -199,10 +199,12 @@ struct PartialInlinerImpl { PartialInlinerImpl( std::function *GetAC, + function_ref LookupAC, std::function *GTTI, Optional> GBFI, ProfileSummaryInfo *ProfSI) - : GetAssumptionCache(GetAC), GetTTI(GTTI), GetBFI(GBFI), PSI(ProfSI) {} + : GetAssumptionCache(GetAC), LookupAssumptionCache(LookupAC), + GetTTI(GTTI), GetBFI(GBFI), PSI(ProfSI) {} bool run(Module &M); // Main part of the transformation that calls helper functions to find @@ -222,9 +224,11 @@ struct PartialInlinerImpl { // Two constructors, one for single region outlining, the other for // multi-region outlining. FunctionCloner(Function *F, FunctionOutliningInfo *OI, - OptimizationRemarkEmitter &ORE); + OptimizationRemarkEmitter &ORE, + function_ref LookupAC); FunctionCloner(Function *F, FunctionOutliningMultiRegionInfo *OMRI, - OptimizationRemarkEmitter &ORE); + OptimizationRemarkEmitter &ORE, + function_ref LookupAC); ~FunctionCloner(); // Prepare for function outlining: making sure there is only @@ -260,11 +264,13 @@ struct PartialInlinerImpl { std::unique_ptr ClonedOMRI = nullptr; std::unique_ptr ClonedFuncBFI = nullptr; OptimizationRemarkEmitter &ORE; + function_ref LookupAC; }; private: int NumPartialInlining = 0; std::function *GetAssumptionCache; + function_ref LookupAssumptionCache; std::function *GetTTI; Optional> GetBFI; ProfileSummaryInfo *PSI; @@ -365,12 +371,17 @@ struct PartialInlinerLegacyPass : public ModulePass { return ACT->getAssumptionCache(F); }; + auto LookupAssumptionCache = [ACT](Function &F) -> AssumptionCache * { + return ACT->lookupAssumptionCache(F); + }; + std::function GetTTI = [&TTIWP](Function &F) -> TargetTransformInfo & { return TTIWP->getTTI(F); }; - return PartialInlinerImpl(&GetAssumptionCache, &GetTTI, NoneType::None, PSI) + return PartialInlinerImpl(&GetAssumptionCache, LookupAssumptionCache, + &GetTTI, NoneType::None, PSI) .run(M); } }; @@ -948,8 +959,9 @@ void PartialInlinerImpl::computeCallsiteToProfCountMap( } PartialInlinerImpl::FunctionCloner::FunctionCloner( - Function *F, FunctionOutliningInfo *OI, OptimizationRemarkEmitter &ORE) - : OrigFunc(F), ORE(ORE) { + Function *F, FunctionOutliningInfo *OI, OptimizationRemarkEmitter &ORE, + function_ref LookupAC) + : OrigFunc(F), ORE(ORE), LookupAC(LookupAC) { ClonedOI = llvm::make_unique(); // Clone the function, so that we can hack away on it. @@ -972,8 +984,9 @@ PartialInlinerImpl::FunctionCloner::FunctionCloner( PartialInlinerImpl::FunctionCloner::FunctionCloner( Function *F, FunctionOutliningMultiRegionInfo *OI, - OptimizationRemarkEmitter &ORE) - : OrigFunc(F), ORE(ORE) { + OptimizationRemarkEmitter &ORE, + function_ref LookupAC) + : OrigFunc(F), ORE(ORE), LookupAC(LookupAC) { ClonedOMRI = llvm::make_unique(); // Clone the function, so that we can hack away on it. @@ -1111,7 +1124,9 @@ bool PartialInlinerImpl::FunctionCloner::doMultiRegionFunctionOutlining() { int CurrentOutlinedRegionCost = ComputeRegionCost(RegionInfo.Region); CodeExtractor CE(RegionInfo.Region, &DT, /*AggregateArgs*/ false, - ClonedFuncBFI.get(), &BPI, /* AllowVarargs */ false); + ClonedFuncBFI.get(), &BPI, + LookupAC(*RegionInfo.EntryBlock->getParent()), + /* AllowVarargs */ false); CE.findInputsOutputs(Inputs, Outputs, Sinks); @@ -1193,7 +1208,7 @@ PartialInlinerImpl::FunctionCloner::doSingleRegionFunctionOutlining() { // Extract the body of the if. Function *OutlinedFunc = CodeExtractor(ToExtract, &DT, /*AggregateArgs*/ false, - ClonedFuncBFI.get(), &BPI, + ClonedFuncBFI.get(), &BPI, LookupAC(*ClonedFunc), /* AllowVarargs */ true) .extractCodeRegion(); @@ -1257,7 +1272,7 @@ std::pair PartialInlinerImpl::unswitchFunction(Function *F) { std::unique_ptr OMRI = computeOutliningColdRegionsInfo(F, ORE); if (OMRI) { - FunctionCloner Cloner(F, OMRI.get(), ORE); + FunctionCloner Cloner(F, OMRI.get(), ORE, LookupAssumptionCache); #ifndef NDEBUG if (TracePartialInlining) { @@ -1290,7 +1305,7 @@ std::pair PartialInlinerImpl::unswitchFunction(Function *F) { if (!OI) return {false, nullptr}; - FunctionCloner Cloner(F, OI.get(), ORE); + FunctionCloner Cloner(F, OI.get(), ORE, LookupAssumptionCache); Cloner.NormalizeReturnBlock(); Function *OutlinedFunction = Cloner.doSingleRegionFunctionOutlining(); @@ -1484,6 +1499,10 @@ PreservedAnalyses PartialInlinerPass::run(Module &M, return FAM.getResult(F); }; + auto LookupAssumptionCache = [&FAM](Function &F) -> AssumptionCache * { + return FAM.getCachedResult(F); + }; + std::function GetBFI = [&FAM](Function &F) -> BlockFrequencyInfo & { return FAM.getResult(F); @@ -1496,7 +1515,8 @@ PreservedAnalyses PartialInlinerPass::run(Module &M, ProfileSummaryInfo *PSI = &AM.getResult(M); - if (PartialInlinerImpl(&GetAssumptionCache, &GetTTI, {GetBFI}, PSI) + if (PartialInlinerImpl(&GetAssumptionCache, LookupAssumptionCache, &GetTTI, + {GetBFI}, PSI) .run(M)) return PreservedAnalyses::none(); return PreservedAnalyses::all(); diff --git a/lib/Transforms/Utils/CodeExtractor.cpp b/lib/Transforms/Utils/CodeExtractor.cpp index 9148c11ba15..e941de29c48 100644 --- a/lib/Transforms/Utils/CodeExtractor.cpp +++ b/lib/Transforms/Utils/CodeExtractor.cpp @@ -20,6 +20,7 @@ #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/BlockFrequencyInfo.h" #include "llvm/Analysis/BlockFrequencyInfoImpl.h" #include "llvm/Analysis/BranchProbabilityInfo.h" @@ -43,6 +44,7 @@ #include "llvm/IR/LLVMContext.h" #include "llvm/IR/MDBuilder.h" #include "llvm/IR/Module.h" +#include "llvm/IR/PatternMatch.h" #include "llvm/IR/Type.h" #include "llvm/IR/User.h" #include "llvm/IR/Value.h" @@ -66,6 +68,7 @@ #include using namespace llvm; +using namespace llvm::PatternMatch; using ProfileCount = Function::ProfileCount; #define DEBUG_TYPE "code-extractor" @@ -235,18 +238,20 @@ buildExtractionBlockSet(ArrayRef BBs, DominatorTree *DT, CodeExtractor::CodeExtractor(ArrayRef BBs, DominatorTree *DT, bool AggregateArgs, BlockFrequencyInfo *BFI, - BranchProbabilityInfo *BPI, bool AllowVarArgs, - bool AllowAlloca, std::string Suffix) + BranchProbabilityInfo *BPI, AssumptionCache *AC, + bool AllowVarArgs, bool AllowAlloca, + std::string Suffix) : DT(DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI), - BPI(BPI), AllowVarArgs(AllowVarArgs), + BPI(BPI), AC(AC), AllowVarArgs(AllowVarArgs), Blocks(buildExtractionBlockSet(BBs, DT, AllowVarArgs, AllowAlloca)), Suffix(Suffix) {} CodeExtractor::CodeExtractor(DominatorTree &DT, Loop &L, bool AggregateArgs, BlockFrequencyInfo *BFI, - BranchProbabilityInfo *BPI, std::string Suffix) + BranchProbabilityInfo *BPI, AssumptionCache *AC, + std::string Suffix) : DT(&DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI), - BPI(BPI), AllowVarArgs(false), + BPI(BPI), AC(AC), AllowVarArgs(false), Blocks(buildExtractionBlockSet(L.getBlocks(), &DT, /* AllowVarArgs */ false, /* AllowAlloca */ false)), @@ -1216,6 +1221,13 @@ void CodeExtractor::moveCodeToFunction(Function *newFunction) { // Insert this basic block into the new function newBlocks.push_back(Block); + + // Remove @llvm.assume calls that were moved to the new function from the + // old function's assumption cache. + if (AC) + for (auto &I : *Block) + if (match(&I, m_Intrinsic())) + AC->unregisterAssumption(cast(&I)); } } diff --git a/test/Transforms/CodeExtractor/extract-assume.ll b/test/Transforms/CodeExtractor/extract-assume.ll new file mode 100644 index 00000000000..b79c6a69137 --- /dev/null +++ b/test/Transforms/CodeExtractor/extract-assume.ll @@ -0,0 +1,29 @@ +; RUN: opt -passes="function(slp-vectorizer),module(hotcoldsplit),function(slp-vectorizer,print)" -disable-output %s 2>&1 | FileCheck %s +; +; Make sure this compiles. Check that function assumption cache is refreshed +; after extracting blocks with assume calls from the function. + +; CHECK: Cached assumptions for function: fun +; CHECK-NEXT: Cached assumptions for function: fun.cold +; CHECK-NEXT: %cmp = icmp uge i32 %x, 64 + +declare void @fun2(i32) #0 + +define void @fun(i32 %x) { +entry: + br i1 undef, label %if.then, label %if.else + +if.then: + ret void + +if.else: + %cmp = icmp uge i32 %x, 64 + call void @llvm.assume(i1 %cmp) + call void @fun2(i32 %x) + unreachable +} + +declare void @llvm.assume(i1) #1 + +attributes #0 = { alwaysinline } +attributes #1 = { nounwind }