]> granicus.if.org Git - llvm/commitdiff
[SCEV] Add a local cache for getZeroExtendExpr and getSignExtendExpr to prevent
authorWei Mi <wmi@google.com>
Mon, 17 Apr 2017 20:40:05 +0000 (20:40 +0000)
committerWei Mi <wmi@google.com>
Mon, 17 Apr 2017 20:40:05 +0000 (20:40 +0000)
the exponential behavior.

The patch is to fix PR32043. Functions getZeroExtendExpr and getSignExtendExpr
may call themselves recursively more than once. This is potentially a 2^N
complexity behavior. The exponential behavior was not commonly exposed before
because of existing global cache mechnism like UniqueSCEVs or some early return
mechanism when flags FlagNSW or FlagNUW are seen. However, we still have case
which can expose the exponential behavior, like the case in PR32043, so we add
a local cache in getZeroExtendExpr and getSignExtendExpr. If the input of the
functions -- SCEV and type pair have been seen before, we can find the extended
expression directly in the local cache.

Differential Revision: https://reviews.llvm.org/D30350

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@300494 91177308-0d34-0410-b5e6-96231b3b80d8

include/llvm/Analysis/ScalarEvolution.h
lib/Analysis/ScalarEvolution.cpp
unittests/Analysis/ScalarEvolutionTest.cpp

index 9a50de540f2b1b4d6e205844a714e19b231c1d4d..91aeae0f728f928d3293f08f302dbaef09f59c27 100644 (file)
@@ -1159,8 +1159,20 @@ public:
   const SCEV *getConstant(const APInt &Val);
   const SCEV *getConstant(Type *Ty, uint64_t V, bool isSigned = false);
   const SCEV *getTruncateExpr(const SCEV *Op, Type *Ty);
+
+  typedef SmallDenseMap<std::pair<const SCEV *, Type *>, const SCEV *, 8>
+      ExtendCacheTy;
   const SCEV *getZeroExtendExpr(const SCEV *Op, Type *Ty);
+  const SCEV *getZeroExtendExprCached(const SCEV *Op, Type *Ty,
+                                      ExtendCacheTy &Cache);
+  const SCEV *getZeroExtendExprImpl(const SCEV *Op, Type *Ty,
+                                    ExtendCacheTy &Cache);
+
   const SCEV *getSignExtendExpr(const SCEV *Op, Type *Ty);
+  const SCEV *getSignExtendExprCached(const SCEV *Op, Type *Ty,
+                                      ExtendCacheTy &Cache);
+  const SCEV *getSignExtendExprImpl(const SCEV *Op, Type *Ty,
+                                    ExtendCacheTy &Cache);
   const SCEV *getAnyExtendExpr(const SCEV *Op, Type *Ty);
   const SCEV *getAddExpr(SmallVectorImpl<const SCEV *> &Ops,
                          SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap,
index ca32cf3c7c34292d9d4b7429e832b64f862ff83b..82b91560ccbb5b0f380944d8491501ff61b75f11 100644 (file)
@@ -1276,7 +1276,8 @@ static const SCEV *getUnsignedOverflowLimitForStep(const SCEV *Step,
 namespace {
 
 struct ExtendOpTraitsBase {
-  typedef const SCEV *(ScalarEvolution::*GetExtendExprTy)(const SCEV *, Type *);
+  typedef const SCEV *(ScalarEvolution::*GetExtendExprTy)(
+      const SCEV *, Type *, ScalarEvolution::ExtendCacheTy &Cache);
 };
 
 // Used to make code generic over signed and unsigned overflow.
@@ -1305,8 +1306,9 @@ struct ExtendOpTraits<SCEVSignExtendExpr> : public ExtendOpTraitsBase {
   }
 };
 
-const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits<
-    SCEVSignExtendExpr>::GetExtendExpr = &ScalarEvolution::getSignExtendExpr;
+const ExtendOpTraitsBase::GetExtendExprTy
+    ExtendOpTraits<SCEVSignExtendExpr>::GetExtendExpr =
+        &ScalarEvolution::getSignExtendExprCached;
 
 template <>
 struct ExtendOpTraits<SCEVZeroExtendExpr> : public ExtendOpTraitsBase {
@@ -1321,8 +1323,9 @@ struct ExtendOpTraits<SCEVZeroExtendExpr> : public ExtendOpTraitsBase {
   }
 };
 
-const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits<
-    SCEVZeroExtendExpr>::GetExtendExpr = &ScalarEvolution::getZeroExtendExpr;
+const ExtendOpTraitsBase::GetExtendExprTy
+    ExtendOpTraits<SCEVZeroExtendExpr>::GetExtendExpr =
+        &ScalarEvolution::getZeroExtendExprCached;
 }
 
 // The recurrence AR has been shown to have no signed/unsigned wrap or something
@@ -1334,7 +1337,8 @@ const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits<
 // "sext/zext(PostIncAR)"
 template <typename ExtendOpTy>
 static const SCEV *getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty,
-                                        ScalarEvolution *SE) {
+                                        ScalarEvolution *SE,
+                                        ScalarEvolution::ExtendCacheTy &Cache) {
   auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType;
   auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr;
 
@@ -1381,9 +1385,9 @@ static const SCEV *getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty,
   unsigned BitWidth = SE->getTypeSizeInBits(AR->getType());
   Type *WideTy = IntegerType::get(SE->getContext(), BitWidth * 2);
   const SCEV *OperandExtendedStart =
-      SE->getAddExpr((SE->*GetExtendExpr)(PreStart, WideTy),
-                     (SE->*GetExtendExpr)(Step, WideTy));
-  if ((SE->*GetExtendExpr)(Start, WideTy) == OperandExtendedStart) {
+      SE->getAddExpr((SE->*GetExtendExpr)(PreStart, WideTy, Cache),
+                     (SE->*GetExtendExpr)(Step, WideTy, Cache));
+  if ((SE->*GetExtendExpr)(Start, WideTy, Cache) == OperandExtendedStart) {
     if (PreAR && AR->getNoWrapFlags(WrapType)) {
       // If we know `AR` == {`PreStart`+`Step`,+,`Step`} is `WrapType` (FlagNSW
       // or FlagNUW) and that `PreStart` + `Step` is `WrapType` too, then
@@ -1408,15 +1412,17 @@ static const SCEV *getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty,
 // Get the normalized zero or sign extended expression for this AddRec's Start.
 template <typename ExtendOpTy>
 static const SCEV *getExtendAddRecStart(const SCEVAddRecExpr *AR, Type *Ty,
-                                        ScalarEvolution *SE) {
+                                        ScalarEvolution *SE,
+                                        ScalarEvolution::ExtendCacheTy &Cache) {
   auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr;
 
-  const SCEV *PreStart = getPreStartForExtend<ExtendOpTy>(AR, Ty, SE);
+  const SCEV *PreStart = getPreStartForExtend<ExtendOpTy>(AR, Ty, SE, Cache);
   if (!PreStart)
-    return (SE->*GetExtendExpr)(AR->getStart(), Ty);
+    return (SE->*GetExtendExpr)(AR->getStart(), Ty, Cache);
 
-  return SE->getAddExpr((SE->*GetExtendExpr)(AR->getStepRecurrence(*SE), Ty),
-                        (SE->*GetExtendExpr)(PreStart, Ty));
+  return SE->getAddExpr(
+      (SE->*GetExtendExpr)(AR->getStepRecurrence(*SE), Ty, Cache),
+      (SE->*GetExtendExpr)(PreStart, Ty, Cache));
 }
 
 // Try to prove away overflow by looking at "nearby" add recurrences.  A
@@ -1496,8 +1502,30 @@ bool ScalarEvolution::proveNoWrapByVaryingStart(const SCEV *Start,
   return false;
 }
 
-const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op,
-                                               Type *Ty) {
+const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op, Type *Ty) {
+  // Use the local cache to prevent exponential behavior of
+  // getZeroExtendExprImpl.
+  ExtendCacheTy Cache;
+  return getZeroExtendExprCached(Op, Ty, Cache);
+}
+
+/// Query \p Cache before calling getZeroExtendExprImpl. If there is no
+/// related entry in the \p Cache, call getZeroExtendExprImpl and save
+/// the result in the \p Cache.
+const SCEV *ScalarEvolution::getZeroExtendExprCached(const SCEV *Op, Type *Ty,
+                                                     ExtendCacheTy &Cache) {
+  auto It = Cache.find({Op, Ty});
+  if (It != Cache.end())
+    return It->second;
+  const SCEV *ZExt = getZeroExtendExprImpl(Op, Ty, Cache);
+  auto InsertResult = Cache.insert({{Op, Ty}, ZExt});
+  assert(InsertResult.second && "Expect the key was not in the cache");
+  return ZExt;
+}
+
+/// The real implementation of getZeroExtendExpr.
+const SCEV *ScalarEvolution::getZeroExtendExprImpl(const SCEV *Op, Type *Ty,
+                                                   ExtendCacheTy &Cache) {
   assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
          "This is not an extending conversion!");
   assert(isSCEVable(Ty) &&
@@ -1507,11 +1535,11 @@ const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op,
   // Fold if the operand is constant.
   if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
     return getConstant(
-      cast<ConstantInt>(ConstantExpr::getZExt(SC->getValue(), Ty)));
+        cast<ConstantInt>(ConstantExpr::getZExt(SC->getValue(), Ty)));
 
   // zext(zext(x)) --> zext(x)
   if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
-    return getZeroExtendExpr(SZ->getOperand(), Ty);
+    return getZeroExtendExprCached(SZ->getOperand(), Ty, Cache);
 
   // Before doing any expensive analysis, check to see if we've already
   // computed a SCEV for this Op and Ty.
@@ -1555,8 +1583,8 @@ const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op,
       // we don't need to do any further analysis.
       if (AR->hasNoUnsignedWrap())
         return getAddRecExpr(
-            getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this),
-            getZeroExtendExpr(Step, Ty), L, AR->getNoWrapFlags());
+            getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, Cache),
+            getZeroExtendExprCached(Step, Ty, Cache), L, AR->getNoWrapFlags());
 
       // Check whether the backedge-taken count is SCEVCouldNotCompute.
       // Note that this serves two purposes: It filters out loops that are
@@ -1581,21 +1609,22 @@ const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op,
           Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
           // Check whether Start+Step*MaxBECount has no unsigned overflow.
           const SCEV *ZMul = getMulExpr(CastedMaxBECount, Step);
-          const SCEV *ZAdd = getZeroExtendExpr(getAddExpr(Start, ZMul), WideTy);
-          const SCEV *WideStart = getZeroExtendExpr(Start, WideTy);
+          const SCEV *ZAdd =
+              getZeroExtendExprCached(getAddExpr(Start, ZMul), WideTy, Cache);
+          const SCEV *WideStart = getZeroExtendExprCached(Start, WideTy, Cache);
           const SCEV *WideMaxBECount =
-            getZeroExtendExpr(CastedMaxBECount, WideTy);
-          const SCEV *OperandExtendedAdd =
-            getAddExpr(WideStart,
-                       getMulExpr(WideMaxBECount,
-                                  getZeroExtendExpr(Step, WideTy)));
+              getZeroExtendExprCached(CastedMaxBECount, WideTy, Cache);
+          const SCEV *OperandExtendedAdd = getAddExpr(
+              WideStart, getMulExpr(WideMaxBECount, getZeroExtendExprCached(
+                                                        Step, WideTy, Cache)));
           if (ZAdd == OperandExtendedAdd) {
             // Cache knowledge of AR NUW, which is propagated to this AddRec.
             const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNUW);
             // Return the expression with the addrec on the outside.
             return getAddRecExpr(
-                getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this),
-                getZeroExtendExpr(Step, Ty), L, AR->getNoWrapFlags());
+                getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, Cache),
+                getZeroExtendExprCached(Step, Ty, Cache), L,
+                AR->getNoWrapFlags());
           }
           // Similar to above, only this time treat the step value as signed.
           // This covers loops that count down.
@@ -1609,7 +1638,7 @@ const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op,
             const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNW);
             // Return the expression with the addrec on the outside.
             return getAddRecExpr(
-                getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this),
+                getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, Cache),
                 getSignExtendExpr(Step, Ty), L, AR->getNoWrapFlags());
           }
         }
@@ -1641,8 +1670,9 @@ const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op,
             const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNUW);
             // Return the expression with the addrec on the outside.
             return getAddRecExpr(
-                getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this),
-                getZeroExtendExpr(Step, Ty), L, AR->getNoWrapFlags());
+                getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, Cache),
+                getZeroExtendExprCached(Step, Ty, Cache), L,
+                AR->getNoWrapFlags());
           }
         } else if (isKnownNegative(Step)) {
           const SCEV *N = getConstant(APInt::getMaxValue(BitWidth) -
@@ -1657,7 +1687,7 @@ const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op,
             const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNW);
             // Return the expression with the addrec on the outside.
             return getAddRecExpr(
-                getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this),
+                getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, Cache),
                 getSignExtendExpr(Step, Ty), L, AR->getNoWrapFlags());
           }
         }
@@ -1666,8 +1696,8 @@ const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op,
       if (proveNoWrapByVaryingStart<SCEVZeroExtendExpr>(Start, Step, L)) {
         const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNUW);
         return getAddRecExpr(
-            getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this),
-            getZeroExtendExpr(Step, Ty), L, AR->getNoWrapFlags());
+            getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, Cache),
+            getZeroExtendExprCached(Step, Ty, Cache), L, AR->getNoWrapFlags());
       }
     }
 
@@ -1678,7 +1708,7 @@ const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op,
       // commute the zero extension with the addition operation.
       SmallVector<const SCEV *, 4> Ops;
       for (const auto *Op : SA->operands())
-        Ops.push_back(getZeroExtendExpr(Op, Ty));
+        Ops.push_back(getZeroExtendExprCached(Op, Ty, Cache));
       return getAddExpr(Ops, SCEV::FlagNUW);
     }
   }
@@ -1692,8 +1722,30 @@ const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op,
   return S;
 }
 
-const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op,
-                                               Type *Ty) {
+const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op, Type *Ty) {
+  // Use the local cache to prevent exponential behavior of
+  // getSignExtendExprImpl.
+  ExtendCacheTy Cache;
+  return getSignExtendExprCached(Op, Ty, Cache);
+}
+
+/// Query \p Cache before calling getSignExtendExprImpl. If there is no
+/// related entry in the \p Cache, call getSignExtendExprImpl and save
+/// the result in the \p Cache.
+const SCEV *ScalarEvolution::getSignExtendExprCached(const SCEV *Op, Type *Ty,
+                                                     ExtendCacheTy &Cache) {
+  auto It = Cache.find({Op, Ty});
+  if (It != Cache.end())
+    return It->second;
+  const SCEV *SExt = getSignExtendExprImpl(Op, Ty, Cache);
+  auto InsertResult = Cache.insert({{Op, Ty}, SExt});
+  assert(InsertResult.second && "Expect the key was not in the cache");
+  return SExt;
+}
+
+/// The real implementation of getSignExtendExpr.
+const SCEV *ScalarEvolution::getSignExtendExprImpl(const SCEV *Op, Type *Ty,
+                                                   ExtendCacheTy &Cache) {
   assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
          "This is not an extending conversion!");
   assert(isSCEVable(Ty) &&
@@ -1703,11 +1755,11 @@ const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op,
   // Fold if the operand is constant.
   if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
     return getConstant(
-      cast<ConstantInt>(ConstantExpr::getSExt(SC->getValue(), Ty)));
+        cast<ConstantInt>(ConstantExpr::getSExt(SC->getValue(), Ty)));
 
   // sext(sext(x)) --> sext(x)
   if (const SCEVSignExtendExpr *SS = dyn_cast<SCEVSignExtendExpr>(Op))
-    return getSignExtendExpr(SS->getOperand(), Ty);
+    return getSignExtendExprCached(SS->getOperand(), Ty, Cache);
 
   // sext(zext(x)) --> zext(x)
   if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
@@ -1746,8 +1798,8 @@ const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op,
           const APInt &C2 = SC2->getAPInt();
           if (C1.isStrictlyPositive() && C2.isStrictlyPositive() &&
               C2.ugt(C1) && C2.isPowerOf2())
-            return getAddExpr(getSignExtendExpr(SC1, Ty),
-                              getSignExtendExpr(SMul, Ty));
+            return getAddExpr(getSignExtendExprCached(SC1, Ty, Cache),
+                              getSignExtendExprCached(SMul, Ty, Cache));
         }
       }
     }
@@ -1758,7 +1810,7 @@ const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op,
       // commute the sign extension with the addition operation.
       SmallVector<const SCEV *, 4> Ops;
       for (const auto *Op : SA->operands())
-        Ops.push_back(getSignExtendExpr(Op, Ty));
+        Ops.push_back(getSignExtendExprCached(Op, Ty, Cache));
       return getAddExpr(Ops, SCEV::FlagNSW);
     }
   }
@@ -1782,8 +1834,8 @@ const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op,
       // we don't need to do any further analysis.
       if (AR->hasNoSignedWrap())
         return getAddRecExpr(
-            getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this),
-            getSignExtendExpr(Step, Ty), L, SCEV::FlagNSW);
+            getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Cache),
+            getSignExtendExprCached(Step, Ty, Cache), L, SCEV::FlagNSW);
 
       // Check whether the backedge-taken count is SCEVCouldNotCompute.
       // Note that this serves two purposes: It filters out loops that are
@@ -1808,21 +1860,22 @@ const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op,
           Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
           // Check whether Start+Step*MaxBECount has no signed overflow.
           const SCEV *SMul = getMulExpr(CastedMaxBECount, Step);
-          const SCEV *SAdd = getSignExtendExpr(getAddExpr(Start, SMul), WideTy);
-          const SCEV *WideStart = getSignExtendExpr(Start, WideTy);
+          const SCEV *SAdd =
+              getSignExtendExprCached(getAddExpr(Start, SMul), WideTy, Cache);
+          const SCEV *WideStart = getSignExtendExprCached(Start, WideTy, Cache);
           const SCEV *WideMaxBECount =
-            getZeroExtendExpr(CastedMaxBECount, WideTy);
-          const SCEV *OperandExtendedAdd =
-            getAddExpr(WideStart,
-                       getMulExpr(WideMaxBECount,
-                                  getSignExtendExpr(Step, WideTy)));
+              getZeroExtendExpr(CastedMaxBECount, WideTy);
+          const SCEV *OperandExtendedAdd = getAddExpr(
+              WideStart, getMulExpr(WideMaxBECount, getSignExtendExprCached(
+                                                        Step, WideTy, Cache)));
           if (SAdd == OperandExtendedAdd) {
             // Cache knowledge of AR NSW, which is propagated to this AddRec.
             const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNSW);
             // Return the expression with the addrec on the outside.
             return getAddRecExpr(
-                getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this),
-                getSignExtendExpr(Step, Ty), L, AR->getNoWrapFlags());
+                getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Cache),
+                getSignExtendExprCached(Step, Ty, Cache), L,
+                AR->getNoWrapFlags());
           }
           // Similar to above, only this time treat the step value as unsigned.
           // This covers loops that count up with an unsigned step.
@@ -1843,7 +1896,7 @@ const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op,
 
             // Return the expression with the addrec on the outside.
             return getAddRecExpr(
-                getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this),
+                getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Cache),
                 getZeroExtendExpr(Step, Ty), L, AR->getNoWrapFlags());
           }
         }
@@ -1875,8 +1928,9 @@ const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op,
           // Cache knowledge of AR NSW, then propagate NSW to the wide AddRec.
           const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNSW);
           return getAddRecExpr(
-              getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this),
-              getSignExtendExpr(Step, Ty), L, AR->getNoWrapFlags());
+              getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Cache),
+              getSignExtendExprCached(Step, Ty, Cache), L,
+              AR->getNoWrapFlags());
         }
       }
 
@@ -1890,18 +1944,18 @@ const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op,
         const APInt &C2 = SC2->getAPInt();
         if (C1.isStrictlyPositive() && C2.isStrictlyPositive() && C2.ugt(C1) &&
             C2.isPowerOf2()) {
-          Start = getSignExtendExpr(Start, Ty);
+          Start = getSignExtendExprCached(Start, Ty, Cache);
           const SCEV *NewAR = getAddRecExpr(getZero(AR->getType()), Step, L,
                                             AR->getNoWrapFlags());
-          return getAddExpr(Start, getSignExtendExpr(NewAR, Ty));
+          return getAddExpr(Start, getSignExtendExprCached(NewAR, Ty, Cache));
         }
       }
 
       if (proveNoWrapByVaryingStart<SCEVSignExtendExpr>(Start, Step, L)) {
         const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNSW);
         return getAddRecExpr(
-            getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this),
-            getSignExtendExpr(Step, Ty), L, AR->getNoWrapFlags());
+            getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Cache),
+            getSignExtendExprCached(Step, Ty, Cache), L, AR->getNoWrapFlags());
       }
     }
 
index df9fd4b5ec3307a7a7a0d413a33383b2bb1f6eb7..57369573595787a0158fc78dd8549f5793cd965f 100644 (file)
@@ -666,5 +666,95 @@ TEST_F(ScalarEvolutionsTest, SCEVNormalization) {
   });
 }
 
+// Expect the call of getZeroExtendExpr will not cost exponential time.
+TEST_F(ScalarEvolutionsTest, SCEVZeroExtendExpr) {
+  LLVMContext C;
+  SMDiagnostic Err;
+
+  // Generate a function like below:
+  // define void @foo() {
+  // entry:
+  //   br label %for.cond
+  //
+  // for.cond:
+  //   %0 = phi i64 [ 100, %entry ], [ %dec, %for.inc ]
+  //   %cmp = icmp sgt i64 %0, 90
+  //   br i1 %cmp, label %for.inc, label %for.cond1
+  //
+  // for.inc:
+  //   %dec = add nsw i64 %0, -1
+  //   br label %for.cond
+  //
+  // for.cond1:
+  //   %1 = phi i64 [ 100, %for.cond ], [ %dec5, %for.inc2 ]
+  //   %cmp3 = icmp sgt i64 %1, 90
+  //   br i1 %cmp3, label %for.inc2, label %for.cond4
+  //
+  // for.inc2:
+  //   %dec5 = add nsw i64 %1, -1
+  //   br label %for.cond1
+  //
+  // ......
+  //
+  // for.cond89:
+  //   %19 = phi i64 [ 100, %for.cond84 ], [ %dec94, %for.inc92 ]
+  //   %cmp93 = icmp sgt i64 %19, 90
+  //   br i1 %cmp93, label %for.inc92, label %for.end
+  //
+  // for.inc92:
+  //   %dec94 = add nsw i64 %19, -1
+  //   br label %for.cond89
+  //
+  // for.end:
+  //   %gep = getelementptr i8, i8* null, i64 %dec
+  //   %gep6 = getelementptr i8, i8* %gep, i64 %dec5
+  //   ......
+  //   %gep95 = getelementptr i8, i8* %gep91, i64 %dec94
+  //   ret void
+  // }
+  FunctionType *FTy = FunctionType::get(Type::getVoidTy(Context), {}, false);
+  Function *F = cast<Function>(M.getOrInsertFunction("foo", FTy));
+
+  BasicBlock *EntryBB = BasicBlock::Create(Context, "entry", F);
+  BasicBlock *CondBB = BasicBlock::Create(Context, "for.cond", F);
+  BasicBlock *EndBB = BasicBlock::Create(Context, "for.end", F);
+  BranchInst::Create(CondBB, EntryBB);
+  BasicBlock *PrevBB = EntryBB;
+
+  Type *I64Ty = Type::getInt64Ty(Context);
+  Type *I8Ty = Type::getInt8Ty(Context);
+  Type *I8PtrTy = Type::getInt8PtrTy(Context);
+  Value *Accum = Constant::getNullValue(I8PtrTy);
+  int Iters = 20;
+  for (int i = 0; i < Iters; i++) {
+    BasicBlock *IncBB = BasicBlock::Create(Context, "for.inc", F, EndBB);
+    auto *PN = PHINode::Create(I64Ty, 2, "", CondBB);
+    PN->addIncoming(ConstantInt::get(Context, APInt(64, 100)), PrevBB);
+    auto *Cmp = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_SGT, PN,
+                                ConstantInt::get(Context, APInt(64, 90)), "cmp",
+                                CondBB);
+    BasicBlock *NextBB;
+    if (i != Iters - 1)
+      NextBB = BasicBlock::Create(Context, "for.cond", F, EndBB);
+    else
+      NextBB = EndBB;
+    BranchInst::Create(IncBB, NextBB, Cmp, CondBB);
+    auto *Dec = BinaryOperator::CreateNSWAdd(
+        PN, ConstantInt::get(Context, APInt(64, -1)), "dec", IncBB);
+    PN->addIncoming(Dec, IncBB);
+    BranchInst::Create(CondBB, IncBB);
+
+    Accum = GetElementPtrInst::Create(I8Ty, Accum, Dec, "gep", EndBB);
+
+    PrevBB = CondBB;
+    CondBB = NextBB;
+  }
+  ReturnInst::Create(Context, nullptr, EndBB);
+  ScalarEvolution SE = buildSE(*F);
+  const SCEV *S = SE.getSCEV(Accum);
+  Type *I128Ty = Type::getInt128Ty(Context);
+  SE.getZeroExtendExpr(S, I128Ty);
+}
+
 }  // end anonymous namespace
 }  // end namespace llvm