From 0f258723e9769f9f1769ad8e36d7f1a9099ca073 Mon Sep 17 00:00:00 2001 From: Sanjay Patel Date: Mon, 16 Jan 2017 17:27:50 +0000 Subject: [PATCH] [InstCombine] refactor shift-of-shift folds; NFCI Reduces code duplication and makes it easier to extend these folds for vectors. git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@292145 91177308-0d34-0410-b5e6-96231b3b80d8 --- .../InstCombine/InstCombineShifts.cpp | 149 ++++++++---------- 1 file changed, 66 insertions(+), 83 deletions(-) diff --git a/lib/Transforms/InstCombine/InstCombineShifts.cpp b/lib/Transforms/InstCombine/InstCombineShifts.cpp index 63216423f75..8ccbd5cf33d 100644 --- a/lib/Transforms/InstCombine/InstCombineShifts.cpp +++ b/lib/Transforms/InstCombine/InstCombineShifts.cpp @@ -190,6 +190,68 @@ static bool canEvaluateShifted(Value *V, unsigned NumBits, bool IsLeftShift, } } +/// Fold OuterShift (InnerShift X, C1), C2. +/// See canEvaluateShiftedShift() for the constraints on these instructions. +static Value *foldShiftedShift(BinaryOperator *InnerShift, unsigned OuterShAmt, + bool IsOuterShl, + InstCombiner::BuilderTy &Builder) { + bool IsInnerShl = InnerShift->getOpcode() == Instruction::Shl; + Type *ShType = InnerShift->getType(); + unsigned TypeWidth = ShType->getScalarSizeInBits(); + + // We only accept shifts-by-a-constant in canEvaluateShifted(). + ConstantInt *C1 = cast(InnerShift->getOperand(1)); + unsigned InnerShAmt = C1->getZExtValue(); + + // Change the shift amount and clear the appropriate IR flags. + auto NewInnerShift = [&](unsigned ShAmt) { + InnerShift->setOperand(1, ConstantInt::get(ShType, ShAmt)); + if (IsInnerShl) { + InnerShift->setHasNoUnsignedWrap(false); + InnerShift->setHasNoSignedWrap(false); + } else { + InnerShift->setIsExact(false); + } + return InnerShift; + }; + + // Two logical shifts in the same direction: + // shl (shl X, C1), C2 --> shl X, C1 + C2 + // lshr (lshr X, C1), C2 --> lshr X, C1 + C2 + if (IsInnerShl == IsOuterShl) { + // If this is an oversized composite shift, then unsigned shifts get 0. + if (InnerShAmt + OuterShAmt >= TypeWidth) + return Constant::getNullValue(ShType); + + return NewInnerShift(InnerShAmt + OuterShAmt); + } + + // Equal shift amounts in opposite directions become bitwise 'and': + // lshr (shl X, C), C --> and X, C' + // shl (lshr X, C), C --> and X, C' + if (InnerShAmt == OuterShAmt) { + APInt Mask = IsInnerShl + ? APInt::getLowBitsSet(TypeWidth, TypeWidth - OuterShAmt) + : APInt::getHighBitsSet(TypeWidth, TypeWidth - OuterShAmt); + Value *And = Builder.CreateAnd(InnerShift->getOperand(0), + ConstantInt::get(ShType, Mask)); + if (auto *AndI = dyn_cast(And)) { + AndI->moveBefore(InnerShift); + AndI->takeName(InnerShift); + } + return And; + } + + assert(InnerShAmt > OuterShAmt && + "Unexpected opposite direction logical shift pair"); + + // In general, we would need an 'and' for this transform, but + // canEvaluateShiftedShift() guarantees that the masked-off bits are not used. + // lshr (shl X, C1), C2 --> shl X, C1 - C2 + // shl (lshr X, C1), C2 --> lshr X, C1 - C2 + return NewInnerShift(InnerShAmt - OuterShAmt); +} + /// When canEvaluateShifted() returns true for an expression, this function /// inserts the new computation that produces the shifted value. static Value *getShiftedValue(Value *V, unsigned NumBits, bool isLeftShift, @@ -223,89 +285,10 @@ static Value *getShiftedValue(Value *V, unsigned NumBits, bool isLeftShift, 1, getShiftedValue(I->getOperand(1), NumBits, isLeftShift, IC, DL)); return I; - case Instruction::Shl: { - BinaryOperator *BO = cast(I); - unsigned TypeWidth = BO->getType()->getScalarSizeInBits(); - - // We only accept shifts-by-a-constant in CanEvaluateShifted. - ConstantInt *CI = cast(BO->getOperand(1)); - - // We can always fold shl(c1)+shl(c2) -> shl(c1+c2). - if (isLeftShift) { - // If this is oversized composite shift, then unsigned shifts get 0. - unsigned NewShAmt = NumBits+CI->getZExtValue(); - if (NewShAmt >= TypeWidth) - return Constant::getNullValue(I->getType()); - - BO->setOperand(1, ConstantInt::get(BO->getType(), NewShAmt)); - BO->setHasNoUnsignedWrap(false); - BO->setHasNoSignedWrap(false); - return I; - } - - // We turn shl(c)+lshr(c) -> and(c2) if the input doesn't already have - // zeros. - if (CI->getValue() == NumBits) { - APInt Mask(APInt::getLowBitsSet(TypeWidth, TypeWidth - NumBits)); - V = IC.Builder->CreateAnd(BO->getOperand(0), - ConstantInt::get(BO->getContext(), Mask)); - if (Instruction *VI = dyn_cast(V)) { - VI->moveBefore(BO); - VI->takeName(BO); - } - return V; - } - - // We turn shl(c1)+shr(c2) -> shl(c3)+and(c4), but only when we know that - // the and won't be needed. - assert(CI->getZExtValue() > NumBits); - BO->setOperand(1, ConstantInt::get(BO->getType(), - CI->getZExtValue() - NumBits)); - BO->setHasNoUnsignedWrap(false); - BO->setHasNoSignedWrap(false); - return BO; - } - // FIXME: This is almost identical to the SHL case. Refactor both cases into - // a helper function. - case Instruction::LShr: { - BinaryOperator *BO = cast(I); - unsigned TypeWidth = BO->getType()->getScalarSizeInBits(); - // We only accept shifts-by-a-constant in CanEvaluateShifted. - ConstantInt *CI = cast(BO->getOperand(1)); - - // We can always fold lshr(c1)+lshr(c2) -> lshr(c1+c2). - if (!isLeftShift) { - // If this is oversized composite shift, then unsigned shifts get 0. - unsigned NewShAmt = NumBits+CI->getZExtValue(); - if (NewShAmt >= TypeWidth) - return Constant::getNullValue(BO->getType()); - - BO->setOperand(1, ConstantInt::get(BO->getType(), NewShAmt)); - BO->setIsExact(false); - return I; - } - - // We turn lshr(c)+shl(c) -> and(c2) if the input doesn't already have - // zeros. - if (CI->getValue() == NumBits) { - APInt Mask(APInt::getHighBitsSet(TypeWidth, TypeWidth - NumBits)); - V = IC.Builder->CreateAnd(I->getOperand(0), - ConstantInt::get(BO->getContext(), Mask)); - if (Instruction *VI = dyn_cast(V)) { - VI->moveBefore(I); - VI->takeName(I); - } - return V; - } - - // We turn lshr(c1)+shl(c2) -> lshr(c3)+and(c4), but only when we know that - // the and won't be needed. - assert(CI->getZExtValue() > NumBits); - BO->setOperand(1, ConstantInt::get(BO->getType(), - CI->getZExtValue() - NumBits)); - BO->setIsExact(false); - return BO; - } + case Instruction::Shl: + case Instruction::LShr: + return foldShiftedShift(cast(I), NumBits, isLeftShift, + *(IC.Builder)); case Instruction::Select: I->setOperand( -- 2.40.0