]> granicus.if.org Git - llvm/commitdiff
[InstCombine] refactor shift-of-shift folds; NFCI
authorSanjay Patel <spatel@rotateright.com>
Mon, 16 Jan 2017 17:27:50 +0000 (17:27 +0000)
committerSanjay Patel <spatel@rotateright.com>
Mon, 16 Jan 2017 17:27:50 +0000 (17:27 +0000)
Reduces code duplication and makes it easier to extend these folds for vectors.

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@292145 91177308-0d34-0410-b5e6-96231b3b80d8

lib/Transforms/InstCombine/InstCombineShifts.cpp

index 63216423f7594b6003e0074d3335af5fc98f9cae..8ccbd5cf33dd1d5f8dffe0a3cacf3c65376ed67d 100644 (file)
@@ -190,6 +190,68 @@ static bool canEvaluateShifted(Value *V, unsigned NumBits, bool IsLeftShift,
   }
 }
 
+/// Fold OuterShift (InnerShift X, C1), C2.
+/// See canEvaluateShiftedShift() for the constraints on these instructions.
+static Value *foldShiftedShift(BinaryOperator *InnerShift, unsigned OuterShAmt,
+                               bool IsOuterShl,
+                               InstCombiner::BuilderTy &Builder) {
+  bool IsInnerShl = InnerShift->getOpcode() == Instruction::Shl;
+  Type *ShType = InnerShift->getType();
+  unsigned TypeWidth = ShType->getScalarSizeInBits();
+
+  // We only accept shifts-by-a-constant in canEvaluateShifted().
+  ConstantInt *C1 = cast<ConstantInt>(InnerShift->getOperand(1));
+  unsigned InnerShAmt = C1->getZExtValue();
+
+  // Change the shift amount and clear the appropriate IR flags.
+  auto NewInnerShift = [&](unsigned ShAmt) {
+    InnerShift->setOperand(1, ConstantInt::get(ShType, ShAmt));
+    if (IsInnerShl) {
+      InnerShift->setHasNoUnsignedWrap(false);
+      InnerShift->setHasNoSignedWrap(false);
+    } else {
+      InnerShift->setIsExact(false);
+    }
+    return InnerShift;
+  };
+
+  // Two logical shifts in the same direction:
+  // shl (shl X, C1), C2 -->  shl X, C1 + C2
+  // lshr (lshr X, C1), C2 --> lshr X, C1 + C2
+  if (IsInnerShl == IsOuterShl) {
+    // If this is an oversized composite shift, then unsigned shifts get 0.
+    if (InnerShAmt + OuterShAmt >= TypeWidth)
+      return Constant::getNullValue(ShType);
+
+    return NewInnerShift(InnerShAmt + OuterShAmt);
+  }
+
+  // Equal shift amounts in opposite directions become bitwise 'and':
+  // lshr (shl X, C), C --> and X, C'
+  // shl (lshr X, C), C --> and X, C'
+  if (InnerShAmt == OuterShAmt) {
+    APInt Mask = IsInnerShl
+                     ? APInt::getLowBitsSet(TypeWidth, TypeWidth - OuterShAmt)
+                     : APInt::getHighBitsSet(TypeWidth, TypeWidth - OuterShAmt);
+    Value *And = Builder.CreateAnd(InnerShift->getOperand(0),
+                                   ConstantInt::get(ShType, Mask));
+    if (auto *AndI = dyn_cast<Instruction>(And)) {
+      AndI->moveBefore(InnerShift);
+      AndI->takeName(InnerShift);
+    }
+    return And;
+  }
+
+  assert(InnerShAmt > OuterShAmt &&
+         "Unexpected opposite direction logical shift pair");
+
+  // In general, we would need an 'and' for this transform, but
+  // canEvaluateShiftedShift() guarantees that the masked-off bits are not used.
+  // lshr (shl X, C1), C2 -->  shl X, C1 - C2
+  // shl (lshr X, C1), C2 --> lshr X, C1 - C2
+  return NewInnerShift(InnerShAmt - OuterShAmt);
+}
+
 /// When canEvaluateShifted() returns true for an expression, this function
 /// inserts the new computation that produces the shifted value.
 static Value *getShiftedValue(Value *V, unsigned NumBits, bool isLeftShift,
@@ -223,89 +285,10 @@ static Value *getShiftedValue(Value *V, unsigned NumBits, bool isLeftShift,
         1, getShiftedValue(I->getOperand(1), NumBits, isLeftShift, IC, DL));
     return I;
 
-  case Instruction::Shl: {
-    BinaryOperator *BO = cast<BinaryOperator>(I);
-    unsigned TypeWidth = BO->getType()->getScalarSizeInBits();
-
-    // We only accept shifts-by-a-constant in CanEvaluateShifted.
-    ConstantInt *CI = cast<ConstantInt>(BO->getOperand(1));
-
-    // We can always fold shl(c1)+shl(c2) -> shl(c1+c2).
-    if (isLeftShift) {
-      // If this is oversized composite shift, then unsigned shifts get 0.
-      unsigned NewShAmt = NumBits+CI->getZExtValue();
-      if (NewShAmt >= TypeWidth)
-        return Constant::getNullValue(I->getType());
-
-      BO->setOperand(1, ConstantInt::get(BO->getType(), NewShAmt));
-      BO->setHasNoUnsignedWrap(false);
-      BO->setHasNoSignedWrap(false);
-      return I;
-    }
-
-    // We turn shl(c)+lshr(c) -> and(c2) if the input doesn't already have
-    // zeros.
-    if (CI->getValue() == NumBits) {
-      APInt Mask(APInt::getLowBitsSet(TypeWidth, TypeWidth - NumBits));
-      V = IC.Builder->CreateAnd(BO->getOperand(0),
-                                ConstantInt::get(BO->getContext(), Mask));
-      if (Instruction *VI = dyn_cast<Instruction>(V)) {
-        VI->moveBefore(BO);
-        VI->takeName(BO);
-      }
-      return V;
-    }
-
-    // We turn shl(c1)+shr(c2) -> shl(c3)+and(c4), but only when we know that
-    // the and won't be needed.
-    assert(CI->getZExtValue() > NumBits);
-    BO->setOperand(1, ConstantInt::get(BO->getType(),
-                                       CI->getZExtValue() - NumBits));
-    BO->setHasNoUnsignedWrap(false);
-    BO->setHasNoSignedWrap(false);
-    return BO;
-  }
-  // FIXME: This is almost identical to the SHL case. Refactor both cases into
-  // a helper function.
-  case Instruction::LShr: {
-    BinaryOperator *BO = cast<BinaryOperator>(I);
-    unsigned TypeWidth = BO->getType()->getScalarSizeInBits();
-    // We only accept shifts-by-a-constant in CanEvaluateShifted.
-    ConstantInt *CI = cast<ConstantInt>(BO->getOperand(1));
-
-    // We can always fold lshr(c1)+lshr(c2) -> lshr(c1+c2).
-    if (!isLeftShift) {
-      // If this is oversized composite shift, then unsigned shifts get 0.
-      unsigned NewShAmt = NumBits+CI->getZExtValue();
-      if (NewShAmt >= TypeWidth)
-        return Constant::getNullValue(BO->getType());
-
-      BO->setOperand(1, ConstantInt::get(BO->getType(), NewShAmt));
-      BO->setIsExact(false);
-      return I;
-    }
-
-    // We turn lshr(c)+shl(c) -> and(c2) if the input doesn't already have
-    // zeros.
-    if (CI->getValue() == NumBits) {
-      APInt Mask(APInt::getHighBitsSet(TypeWidth, TypeWidth - NumBits));
-      V = IC.Builder->CreateAnd(I->getOperand(0),
-                                ConstantInt::get(BO->getContext(), Mask));
-      if (Instruction *VI = dyn_cast<Instruction>(V)) {
-        VI->moveBefore(I);
-        VI->takeName(I);
-      }
-      return V;
-    }
-
-    // We turn lshr(c1)+shl(c2) -> lshr(c3)+and(c4), but only when we know that
-    // the and won't be needed.
-    assert(CI->getZExtValue() > NumBits);
-    BO->setOperand(1, ConstantInt::get(BO->getType(),
-                                       CI->getZExtValue() - NumBits));
-    BO->setIsExact(false);
-    return BO;
-  }
+  case Instruction::Shl:
+  case Instruction::LShr:
+    return foldShiftedShift(cast<BinaryOperator>(I), NumBits, isLeftShift,
+                            *(IC.Builder));
 
   case Instruction::Select:
     I->setOperand(