#include "gtest/gtest.h"
namespace llvm {
-namespace {
// We use this fixture to ensure that we clean up ScalarEvolution before
// deleting the PassManager.
ScalarEvolution SE = buildSE(*F);
Test(*F, *LI, SE);
}
+
+ static Optional<APInt> computeConstantDifference(ScalarEvolution &SE,
+ const SCEV *LHS,
+ const SCEV *RHS) {
+ return SE.computeConstantDifference(LHS, RHS);
+ }
};
TEST_F(ScalarEvolutionsTest, SCEVUnknownRAUW) {
"} ");
}
-} // end anonymous namespace
+TEST_F(ScalarEvolutionsTest, SCEVComputeConstantDifference) {
+ LLVMContext C;
+ SMDiagnostic Err;
+ std::unique_ptr<Module> M = parseAssemblyString(
+ "define void @foo(i32 %sz, i32 %pp) { "
+ "entry: "
+ " %v0 = add i32 %pp, 0 "
+ " %v3 = add i32 %pp, 3 "
+ " br label %loop.body "
+ "loop.body: "
+ " %iv = phi i32 [ %iv.next, %loop.body ], [ 0, %entry ] "
+ " %xa = add nsw i32 %iv, %v0 "
+ " %yy = add nsw i32 %iv, %v3 "
+ " %xb = sub nsw i32 %yy, 3 "
+ " %iv.next = add nsw i32 %iv, 1 "
+ " %cmp = icmp sle i32 %iv.next, %sz "
+ " br i1 %cmp, label %loop.body, label %exit "
+ "exit: "
+ " ret void "
+ "} ",
+ Err, C);
+
+ ASSERT_TRUE(M && "Could not parse module?");
+ ASSERT_TRUE(!verifyModule(*M) && "Must have been well formed!");
+
+ runWithSE(*M, "foo", [](Function &F, LoopInfo &LI, ScalarEvolution &SE) {
+ auto *ScevV0 = SE.getSCEV(getInstructionByName(F, "v0")); // %pp
+ auto *ScevV3 = SE.getSCEV(getInstructionByName(F, "v3")); // (3 + %pp)
+ auto *ScevIV = SE.getSCEV(getInstructionByName(F, "iv")); // {0,+,1}
+ auto *ScevXA = SE.getSCEV(getInstructionByName(F, "xa")); // {%pp,+,1}
+ auto *ScevYY = SE.getSCEV(getInstructionByName(F, "yy")); // {(3 + %pp),+,1}
+ auto *ScevXB = SE.getSCEV(getInstructionByName(F, "xb")); // {%pp,+,1}
+ auto *ScevIVNext = SE.getSCEV(getInstructionByName(F, "iv.next")); // {1,+,1}
+
+ auto diff = [&SE](const SCEV *LHS, const SCEV *RHS) -> Optional<int> {
+ auto ConstantDiffOrNone = computeConstantDifference(SE, LHS, RHS);
+ if (!ConstantDiffOrNone)
+ return None;
+
+ auto ExtDiff = ConstantDiffOrNone->getSExtValue();
+ int Diff = ExtDiff;
+ assert(Diff == ExtDiff && "Integer overflow");
+ return Diff;
+ };
+
+ EXPECT_EQ(diff(ScevV3, ScevV0), 3);
+ EXPECT_EQ(diff(ScevV0, ScevV3), -3);
+ EXPECT_EQ(diff(ScevV0, ScevV0), 0);
+ EXPECT_EQ(diff(ScevV3, ScevV3), 0);
+ EXPECT_EQ(diff(ScevIV, ScevIV), 0);
+ EXPECT_EQ(diff(ScevXA, ScevXB), 0);
+ EXPECT_EQ(diff(ScevXA, ScevYY), -3);
+ EXPECT_EQ(diff(ScevYY, ScevXB), 3);
+ EXPECT_EQ(diff(ScevIV, ScevIVNext), -1);
+ EXPECT_EQ(diff(ScevIVNext, ScevIV), 1);
+ EXPECT_EQ(diff(ScevIVNext, ScevIVNext), 0);
+ EXPECT_EQ(diff(ScevV0, ScevIV), None);
+ EXPECT_EQ(diff(ScevIVNext, ScevV3), None);
+ EXPECT_EQ(diff(ScevYY, ScevV3), None);
+ });
+}
+
} // end namespace llvm