From 17520ed1e33177b9dacf0bbfac84378b223e0a14 Mon Sep 17 00:00:00 2001 From: Roman Lebedev Date: Mon, 7 Oct 2019 20:53:00 +0000 Subject: [PATCH] [InstCombine][NFC] dropRedundantMaskingOfLeftShiftInput(): change how we deal with mask Summary: Currently, we pre-check whether we need to produce a mask or not. This involves some rather magical constants. I'd like to extend this fold to also handle the situation when there's also a `trunc` before outer shift. That will require another set of magical constants. It's ugly. Instead, we can just compute the mask, and check whether mask is a pass-through (all-ones) or not. This way we don't need to have any magical numbers. This change is NFC other than the fact that we now compute the mask and then check if we need (and can!) apply it. Reviewers: spatel Reviewed By: spatel Subscribers: hiraditya, llvm-commits Tags: #llvm Differential Revision: https://reviews.llvm.org/D68470 git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@373961 91177308-0d34-0410-b5e6-96231b3b80d8 --- .../InstCombine/InstCombineShifts.cpp | 132 ++++++++---------- 1 file changed, 62 insertions(+), 70 deletions(-) diff --git a/lib/Transforms/InstCombine/InstCombineShifts.cpp b/lib/Transforms/InstCombine/InstCombineShifts.cpp index 6675ab12aee..a325b29afa6 100644 --- a/lib/Transforms/InstCombine/InstCombineShifts.cpp +++ b/lib/Transforms/InstCombine/InstCombineShifts.cpp @@ -181,39 +181,29 @@ dropRedundantMaskingOfLeftShiftInput(BinaryOperator *OuterShift, MaskShAmt, ShiftShAmt, /*IsNSW=*/false, /*IsNUW=*/false, Q)); if (!SumOfShAmts) return nullptr; // Did not simplify. + // In this pattern SumOfShAmts correlates with the number of low bits + // that shall remain in the root value (OuterShift). + Type *Ty = X->getType(); - unsigned BitWidth = Ty->getScalarSizeInBits(); - // In this pattern SumOfShAmts correlates with the number of low bits that - // shall remain in the root value (OuterShift). If SumOfShAmts is less than - // bitwidth, we'll need to also produce a mask to keep SumOfShAmts low bits. - // So, does *any* channel need a mask? - if (!match(SumOfShAmts, m_SpecificInt_ICMP(ICmpInst::Predicate::ICMP_UGE, - APInt(BitWidth, BitWidth)))) { - // But for a mask we need to get rid of old masking instruction. - if (!Masked->hasOneUse()) - return nullptr; // Else we can't perform the fold. - // The mask must be computed in a type twice as wide to ensure - // that no bits are lost if the sum-of-shifts is wider than the base type. - Type *ExtendedTy = Ty->getExtendedType(); - // An extend of an undef value becomes zero because the high bits are - // never completely unknown. Replace the the `undef` shift amounts with - // final shift bitwidth to ensure that the value remains undef when - // creating the subsequent shift op. - SumOfShAmts = replaceUndefsWith( - SumOfShAmts, - ConstantInt::get(SumOfShAmts->getType()->getScalarType(), - ExtendedTy->getScalarType()->getScalarSizeInBits())); - auto *ExtendedSumOfShAmts = - ConstantExpr::getZExt(SumOfShAmts, ExtendedTy); - // And compute the mask as usual: ~(-1 << (SumOfShAmts)) - auto *ExtendedAllOnes = ConstantExpr::getAllOnesValue(ExtendedTy); - auto *ExtendedInvertedMask = - ConstantExpr::getShl(ExtendedAllOnes, ExtendedSumOfShAmts); - auto *ExtendedMask = ConstantExpr::getNot(ExtendedInvertedMask); - NewMask = ConstantExpr::getTrunc(ExtendedMask, Ty); - } else - NewMask = nullptr; // No mask needed. - // All good, we can do this fold. + + // The mask must be computed in a type twice as wide to ensure + // that no bits are lost if the sum-of-shifts is wider than the base type. + Type *ExtendedTy = Ty->getExtendedType(); + // An extend of an undef value becomes zero because the high bits are never + // completely unknown. Replace the the `undef` shift amounts with final + // shift bitwidth to ensure that the value remains undef when creating the + // subsequent shift op. + SumOfShAmts = replaceUndefsWith( + SumOfShAmts, + ConstantInt::get(SumOfShAmts->getType()->getScalarType(), + ExtendedTy->getScalarType()->getScalarSizeInBits())); + auto *ExtendedSumOfShAmts = ConstantExpr::getZExt(SumOfShAmts, ExtendedTy); + // And compute the mask as usual: ~(-1 << (SumOfShAmts)) + auto *ExtendedAllOnes = ConstantExpr::getAllOnesValue(ExtendedTy); + auto *ExtendedInvertedMask = + ConstantExpr::getShl(ExtendedAllOnes, ExtendedSumOfShAmts); + auto *ExtendedMask = ConstantExpr::getNot(ExtendedInvertedMask); + NewMask = ConstantExpr::getTrunc(ExtendedMask, Ty); } else if (match(Masked, m_c_And(m_CombineOr(MaskC, MaskD), m_Value(X))) || match(Masked, m_Shr(m_Shl(m_Value(X), m_Value(MaskShAmt)), m_Deferred(MaskShAmt)))) { @@ -223,49 +213,51 @@ dropRedundantMaskingOfLeftShiftInput(BinaryOperator *OuterShift, if (!ShAmtsDiff) return nullptr; // Did not simplify. // In this pattern ShAmtsDiff correlates with the number of high bits that - // shall be unset in the root value (OuterShift). If ShAmtsDiff is negative, - // we'll need to also produce a mask to unset ShAmtsDiff high bits. - // So, does *any* channel need a mask? (is ShiftShAmt u>= MaskShAmt ?) - if (!match(ShAmtsDiff, m_NonNegative())) { - // This sub-fold (with mask) is invalid for 'ashr' "masking" instruction. - if (match(Masked, m_AShr(m_Value(), m_Value()))) - return nullptr; - // For a mask we need to get rid of old masking instruction. - if (!Masked->hasOneUse()) - return nullptr; // Else we can't perform the fold. - Type *Ty = X->getType(); - unsigned BitWidth = Ty->getScalarSizeInBits(); - // The mask must be computed in a type twice as wide to ensure - // that no bits are lost if the sum-of-shifts is wider than the base type. - Type *ExtendedTy = Ty->getExtendedType(); - // An extend of an undef value becomes zero because the high bits are - // never completely unknown. Replace the the `undef` shift amounts with - // negated shift bitwidth to ensure that the value remains undef when - // creating the subsequent shift op. - ShAmtsDiff = replaceUndefsWith( - ShAmtsDiff, - ConstantInt::get(ShAmtsDiff->getType()->getScalarType(), -BitWidth)); - auto *ExtendedNumHighBitsToClear = ConstantExpr::getZExt( - ConstantExpr::getAdd( - ConstantExpr::getNeg(ShAmtsDiff), - ConstantInt::get(Ty, BitWidth, /*isSigned=*/false)), - ExtendedTy); - // And compute the mask as usual: (-1 l>> (ShAmtsDiff)) - auto *ExtendedAllOnes = ConstantExpr::getAllOnesValue(ExtendedTy); - auto *ExtendedMask = - ConstantExpr::getLShr(ExtendedAllOnes, ExtendedNumHighBitsToClear); - NewMask = ConstantExpr::getTrunc(ExtendedMask, Ty); - } else - NewMask = nullptr; // No mask needed. - // All good, we can do this fold. + // shall be unset in the root value (OuterShift). + + Type *Ty = X->getType(); + unsigned BitWidth = Ty->getScalarSizeInBits(); + + // The mask must be computed in a type twice as wide to ensure + // that no bits are lost if the sum-of-shifts is wider than the base type. + Type *ExtendedTy = Ty->getExtendedType(); + // An extend of an undef value becomes zero because the high bits are never + // completely unknown. Replace the the `undef` shift amounts with negated + // shift bitwidth to ensure that the value remains undef when creating the + // subsequent shift op. + ShAmtsDiff = replaceUndefsWith( + ShAmtsDiff, + ConstantInt::get(ShAmtsDiff->getType()->getScalarType(), -BitWidth)); + auto *ExtendedNumHighBitsToClear = ConstantExpr::getZExt( + ConstantExpr::getSub(ConstantInt::get(ShAmtsDiff->getType(), BitWidth, + /*isSigned=*/false), + ShAmtsDiff), + ExtendedTy); + // And compute the mask as usual: (-1 l>> (NumHighBitsToClear)) + auto *ExtendedAllOnes = ConstantExpr::getAllOnesValue(ExtendedTy); + auto *ExtendedMask = + ConstantExpr::getLShr(ExtendedAllOnes, ExtendedNumHighBitsToClear); + NewMask = ConstantExpr::getTrunc(ExtendedMask, Ty); } else return nullptr; // Don't know anything about this pattern. - // No 'NUW'/'NSW'! - // We no longer know that we won't shift-out non-0 bits. + // Does this mask has any unset bits? If not then we can just not apply it. + bool NeedMask = !match(NewMask, m_AllOnes()); + + // If we need to apply a mask, there are several more restrictions we have. + if (NeedMask) { + // The old masking instruction must go away. + if (!Masked->hasOneUse()) + return nullptr; + // The original "masking" instruction must not have been`ashr`. + if (match(Masked, m_AShr(m_Value(), m_Value()))) + return nullptr; + } + + // No 'NUW'/'NSW'! We no longer know that we won't shift-out non-0 bits. auto *NewShift = BinaryOperator::Create(OuterShift->getOpcode(), X, ShiftShAmt); - if (!NewMask) + if (!NeedMask) return NewShift; Builder.Insert(NewShift); -- 2.40.0