]> granicus.if.org Git - llvm/commitdiff
[SCEV] Return zero from computeConstantDifference(X, X)
authorNikolai Bozhenov <nikolai.bozhenov@intel.com>
Wed, 7 Aug 2019 17:38:38 +0000 (17:38 +0000)
committerNikolai Bozhenov <nikolai.bozhenov@intel.com>
Wed, 7 Aug 2019 17:38:38 +0000 (17:38 +0000)
Without this patch computeConstantDifference returns None for cases like
these:

  computeConstantDifference(%x, %x)
  computeConstantDifference({%x,+,16}, {%x,+,16})

Differential Revision: https://reviews.llvm.org/D65474

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@368193 91177308-0d34-0410-b5e6-96231b3b80d8

include/llvm/Analysis/ScalarEvolution.h
lib/Analysis/ScalarEvolution.cpp
unittests/Analysis/ScalarEvolutionTest.cpp

index 0bd98ef37e7ab04b0a3ee195586c7d19458ca33e..eb32991ec12fa56052a3290ab650e3e6088aea29 100644 (file)
@@ -468,6 +468,8 @@ template <> struct DenseMapInfo<ExitLimitQuery> {
 /// can't do much with the SCEV objects directly, they must ask this class
 /// for services.
 class ScalarEvolution {
+  friend class ScalarEvolutionsTest;
+
 public:
   /// An enum describing the relationship between a SCEV and a loop.
   enum LoopDisposition {
index bc2cfd6fcc42c8cb375c78c07ff528f4b3608989..8552d784c435e05433e57fb7dbab0e7c4bfadc6c 100644 (file)
@@ -9833,6 +9833,10 @@ Optional<APInt> ScalarEvolution::computeConstantDifference(const SCEV *More,
   // We avoid subtracting expressions here because this function is usually
   // fairly deep in the call stack (i.e. is called many times).
 
+  // X - X = 0.
+  if (More == Less)
+    return APInt(getTypeSizeInBits(More->getType()), 0);
+
   if (isa<SCEVAddRecExpr>(Less) && isa<SCEVAddRecExpr>(More)) {
     const auto *LAR = cast<SCEVAddRecExpr>(Less);
     const auto *MAR = cast<SCEVAddRecExpr>(More);
index fb8f6689ca7c2061f4e992bcf7914764184a00b2..42c708da6d4fe2aeb2d05d6d173bee44dace7a26 100644 (file)
@@ -26,7 +26,6 @@
 #include "gtest/gtest.h"
 
 namespace llvm {
-namespace {
 
 // We use this fixture to ensure that we clean up ScalarEvolution before
 // deleting the PassManager.
@@ -58,6 +57,12 @@ protected:
     ScalarEvolution SE = buildSE(*F);
     Test(*F, *LI, SE);
   }
+
+  static Optional<APInt> computeConstantDifference(ScalarEvolution &SE,
+                                                   const SCEV *LHS,
+                                                   const SCEV *RHS) {
+    return SE.computeConstantDifference(LHS, RHS);
+  }
 };
 
 TEST_F(ScalarEvolutionsTest, SCEVUnknownRAUW) {
@@ -1678,5 +1683,66 @@ TEST_F(ScalarEvolutionsTest, SCEVExpanderShlNSW) {
                "} ");
 }
 
-}  // end anonymous namespace
+TEST_F(ScalarEvolutionsTest, SCEVComputeConstantDifference) {
+  LLVMContext C;
+  SMDiagnostic Err;
+  std::unique_ptr<Module> M = parseAssemblyString(
+      "define void @foo(i32 %sz, i32 %pp) { "
+      "entry: "
+      "  %v0 = add i32 %pp, 0 "
+      "  %v3 = add i32 %pp, 3 "
+      "  br label %loop.body "
+      "loop.body: "
+      "  %iv = phi i32 [ %iv.next, %loop.body ], [ 0, %entry ] "
+      "  %xa = add nsw i32 %iv, %v0 "
+      "  %yy = add nsw i32 %iv, %v3 "
+      "  %xb = sub nsw i32 %yy, 3 "
+      "  %iv.next = add nsw i32 %iv, 1 "
+      "  %cmp = icmp sle i32 %iv.next, %sz "
+      "  br i1 %cmp, label %loop.body, label %exit "
+      "exit: "
+      "  ret void "
+      "} ",
+      Err, C);
+
+  ASSERT_TRUE(M && "Could not parse module?");
+  ASSERT_TRUE(!verifyModule(*M) && "Must have been well formed!");
+
+  runWithSE(*M, "foo", [](Function &F, LoopInfo &LI, ScalarEvolution &SE) {
+    auto *ScevV0 = SE.getSCEV(getInstructionByName(F, "v0")); // %pp
+    auto *ScevV3 = SE.getSCEV(getInstructionByName(F, "v3")); // (3 + %pp)
+    auto *ScevIV = SE.getSCEV(getInstructionByName(F, "iv")); // {0,+,1}
+    auto *ScevXA = SE.getSCEV(getInstructionByName(F, "xa")); // {%pp,+,1}
+    auto *ScevYY = SE.getSCEV(getInstructionByName(F, "yy")); // {(3 + %pp),+,1}
+    auto *ScevXB = SE.getSCEV(getInstructionByName(F, "xb")); // {%pp,+,1}
+    auto *ScevIVNext = SE.getSCEV(getInstructionByName(F, "iv.next")); // {1,+,1}
+
+    auto diff = [&SE](const SCEV *LHS, const SCEV *RHS) -> Optional<int> {
+      auto ConstantDiffOrNone = computeConstantDifference(SE, LHS, RHS);
+      if (!ConstantDiffOrNone)
+        return None;
+
+      auto ExtDiff = ConstantDiffOrNone->getSExtValue();
+      int Diff = ExtDiff;
+      assert(Diff == ExtDiff && "Integer overflow");
+      return Diff;
+    };
+
+    EXPECT_EQ(diff(ScevV3, ScevV0), 3);
+    EXPECT_EQ(diff(ScevV0, ScevV3), -3);
+    EXPECT_EQ(diff(ScevV0, ScevV0), 0);
+    EXPECT_EQ(diff(ScevV3, ScevV3), 0);
+    EXPECT_EQ(diff(ScevIV, ScevIV), 0);
+    EXPECT_EQ(diff(ScevXA, ScevXB), 0);
+    EXPECT_EQ(diff(ScevXA, ScevYY), -3);
+    EXPECT_EQ(diff(ScevYY, ScevXB), 3);
+    EXPECT_EQ(diff(ScevIV, ScevIVNext), -1);
+    EXPECT_EQ(diff(ScevIVNext, ScevIV), 1);
+    EXPECT_EQ(diff(ScevIVNext, ScevIVNext), 0);
+    EXPECT_EQ(diff(ScevV0, ScevIV), None);
+    EXPECT_EQ(diff(ScevIVNext, ScevV3), None);
+    EXPECT_EQ(diff(ScevYY, ScevV3), None);
+  });
+}
+
 }  // end namespace llvm