MaskShAmt, ShiftShAmt, /*IsNSW=*/false, /*IsNUW=*/false, Q));
if (!SumOfShAmts)
return nullptr; // Did not simplify.
+ // In this pattern SumOfShAmts correlates with the number of low bits
+ // that shall remain in the root value (OuterShift).
+
Type *Ty = X->getType();
- unsigned BitWidth = Ty->getScalarSizeInBits();
- // In this pattern SumOfShAmts correlates with the number of low bits that
- // shall remain in the root value (OuterShift). If SumOfShAmts is less than
- // bitwidth, we'll need to also produce a mask to keep SumOfShAmts low bits.
- // So, does *any* channel need a mask?
- if (!match(SumOfShAmts, m_SpecificInt_ICMP(ICmpInst::Predicate::ICMP_UGE,
- APInt(BitWidth, BitWidth)))) {
- // But for a mask we need to get rid of old masking instruction.
- if (!Masked->hasOneUse())
- return nullptr; // Else we can't perform the fold.
- // The mask must be computed in a type twice as wide to ensure
- // that no bits are lost if the sum-of-shifts is wider than the base type.
- Type *ExtendedTy = Ty->getExtendedType();
- // An extend of an undef value becomes zero because the high bits are
- // never completely unknown. Replace the the `undef` shift amounts with
- // final shift bitwidth to ensure that the value remains undef when
- // creating the subsequent shift op.
- SumOfShAmts = replaceUndefsWith(
- SumOfShAmts,
- ConstantInt::get(SumOfShAmts->getType()->getScalarType(),
- ExtendedTy->getScalarType()->getScalarSizeInBits()));
- auto *ExtendedSumOfShAmts =
- ConstantExpr::getZExt(SumOfShAmts, ExtendedTy);
- // And compute the mask as usual: ~(-1 << (SumOfShAmts))
- auto *ExtendedAllOnes = ConstantExpr::getAllOnesValue(ExtendedTy);
- auto *ExtendedInvertedMask =
- ConstantExpr::getShl(ExtendedAllOnes, ExtendedSumOfShAmts);
- auto *ExtendedMask = ConstantExpr::getNot(ExtendedInvertedMask);
- NewMask = ConstantExpr::getTrunc(ExtendedMask, Ty);
- } else
- NewMask = nullptr; // No mask needed.
- // All good, we can do this fold.
+
+ // The mask must be computed in a type twice as wide to ensure
+ // that no bits are lost if the sum-of-shifts is wider than the base type.
+ Type *ExtendedTy = Ty->getExtendedType();
+ // An extend of an undef value becomes zero because the high bits are never
+ // completely unknown. Replace the the `undef` shift amounts with final
+ // shift bitwidth to ensure that the value remains undef when creating the
+ // subsequent shift op.
+ SumOfShAmts = replaceUndefsWith(
+ SumOfShAmts,
+ ConstantInt::get(SumOfShAmts->getType()->getScalarType(),
+ ExtendedTy->getScalarType()->getScalarSizeInBits()));
+ auto *ExtendedSumOfShAmts = ConstantExpr::getZExt(SumOfShAmts, ExtendedTy);
+ // And compute the mask as usual: ~(-1 << (SumOfShAmts))
+ auto *ExtendedAllOnes = ConstantExpr::getAllOnesValue(ExtendedTy);
+ auto *ExtendedInvertedMask =
+ ConstantExpr::getShl(ExtendedAllOnes, ExtendedSumOfShAmts);
+ auto *ExtendedMask = ConstantExpr::getNot(ExtendedInvertedMask);
+ NewMask = ConstantExpr::getTrunc(ExtendedMask, Ty);
} else if (match(Masked, m_c_And(m_CombineOr(MaskC, MaskD), m_Value(X))) ||
match(Masked, m_Shr(m_Shl(m_Value(X), m_Value(MaskShAmt)),
m_Deferred(MaskShAmt)))) {
if (!ShAmtsDiff)
return nullptr; // Did not simplify.
// In this pattern ShAmtsDiff correlates with the number of high bits that
- // shall be unset in the root value (OuterShift). If ShAmtsDiff is negative,
- // we'll need to also produce a mask to unset ShAmtsDiff high bits.
- // So, does *any* channel need a mask? (is ShiftShAmt u>= MaskShAmt ?)
- if (!match(ShAmtsDiff, m_NonNegative())) {
- // This sub-fold (with mask) is invalid for 'ashr' "masking" instruction.
- if (match(Masked, m_AShr(m_Value(), m_Value())))
- return nullptr;
- // For a mask we need to get rid of old masking instruction.
- if (!Masked->hasOneUse())
- return nullptr; // Else we can't perform the fold.
- Type *Ty = X->getType();
- unsigned BitWidth = Ty->getScalarSizeInBits();
- // The mask must be computed in a type twice as wide to ensure
- // that no bits are lost if the sum-of-shifts is wider than the base type.
- Type *ExtendedTy = Ty->getExtendedType();
- // An extend of an undef value becomes zero because the high bits are
- // never completely unknown. Replace the the `undef` shift amounts with
- // negated shift bitwidth to ensure that the value remains undef when
- // creating the subsequent shift op.
- ShAmtsDiff = replaceUndefsWith(
- ShAmtsDiff,
- ConstantInt::get(ShAmtsDiff->getType()->getScalarType(), -BitWidth));
- auto *ExtendedNumHighBitsToClear = ConstantExpr::getZExt(
- ConstantExpr::getAdd(
- ConstantExpr::getNeg(ShAmtsDiff),
- ConstantInt::get(Ty, BitWidth, /*isSigned=*/false)),
- ExtendedTy);
- // And compute the mask as usual: (-1 l>> (ShAmtsDiff))
- auto *ExtendedAllOnes = ConstantExpr::getAllOnesValue(ExtendedTy);
- auto *ExtendedMask =
- ConstantExpr::getLShr(ExtendedAllOnes, ExtendedNumHighBitsToClear);
- NewMask = ConstantExpr::getTrunc(ExtendedMask, Ty);
- } else
- NewMask = nullptr; // No mask needed.
- // All good, we can do this fold.
+ // shall be unset in the root value (OuterShift).
+
+ Type *Ty = X->getType();
+ unsigned BitWidth = Ty->getScalarSizeInBits();
+
+ // The mask must be computed in a type twice as wide to ensure
+ // that no bits are lost if the sum-of-shifts is wider than the base type.
+ Type *ExtendedTy = Ty->getExtendedType();
+ // An extend of an undef value becomes zero because the high bits are never
+ // completely unknown. Replace the the `undef` shift amounts with negated
+ // shift bitwidth to ensure that the value remains undef when creating the
+ // subsequent shift op.
+ ShAmtsDiff = replaceUndefsWith(
+ ShAmtsDiff,
+ ConstantInt::get(ShAmtsDiff->getType()->getScalarType(), -BitWidth));
+ auto *ExtendedNumHighBitsToClear = ConstantExpr::getZExt(
+ ConstantExpr::getSub(ConstantInt::get(ShAmtsDiff->getType(), BitWidth,
+ /*isSigned=*/false),
+ ShAmtsDiff),
+ ExtendedTy);
+ // And compute the mask as usual: (-1 l>> (NumHighBitsToClear))
+ auto *ExtendedAllOnes = ConstantExpr::getAllOnesValue(ExtendedTy);
+ auto *ExtendedMask =
+ ConstantExpr::getLShr(ExtendedAllOnes, ExtendedNumHighBitsToClear);
+ NewMask = ConstantExpr::getTrunc(ExtendedMask, Ty);
} else
return nullptr; // Don't know anything about this pattern.
- // No 'NUW'/'NSW'!
- // We no longer know that we won't shift-out non-0 bits.
+ // Does this mask has any unset bits? If not then we can just not apply it.
+ bool NeedMask = !match(NewMask, m_AllOnes());
+
+ // If we need to apply a mask, there are several more restrictions we have.
+ if (NeedMask) {
+ // The old masking instruction must go away.
+ if (!Masked->hasOneUse())
+ return nullptr;
+ // The original "masking" instruction must not have been`ashr`.
+ if (match(Masked, m_AShr(m_Value(), m_Value())))
+ return nullptr;
+ }
+
+ // No 'NUW'/'NSW'! We no longer know that we won't shift-out non-0 bits.
auto *NewShift =
BinaryOperator::Create(OuterShift->getOpcode(), X, ShiftShAmt);
- if (!NewMask)
+ if (!NeedMask)
return NewShift;
Builder.Insert(NewShift);