]> granicus.if.org Git - llvm/commitdiff
[InstCombine][NFC] dropRedundantMaskingOfLeftShiftInput(): change how we deal with...
authorRoman Lebedev <lebedev.ri@gmail.com>
Mon, 7 Oct 2019 20:53:00 +0000 (20:53 +0000)
committerRoman Lebedev <lebedev.ri@gmail.com>
Mon, 7 Oct 2019 20:53:00 +0000 (20:53 +0000)
Summary:
Currently, we pre-check whether we need to produce a mask or not.
This involves some rather magical constants.
I'd like to extend this fold to also handle the situation
when there's also a `trunc` before outer shift.
That will require another set of magical constants.
It's ugly.

Instead, we can just compute the mask, and check
whether mask is a pass-through (all-ones) or not.
This way we don't need to have any magical numbers.

This change is NFC other than the fact that we now compute
the mask and then check if we need (and can!) apply it.

Reviewers: spatel

Reviewed By: spatel

Subscribers: hiraditya, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D68470

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@373961 91177308-0d34-0410-b5e6-96231b3b80d8

lib/Transforms/InstCombine/InstCombineShifts.cpp

index 6675ab12aeecc90b9eef6c588a8b77c3bbd47a2d..a325b29afa63d7860c1db5babcc4c219e1f5bdf4 100644 (file)
@@ -181,39 +181,29 @@ dropRedundantMaskingOfLeftShiftInput(BinaryOperator *OuterShift,
         MaskShAmt, ShiftShAmt, /*IsNSW=*/false, /*IsNUW=*/false, Q));
     if (!SumOfShAmts)
       return nullptr; // Did not simplify.
+    // In this pattern SumOfShAmts correlates with the number of low bits
+    // that shall remain in the root value (OuterShift).
+
     Type *Ty = X->getType();
-    unsigned BitWidth = Ty->getScalarSizeInBits();
-    // In this pattern SumOfShAmts correlates with the number of low bits that
-    // shall remain in the root value (OuterShift). If SumOfShAmts is less than
-    // bitwidth, we'll need to also produce a mask to keep SumOfShAmts low bits.
-    // So, does *any* channel need a mask?
-    if (!match(SumOfShAmts, m_SpecificInt_ICMP(ICmpInst::Predicate::ICMP_UGE,
-                                               APInt(BitWidth, BitWidth)))) {
-      // But for a mask we need to get rid of old masking instruction.
-      if (!Masked->hasOneUse())
-        return nullptr; // Else we can't perform the fold.
-      // The mask must be computed in a type twice as wide to ensure
-      // that no bits are lost if the sum-of-shifts is wider than the base type.
-      Type *ExtendedTy = Ty->getExtendedType();
-      // An extend of an undef value becomes zero because the high bits are
-      // never completely unknown. Replace the the `undef` shift amounts with
-      // final shift bitwidth to ensure that the value remains undef when
-      // creating the subsequent shift op.
-      SumOfShAmts = replaceUndefsWith(
-          SumOfShAmts,
-          ConstantInt::get(SumOfShAmts->getType()->getScalarType(),
-                           ExtendedTy->getScalarType()->getScalarSizeInBits()));
-      auto *ExtendedSumOfShAmts =
-          ConstantExpr::getZExt(SumOfShAmts, ExtendedTy);
-      // And compute the mask as usual: ~(-1 << (SumOfShAmts))
-      auto *ExtendedAllOnes = ConstantExpr::getAllOnesValue(ExtendedTy);
-      auto *ExtendedInvertedMask =
-          ConstantExpr::getShl(ExtendedAllOnes, ExtendedSumOfShAmts);
-      auto *ExtendedMask = ConstantExpr::getNot(ExtendedInvertedMask);
-      NewMask = ConstantExpr::getTrunc(ExtendedMask, Ty);
-    } else
-      NewMask = nullptr; // No mask needed.
-    // All good, we can do this fold.
+
+    // The mask must be computed in a type twice as wide to ensure
+    // that no bits are lost if the sum-of-shifts is wider than the base type.
+    Type *ExtendedTy = Ty->getExtendedType();
+    // An extend of an undef value becomes zero because the high bits are never
+    // completely unknown. Replace the the `undef` shift amounts with final
+    // shift bitwidth to ensure that the value remains undef when creating the
+    // subsequent shift op.
+    SumOfShAmts = replaceUndefsWith(
+        SumOfShAmts,
+        ConstantInt::get(SumOfShAmts->getType()->getScalarType(),
+                         ExtendedTy->getScalarType()->getScalarSizeInBits()));
+    auto *ExtendedSumOfShAmts = ConstantExpr::getZExt(SumOfShAmts, ExtendedTy);
+    // And compute the mask as usual: ~(-1 << (SumOfShAmts))
+    auto *ExtendedAllOnes = ConstantExpr::getAllOnesValue(ExtendedTy);
+    auto *ExtendedInvertedMask =
+        ConstantExpr::getShl(ExtendedAllOnes, ExtendedSumOfShAmts);
+    auto *ExtendedMask = ConstantExpr::getNot(ExtendedInvertedMask);
+    NewMask = ConstantExpr::getTrunc(ExtendedMask, Ty);
   } else if (match(Masked, m_c_And(m_CombineOr(MaskC, MaskD), m_Value(X))) ||
              match(Masked, m_Shr(m_Shl(m_Value(X), m_Value(MaskShAmt)),
                                  m_Deferred(MaskShAmt)))) {
@@ -223,49 +213,51 @@ dropRedundantMaskingOfLeftShiftInput(BinaryOperator *OuterShift,
     if (!ShAmtsDiff)
       return nullptr; // Did not simplify.
     // In this pattern ShAmtsDiff correlates with the number of high bits that
-    // shall be unset in the root value (OuterShift). If ShAmtsDiff is negative,
-    // we'll need to also produce a mask to unset ShAmtsDiff high bits.
-    // So, does *any* channel need a mask? (is ShiftShAmt u>= MaskShAmt ?)
-    if (!match(ShAmtsDiff, m_NonNegative())) {
-      // This sub-fold (with mask) is invalid for 'ashr' "masking" instruction.
-      if (match(Masked, m_AShr(m_Value(), m_Value())))
-        return nullptr;
-      // For a mask we need to get rid of old masking instruction.
-      if (!Masked->hasOneUse())
-        return nullptr; // Else we can't perform the fold.
-      Type *Ty = X->getType();
-      unsigned BitWidth = Ty->getScalarSizeInBits();
-      // The mask must be computed in a type twice as wide to ensure
-      // that no bits are lost if the sum-of-shifts is wider than the base type.
-      Type *ExtendedTy = Ty->getExtendedType();
-      // An extend of an undef value becomes zero because the high bits are
-      // never completely unknown. Replace the the `undef` shift amounts with
-      // negated shift bitwidth to ensure that the value remains undef when
-      // creating the subsequent shift op.
-      ShAmtsDiff = replaceUndefsWith(
-          ShAmtsDiff,
-          ConstantInt::get(ShAmtsDiff->getType()->getScalarType(), -BitWidth));
-      auto *ExtendedNumHighBitsToClear = ConstantExpr::getZExt(
-          ConstantExpr::getAdd(
-              ConstantExpr::getNeg(ShAmtsDiff),
-              ConstantInt::get(Ty, BitWidth, /*isSigned=*/false)),
-          ExtendedTy);
-      // And compute the mask as usual: (-1 l>> (ShAmtsDiff))
-      auto *ExtendedAllOnes = ConstantExpr::getAllOnesValue(ExtendedTy);
-      auto *ExtendedMask =
-          ConstantExpr::getLShr(ExtendedAllOnes, ExtendedNumHighBitsToClear);
-      NewMask = ConstantExpr::getTrunc(ExtendedMask, Ty);
-    } else
-      NewMask = nullptr; // No mask needed.
-    // All good, we can do this fold.
+    // shall be unset in the root value (OuterShift).
+
+    Type *Ty = X->getType();
+    unsigned BitWidth = Ty->getScalarSizeInBits();
+
+    // The mask must be computed in a type twice as wide to ensure
+    // that no bits are lost if the sum-of-shifts is wider than the base type.
+    Type *ExtendedTy = Ty->getExtendedType();
+    // An extend of an undef value becomes zero because the high bits are never
+    // completely unknown. Replace the the `undef` shift amounts with negated
+    // shift bitwidth to ensure that the value remains undef when creating the
+    // subsequent shift op.
+    ShAmtsDiff = replaceUndefsWith(
+        ShAmtsDiff,
+        ConstantInt::get(ShAmtsDiff->getType()->getScalarType(), -BitWidth));
+    auto *ExtendedNumHighBitsToClear = ConstantExpr::getZExt(
+        ConstantExpr::getSub(ConstantInt::get(ShAmtsDiff->getType(), BitWidth,
+                                              /*isSigned=*/false),
+                             ShAmtsDiff),
+        ExtendedTy);
+    // And compute the mask as usual: (-1 l>> (NumHighBitsToClear))
+    auto *ExtendedAllOnes = ConstantExpr::getAllOnesValue(ExtendedTy);
+    auto *ExtendedMask =
+        ConstantExpr::getLShr(ExtendedAllOnes, ExtendedNumHighBitsToClear);
+    NewMask = ConstantExpr::getTrunc(ExtendedMask, Ty);
   } else
     return nullptr; // Don't know anything about this pattern.
 
-  // No 'NUW'/'NSW'!
-  // We no longer know that we won't shift-out non-0 bits.
+  // Does this mask has any unset bits? If not then we can just not apply it.
+  bool NeedMask = !match(NewMask, m_AllOnes());
+
+  // If we need to apply a mask, there are several more restrictions we have.
+  if (NeedMask) {
+    // The old masking instruction must go away.
+    if (!Masked->hasOneUse())
+      return nullptr;
+    // The original "masking" instruction must not have been`ashr`.
+    if (match(Masked, m_AShr(m_Value(), m_Value())))
+      return nullptr;
+  }
+
+  // No 'NUW'/'NSW'! We no longer know that we won't shift-out non-0 bits.
   auto *NewShift =
       BinaryOperator::Create(OuterShift->getOpcode(), X, ShiftShAmt);
-  if (!NewMask)
+  if (!NeedMask)
     return NewShift;
 
   Builder.Insert(NewShift);