From: Dorit Nuzman Date: Sun, 10 Dec 2017 11:13:35 +0000 (+0000) Subject: [SCEV] Fix wrong Equal predicate created in getAddRecForPhiWithCasts X-Git-Url: https://granicus.if.org/sourcecode?a=commitdiff_plain;h=6ae5d1f67fd8f26fd1a007667b8615035b89a524;p=llvm [SCEV] Fix wrong Equal predicate created in getAddRecForPhiWithCasts CreateAddRecFromPHIWithCastsImpl() adds an IncrementNUSW overflow predicate which allows the PSCEV rewriter to rewrite this scev expression: (zext i8 {0, + , (trunc i32 step to i8)} to i32) into {0, +, (sext i8 (trunc i32 step to i8) to i32)} But then it adds the wrong Equal predicate: %step == (zext i8 (trunc i32 %step to i8) to i32). instead of: %step == (sext i8 (trunc i32 %step to i8) to i32) This is fixed here. Differential Revision: https://reviews.llvm.org/D40641 git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@320298 91177308-0d34-0410-b5e6-96231b3b80d8 --- diff --git a/lib/Analysis/ScalarEvolution.cpp b/lib/Analysis/ScalarEvolution.cpp index 3ed7dde47bb..960bd64830c 100644 --- a/lib/Analysis/ScalarEvolution.cpp +++ b/lib/Analysis/ScalarEvolution.cpp @@ -4636,18 +4636,19 @@ ScalarEvolution::createAddRecFromPHIWithCastsImpl(const SCEVUnknown *SymbolicPHI // Construct the extended SCEV: (Ext ix (Trunc iy (Expr) to ix) to iy) // for each of StartVal and Accum - auto GetExtendedExpr = [&](const SCEV *Expr) -> const SCEV * { + auto getExtendedExpr = [&](const SCEV *Expr, + bool CreateSignExtend) -> const SCEV * { assert(isLoopInvariant(Expr, L) && "Expr is expected to be invariant"); const SCEV *TruncatedExpr = getTruncateExpr(Expr, TruncTy); const SCEV *ExtendedExpr = - Signed ? getSignExtendExpr(TruncatedExpr, Expr->getType()) - : getZeroExtendExpr(TruncatedExpr, Expr->getType()); + CreateSignExtend ? getSignExtendExpr(TruncatedExpr, Expr->getType()) + : getZeroExtendExpr(TruncatedExpr, Expr->getType()); return ExtendedExpr; }; // Given: // ExtendedExpr = (Ext ix (Trunc iy (Expr) to ix) to iy - // = GetExtendedExpr(Expr) + // = getExtendedExpr(Expr) // Determine whether the predicate P: Expr == ExtendedExpr // is known to be false at compile time auto PredIsKnownFalse = [&](const SCEV *Expr, @@ -4656,13 +4657,15 @@ ScalarEvolution::createAddRecFromPHIWithCastsImpl(const SCEVUnknown *SymbolicPHI isKnownPredicate(ICmpInst::ICMP_NE, Expr, ExtendedExpr); }; - const SCEV *StartExtended = GetExtendedExpr(StartVal); + const SCEV *StartExtended = getExtendedExpr(StartVal, Signed); if (PredIsKnownFalse(StartVal, StartExtended)) { DEBUG(dbgs() << "P2 is compile-time false\n";); return None; } - const SCEV *AccumExtended = GetExtendedExpr(Accum); + // The Step is always Signed (because the overflow checks are either + // NSSW or NUSW) + const SCEV *AccumExtended = getExtendedExpr(Accum, /*CreateSignExtend=*/true); if (PredIsKnownFalse(Accum, AccumExtended)) { DEBUG(dbgs() << "P3 is compile-time false\n";); return None; diff --git a/test/Transforms/LoopVectorize/pr30654-phiscev-sext-trunc.ll b/test/Transforms/LoopVectorize/pr30654-phiscev-sext-trunc.ll index 40af8f3adf0..d9c9632be04 100644 --- a/test/Transforms/LoopVectorize/pr30654-phiscev-sext-trunc.ll +++ b/test/Transforms/LoopVectorize/pr30654-phiscev-sext-trunc.ll @@ -74,7 +74,7 @@ for.end: ; Same as above, but for checking the SCEV "zext(trunc(%p.09)) + %step". ; Here we expect the following two predicates to be added for runtime checking: ; 1) {0,+,(trunc i32 %step to i8)}<%for.body> Added Flags: -; 2) Equal predicate: %step == (zext i8 (trunc i32 %step to i8) to i32) +; 2) Equal predicate: %step == (sext i8 (trunc i32 %step to i8) to i32) ; ; int a[N]; ; void doit2(int n, int step) { @@ -93,7 +93,8 @@ for.end: ; CHECK-NOT: %mul = call { i8, i1 } @llvm.umul.with.overflow.i8(i8 {{.*}}, i8 {{.*}}) ; CHECK: %[[TEST:[0-9]+]] = or i1 {{.*}}, %mul.overflow ; CHECK: %[[NTEST:[0-9]+]] = or i1 false, %[[TEST]] -; CHECK: %ident.check = icmp ne i32 {{.*}}, %{{.*}} +; CHECK: %[[EXT:[0-9]+]] = sext i8 {{.*}} to i32 +; CHECK: %ident.check = icmp ne i32 {{.*}}, %[[EXT]] ; CHECK: %{{.*}} = or i1 %[[NTEST]], %ident.check ; CHECK-NOT: %mul = call { i8, i1 } @llvm.umul.with.overflow.i8(i8 {{.*}}, i8 {{.*}}) ; CHECK: vector.body: