]> granicus.if.org Git - llvm/commitdiff
[SCEV] Refactor out a useful pattern; NFC
authorSanjoy Das <sanjoy@playingwithpointers.com>
Wed, 9 Nov 2016 18:22:43 +0000 (18:22 +0000)
committerSanjoy Das <sanjoy@playingwithpointers.com>
Wed, 9 Nov 2016 18:22:43 +0000 (18:22 +0000)
git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@286386 91177308-0d34-0410-b5e6-96231b3b80d8

include/llvm/Analysis/ScalarEvolutionExpressions.h
lib/Analysis/ScalarEvolution.cpp

index 9113880ef25e45373d13a6d4f1fb1d048d4e0894..fdcd8be00dde88aab4dc5faebb658d8424b27efd 100644 (file)
@@ -537,6 +537,31 @@ namespace llvm {
     T.visitAll(Root);
   }
 
+  /// Return true if any node in \p Root satisfies the predicate \p Pred.
+  template <typename PredTy>
+  bool SCEVExprContains(const SCEV *Root, PredTy Pred) {
+    struct FindClosure {
+      bool Found = false;
+      PredTy Pred;
+
+      FindClosure(PredTy Pred) : Pred(Pred) {}
+
+      bool follow(const SCEV *S) {
+        if (!Pred(S))
+          return true;
+
+        Found = true;
+        return false;
+      }
+
+      bool isDone() const { return Found; }
+    };
+
+    FindClosure FC(Pred);
+    visitAll(Root, FC);
+    return FC.Found;
+  }
+
   /// This visitor recursively visits a SCEV expression and re-writes it.
   /// The result from each visit is cached, so it will return the same
   /// SCEV for the same input.
index ce9fade782f7df6dbfce0e24f0260d0c2408b3b6..c93301148178c552c0f4fb8152ad72507037f1a9 100644 (file)
@@ -3356,69 +3356,24 @@ const SCEV *ScalarEvolution::getCouldNotCompute() {
   return CouldNotCompute.get();
 }
 
-
 bool ScalarEvolution::checkValidity(const SCEV *S) const {
-  // Helper class working with SCEVTraversal to figure out if a SCEV contains
-  // a SCEVUnknown with null value-pointer. FindInvalidSCEVUnknown::FindOne
-  // is set iff if find such SCEVUnknown.
-  //
-  struct FindInvalidSCEVUnknown {
-    bool FindOne;
-    FindInvalidSCEVUnknown() { FindOne = false; }
-    bool follow(const SCEV *S) {
-      switch (static_cast<SCEVTypes>(S->getSCEVType())) {
-      case scConstant:
-        return false;
-      case scUnknown:
-        if (!cast<SCEVUnknown>(S)->getValue())
-          FindOne = true;
-        return false;
-      default:
-        return true;
-      }
-    }
-    bool isDone() const { return FindOne; }
-  };
-
-  FindInvalidSCEVUnknown F;
-  SCEVTraversal<FindInvalidSCEVUnknown> ST(F);
-  ST.visitAll(S);
+  bool ContainsNulls = SCEVExprContains(S, [](const SCEV *S) {
+    auto *SU = dyn_cast<SCEVUnknown>(S);
+    return SU && SU->getValue() == nullptr;
+  });
 
-  return !F.FindOne;
+  return !ContainsNulls;
 }
 
 bool ScalarEvolution::containsAddRecurrence(const SCEV *S) {
-  // Helper class working with SCEVTraversal to figure out if a SCEV contains a
-  // sub SCEV of scAddRecExpr type.  FindInvalidSCEVUnknown::FoundOne is set iff
-  // if such sub scAddRecExpr type SCEV is found.
-  struct FindAddRecurrence {
-    bool FoundOne;
-    FindAddRecurrence() : FoundOne(false) {}
-
-    bool follow(const SCEV *S) {
-      switch (static_cast<SCEVTypes>(S->getSCEVType())) {
-      case scAddRecExpr:
-        FoundOne = true;
-      case scConstant:
-      case scUnknown:
-      case scCouldNotCompute:
-        return false;
-      default:
-        return true;
-      }
-    }
-    bool isDone() const { return FoundOne; }
-  };
-
   HasRecMapType::iterator I = HasRecMap.find(S);
   if (I != HasRecMap.end())
     return I->second;
 
-  FindAddRecurrence F;
-  SCEVTraversal<FindAddRecurrence> ST(F);
-  ST.visitAll(S);
-  HasRecMap.insert({S, F.FoundOne});
-  return F.FoundOne;
+  bool FoundAddRec =
+      SCEVExprContains(S, [](const SCEV *S) { return isa<SCEVAddRecExpr>(S); });
+  HasRecMap.insert({S, FoundAddRec});
+  return FoundAddRec;
 }
 
 /// Try to split a SCEVAddExpr into a pair of {SCEV, ConstantInt}.
@@ -8993,38 +8948,15 @@ const SCEV *SCEVAddRecExpr::getNumIterationsInRange(const ConstantRange &Range,
   return SE.getCouldNotCompute();
 }
 
-namespace {
-struct FindUndefs {
-  bool Found;
-  FindUndefs() : Found(false) {}
-
-  bool follow(const SCEV *S) {
-    if (const SCEVUnknown *C = dyn_cast<SCEVUnknown>(S)) {
-      if (isa<UndefValue>(C->getValue()))
-        Found = true;
-    } else if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S)) {
-      if (isa<UndefValue>(C->getValue()))
-        Found = true;
-    }
-
-    // Keep looking if we haven't found it yet.
-    return !Found;
-  }
-  bool isDone() const {
-    // Stop recursion if we have found an undef.
-    return Found;
-  }
-};
-}
-
 // Return true when S contains at least an undef value.
-static inline bool
-containsUndefs(const SCEV *S) {
-  FindUndefs F;
-  SCEVTraversal<FindUndefs> ST(F);
-  ST.visitAll(S);
-
-  return F.Found;
+static inline bool containsUndefs(const SCEV *S) {
+  return SCEVExprContains(S, [](const SCEV *S) {
+    if (const auto *SU = dyn_cast<SCEVUnknown>(S))
+      return isa<UndefValue>(SU->getValue());
+    else if (const auto *SC = dyn_cast<SCEVConstant>(S))
+      return isa<UndefValue>(SC->getValue());
+    return false;
+  });
 }
 
 namespace {
@@ -9217,40 +9149,11 @@ static bool findArrayDimensionsRec(ScalarEvolution &SE,
   return true;
 }
 
-// Returns true when S contains at least a SCEVUnknown parameter.
-static inline bool
-containsParameters(const SCEV *S) {
-  struct FindParameter {
-    bool FoundParameter;
-    FindParameter() : FoundParameter(false) {}
-
-    bool follow(const SCEV *S) {
-      if (isa<SCEVUnknown>(S)) {
-        FoundParameter = true;
-        // Stop recursion: we found a parameter.
-        return false;
-      }
-      // Keep looking.
-      return true;
-    }
-    bool isDone() const {
-      // Stop recursion if we have found a parameter.
-      return FoundParameter;
-    }
-  };
-
-  FindParameter F;
-  SCEVTraversal<FindParameter> ST(F);
-  ST.visitAll(S);
-
-  return F.FoundParameter;
-}
 
 // Returns true when one of the SCEVs of Terms contains a SCEVUnknown parameter.
-static inline bool
-containsParameters(SmallVectorImpl<const SCEV *> &Terms) {
+static inline bool containsParameters(SmallVectorImpl<const SCEV *> &Terms) {
   for (const SCEV *T : Terms)
-    if (containsParameters(T))
+    if (SCEVExprContains(T, [](const SCEV *S) { return isa<SCEVUnknown>(S); }))
       return true;
   return false;
 }
@@ -9977,24 +9880,7 @@ bool ScalarEvolution::properlyDominates(const SCEV *S, const BasicBlock *BB) {
 }
 
 bool ScalarEvolution::hasOperand(const SCEV *S, const SCEV *Op) const {
-  // Search for a SCEV expression node within an expression tree.
-  // Implements SCEVTraversal::Visitor.
-  struct SCEVSearch {
-    const SCEV *Node;
-    bool IsFound;
-
-    SCEVSearch(const SCEV *N): Node(N), IsFound(false) {}
-
-    bool follow(const SCEV *S) {
-      IsFound |= (S == Node);
-      return !IsFound;
-    }
-    bool isDone() const { return IsFound; }
-  };
-
-  SCEVSearch Search(Op);
-  visitAll(S, Search);
-  return Search.IsFound;
+  return SCEVExprContains(S, [&](const SCEV *Expr) { return Expr == Op; });
 }
 
 void ScalarEvolution::forgetMemoizedResults(const SCEV *S) {