]> granicus.if.org Git - llvm/commitdiff
[SCEV] Maintain loop use lists, and use them in forgetLoop
authorSanjoy Das <sanjoy@playingwithpointers.com>
Fri, 13 Oct 2017 05:50:52 +0000 (05:50 +0000)
committerSanjoy Das <sanjoy@playingwithpointers.com>
Fri, 13 Oct 2017 05:50:52 +0000 (05:50 +0000)
Summary:
Currently we do not correctly invalidate memoized results for add recurrences
that were created directly (i.e. they were not created from a `Value`).  This
change fixes this by keeping loop use lists and using the loop use lists to
determine which SCEV expressions to invalidate.

Here are some statistics on the number of uses of in the use lists of all loops
on a clang bootstrap (config: release, no asserts):

     Count: 731310
       Min: 1
      Mean: 8.555150
50th %time: 4
95th %tile: 25
99th %tile: 53
       Max: 433

Reviewers: atrick, sunfish, mkazantsev

Subscribers: mcrosier, llvm-commits

Differential Revision: https://reviews.llvm.org/D38434

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@315672 91177308-0d34-0410-b5e6-96231b3b80d8

include/llvm/Analysis/ScalarEvolution.h
lib/Analysis/ScalarEvolution.cpp
unittests/Analysis/ScalarEvolutionTest.cpp

index 5409949c6fb1df32a354e919cfc9046bfa3eb878..247d835ead14ccd23f76a6ae7e71082b3afdad5e 100644 (file)
@@ -1761,10 +1761,18 @@ private:
   const SCEV *getOrCreateMulExpr(SmallVectorImpl<const SCEV *> &Ops,
                                  SCEV::NoWrapFlags Flags);
 
+  /// Find all of the loops transitively used in \p S, and update \c LoopUsers
+  /// accordingly.
+  void addToLoopUseLists(const SCEV *S);
+
   FoldingSet<SCEV> UniqueSCEVs;
   FoldingSet<SCEVPredicate> UniquePreds;
   BumpPtrAllocator SCEVAllocator;
 
+  /// This maps loops to a list of SCEV expressions that (transitively) use said
+  /// loop.
+  DenseMap<const Loop *, SmallVector<const SCEV *, 4>> LoopUsers;
+
   /// Cache tentative mappings from UnknownSCEVs in a Loop, to a SCEV expression
   /// they can be rewritten into under certain predicates.
   DenseMap<std::pair<const SCEVUnknown *, const Loop *>,
index c294f61234257a6924f44072228931d5d219082e..6d9ab911eded01b91f0b716e39b36e5feacb2fbe 100644 (file)
@@ -1290,6 +1290,7 @@ const SCEV *ScalarEvolution::getTruncateExpr(const SCEV *Op,
   SCEV *S = new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator),
                                                  Op, Ty);
   UniqueSCEVs.InsertNode(S, IP);
+  addToLoopUseLists(S);
   return S;
 }
 
@@ -1580,6 +1581,7 @@ ScalarEvolution::getZeroExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) {
     SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
                                                      Op, Ty);
     UniqueSCEVs.InsertNode(S, IP);
+    addToLoopUseLists(S);
     return S;
   }
 
@@ -1766,6 +1768,7 @@ ScalarEvolution::getZeroExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) {
   SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
                                                    Op, Ty);
   UniqueSCEVs.InsertNode(S, IP);
+  addToLoopUseLists(S);
   return S;
 }
 
@@ -1803,6 +1806,7 @@ ScalarEvolution::getSignExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) {
     SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
                                                      Op, Ty);
     UniqueSCEVs.InsertNode(S, IP);
+    addToLoopUseLists(S);
     return S;
   }
 
@@ -2014,6 +2018,7 @@ ScalarEvolution::getSignExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) {
   SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
                                                    Op, Ty);
   UniqueSCEVs.InsertNode(S, IP);
+  addToLoopUseLists(S);
   return S;
 }
 
@@ -2662,6 +2667,7 @@ ScalarEvolution::getOrCreateAddExpr(SmallVectorImpl<const SCEV *> &Ops,
     S = new (SCEVAllocator)
         SCEVAddExpr(ID.Intern(SCEVAllocator), O, Ops.size());
     UniqueSCEVs.InsertNode(S, IP);
+    addToLoopUseLists(S);
   }
   S->setNoWrapFlags(Flags);
   return S;
@@ -2683,6 +2689,7 @@ ScalarEvolution::getOrCreateMulExpr(SmallVectorImpl<const SCEV *> &Ops,
     S = new (SCEVAllocator) SCEVMulExpr(ID.Intern(SCEVAllocator),
                                         O, Ops.size());
     UniqueSCEVs.InsertNode(S, IP);
+    addToLoopUseLists(S);
   }
   S->setNoWrapFlags(Flags);
   return S;
@@ -3135,6 +3142,7 @@ const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS,
   SCEV *S = new (SCEVAllocator) SCEVUDivExpr(ID.Intern(SCEVAllocator),
                                              LHS, RHS);
   UniqueSCEVs.InsertNode(S, IP);
+  addToLoopUseLists(S);
   return S;
 }
 
@@ -3315,6 +3323,7 @@ ScalarEvolution::getAddRecExpr(SmallVectorImpl<const SCEV *> &Operands,
     S = new (SCEVAllocator) SCEVAddRecExpr(ID.Intern(SCEVAllocator),
                                            O, Operands.size(), L);
     UniqueSCEVs.InsertNode(S, IP);
+    addToLoopUseLists(S);
   }
   S->setNoWrapFlags(Flags);
   return S;
@@ -3470,6 +3479,7 @@ ScalarEvolution::getSMaxExpr(SmallVectorImpl<const SCEV *> &Ops) {
   SCEV *S = new (SCEVAllocator) SCEVSMaxExpr(ID.Intern(SCEVAllocator),
                                              O, Ops.size());
   UniqueSCEVs.InsertNode(S, IP);
+  addToLoopUseLists(S);
   return S;
 }
 
@@ -3571,6 +3581,7 @@ ScalarEvolution::getUMaxExpr(SmallVectorImpl<const SCEV *> &Ops) {
   SCEV *S = new (SCEVAllocator) SCEVUMaxExpr(ID.Intern(SCEVAllocator),
                                              O, Ops.size());
   UniqueSCEVs.InsertNode(S, IP);
+  addToLoopUseLists(S);
   return S;
 }
 
@@ -6392,6 +6403,13 @@ void ScalarEvolution::forgetLoop(const Loop *L) {
         ++I;
     }
 
+    auto LoopUsersItr = LoopUsers.find(CurrL);
+    if (LoopUsersItr != LoopUsers.end()) {
+      for (auto *S : LoopUsersItr->second)
+        forgetMemoizedResults(S);
+      LoopUsers.erase(LoopUsersItr);
+    }
+
     // Drop information about expressions based on loop-header PHIs.
     PushLoopPHIs(CurrL, Worklist);
 
@@ -10574,6 +10592,7 @@ ScalarEvolution::ScalarEvolution(ScalarEvolution &&Arg)
       UniqueSCEVs(std::move(Arg.UniqueSCEVs)),
       UniquePreds(std::move(Arg.UniquePreds)),
       SCEVAllocator(std::move(Arg.SCEVAllocator)),
+      LoopUsers(std::move(Arg.LoopUsers)),
       PredicatedSCEVRewrites(std::move(Arg.PredicatedSCEVRewrites)),
       FirstUnknown(Arg.FirstUnknown) {
   Arg.FirstUnknown = nullptr;
@@ -11016,6 +11035,25 @@ ScalarEvolution::forgetMemoizedResults(const SCEV *S, bool EraseExitLimit) {
         ExitLimits.erase(I);
 }
 
+void ScalarEvolution::addToLoopUseLists(const SCEV *S) {
+  struct FindUsedLoops {
+    SmallPtrSet<const Loop *, 8> LoopsUsed;
+    bool follow(const SCEV *S) {
+      if (auto *AR = dyn_cast<SCEVAddRecExpr>(S))
+        LoopsUsed.insert(AR->getLoop());
+      return true;
+    }
+
+    bool isDone() const { return false; }
+  };
+
+  FindUsedLoops F;
+  SCEVTraversal<FindUsedLoops>(F).visitAll(S);
+
+  for (auto *L : F.LoopsUsed)
+    LoopUsers[L].push_back(S);
+}
+
 void ScalarEvolution::verify() const {
   ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this);
   ScalarEvolution SE2(F, TLI, AC, DT, LI);
index f4ef32bc6f89c3dc0f2e3dc12cba29aa7243f6da..1f51c1c91a56634951de90c81d078be1a5b032c6 100644 (file)
@@ -856,6 +856,17 @@ TEST_F(ScalarEvolutionsTest, SCEVExitLimitForgetLoop) {
   EXPECT_TRUE(isa<SCEVConstant>(EC));
   EXPECT_EQ(cast<SCEVConstant>(EC)->getAPInt().getLimitedValue(), 999u);
 
+  // The add recurrence {5,+,1} does not correspond to any PHI in the IR, and
+  // that is relevant to this test.
+  auto *Five = SE.getConstant(APInt(/*numBits=*/64, 5));
+  auto *AR =
+      SE.getAddRecExpr(Five, SE.getOne(T_int64), Loop, SCEV::FlagAnyWrap);
+  const SCEV *ARAtLoopExit = SE.getSCEVAtScope(AR, nullptr);
+  EXPECT_FALSE(isa<SCEVCouldNotCompute>(ARAtLoopExit));
+  EXPECT_TRUE(isa<SCEVConstant>(ARAtLoopExit));
+  EXPECT_EQ(cast<SCEVConstant>(ARAtLoopExit)->getAPInt().getLimitedValue(),
+            1004u);
+
   SE.forgetLoop(Loop);
   Br->eraseFromParent();
   Cond->eraseFromParent();
@@ -868,6 +879,11 @@ TEST_F(ScalarEvolutionsTest, SCEVExitLimitForgetLoop) {
   EXPECT_FALSE(isa<SCEVCouldNotCompute>(NewEC));
   EXPECT_TRUE(isa<SCEVConstant>(NewEC));
   EXPECT_EQ(cast<SCEVConstant>(NewEC)->getAPInt().getLimitedValue(), 1999u);
+  const SCEV *NewARAtLoopExit = SE.getSCEVAtScope(AR, nullptr);
+  EXPECT_FALSE(isa<SCEVCouldNotCompute>(NewARAtLoopExit));
+  EXPECT_TRUE(isa<SCEVConstant>(NewARAtLoopExit));
+  EXPECT_EQ(cast<SCEVConstant>(NewARAtLoopExit)->getAPInt().getLimitedValue(),
+            2004u);
 }
 
 // Make sure that SCEV invalidates exit limits after invalidating the values it