]> granicus.if.org Git - llvm/commitdiff
[InstCombine] use m_APInt to allow ashr folds for vectors with splat constants
authorSanjay Patel <spatel@rotateright.com>
Sat, 21 Jan 2017 17:59:59 +0000 (17:59 +0000)
committerSanjay Patel <spatel@rotateright.com>
Sat, 21 Jan 2017 17:59:59 +0000 (17:59 +0000)
We may be able to assert that no shl-shl or lshr-lshr pairs ever get here
because we should have already handled those in foldShiftedShift().

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@292726 91177308-0d34-0410-b5e6-96231b3b80d8

lib/Transforms/InstCombine/InstCombineShifts.cpp
test/Transforms/InstCombine/shift-sra.ll

index fcd7b601f7fb6d001c6af05a5a65d2af9e3f2c18..21a975857b99495aa4192d9565f8189899c17ee3 100644 (file)
@@ -315,14 +315,32 @@ static Instruction *
 foldShiftByConstOfShiftByConst(BinaryOperator &I, const APInt *COp1,
                                InstCombiner::BuilderTy *Builder) {
   Value *Op0 = I.getOperand(0);
-  uint32_t TypeBits = Op0->getType()->getScalarSizeInBits();
+  unsigned TypeBits = Op0->getType()->getScalarSizeInBits();
 
   // Find out if this is a shift of a shift by a constant.
   BinaryOperator *ShiftOp = dyn_cast<BinaryOperator>(Op0);
-  if (!ShiftOp || !ShiftOp->isShift() ||
-      !isa<ConstantInt>(ShiftOp->getOperand(1)))
+  if (!ShiftOp || !ShiftOp->isShift())
+    return nullptr;
+
+  const APInt *ShAmt1;
+  if (!match(ShiftOp->getOperand(1), m_APInt(ShAmt1)))
     return nullptr;
 
+  // Check for (X << c1) << c2  and  (X >> c1) >> c2
+  if (I.getOpcode() == ShiftOp->getOpcode()) {
+    unsigned AmtSum = (*ShAmt1 + *COp1).getZExtValue();
+    // If this is an oversized composite shift, then unsigned shifts become
+    // zero (handled in InstSimplify) and ashr saturates.
+    if (AmtSum >= TypeBits) {
+      if (I.getOpcode() != Instruction::AShr)
+        return nullptr;
+      AmtSum = TypeBits - 1; // Saturate to 31 for i32 ashr.
+    }
+
+    return BinaryOperator::Create(I.getOpcode(), ShiftOp->getOperand(0),
+                                  ConstantInt::get(I.getType(), AmtSum));
+  }
+
   // This is a constant shift of a constant shift. Be careful about hiding
   // shl instructions behind bit masks. They are used to represent multiplies
   // by a constant, and it is important that simple arithmetic expressions
@@ -335,31 +353,20 @@ foldShiftByConstOfShiftByConst(BinaryOperator &I, const APInt *COp1,
   // Combinations of right and left shifts will still be optimized in
   // DAGCombine where scalar evolution no longer applies.
 
-  ConstantInt *ShiftAmt1C = cast<ConstantInt>(ShiftOp->getOperand(1));
+  // FIXME: Everything under here should be extended to work with vector types.
+
+  auto *ShiftAmt1C = dyn_cast<ConstantInt>(ShiftOp->getOperand(1));
+  if (!ShiftAmt1C)
+    return nullptr;
+
   uint32_t ShiftAmt1 = ShiftAmt1C->getLimitedValue(TypeBits);
   uint32_t ShiftAmt2 = COp1->getLimitedValue(TypeBits);
   assert(ShiftAmt2 != 0 && "Should have been simplified earlier");
   if (ShiftAmt1 == 0)
     return nullptr; // Will be simplified in the future.
-  Value *X = ShiftOp->getOperand(0);
 
+  Value *X = ShiftOp->getOperand(0);
   IntegerType *Ty = cast<IntegerType>(I.getType());
-
-  // Check for (X << c1) << c2  and  (X >> c1) >> c2
-  if (I.getOpcode() == ShiftOp->getOpcode()) {
-    uint32_t AmtSum = ShiftAmt1 + ShiftAmt2; // Fold into one big shift.
-    // If this is an oversized composite shift, then unsigned shifts become
-    // zero (handled in InstSimplify) and ashr saturates.
-    if (AmtSum >= TypeBits) {
-      if (I.getOpcode() != Instruction::AShr)
-        return nullptr;
-      AmtSum = TypeBits - 1; // Saturate to 31 for i32 ashr.
-    }
-
-    return BinaryOperator::Create(I.getOpcode(), X,
-                                  ConstantInt::get(Ty, AmtSum));
-  }
-
   if (ShiftAmt1 == ShiftAmt2) {
     // If we have ((X << C) >>u C), turn this into X & (-1 >>u C).
     if (I.getOpcode() == Instruction::LShr &&
index 1d6f79f54b18cc7211759e0b9915fecdd946716c..4483e60b506a02641f67bb532fe81e0a0c854605 100644 (file)
@@ -139,13 +139,11 @@ define i32 @ashr_overshift(i32 %x) {
   ret i32 %sh2
 }
 
-; FIXME:
 ; (X >>s C1) >>s C2 --> X >>s (C1 + C2)
 
 define <2 x i32> @ashr_ashr_splat_vec(<2 x i32> %x) {
 ; CHECK-LABEL: @ashr_ashr_splat_vec(
-; CHECK-NEXT:    [[SH1:%.*]] = ashr <2 x i32> %x, <i32 5, i32 5>
-; CHECK-NEXT:    [[SH2:%.*]] = ashr <2 x i32> [[SH1]], <i32 7, i32 7>
+; CHECK-NEXT:    [[SH2:%.*]] = ashr <2 x i32> %x, <i32 12, i32 12>
 ; CHECK-NEXT:    ret <2 x i32> [[SH2]]
 ;
   %sh1 = ashr <2 x i32> %x, <i32 5, i32 5>
@@ -153,13 +151,11 @@ define <2 x i32> @ashr_ashr_splat_vec(<2 x i32> %x) {
   ret <2 x i32> %sh2
 }
 
-; FIXME:
 ; (X >>s C1) >>s C2 --> X >>s (Bitwidth - 1)
 
 define <2 x i32> @ashr_overshift_splat_vec(<2 x i32> %x) {
 ; CHECK-LABEL: @ashr_overshift_splat_vec(
-; CHECK-NEXT:    [[SH1:%.*]] = ashr <2 x i32> %x, <i32 15, i32 15>
-; CHECK-NEXT:    [[SH2:%.*]] = ashr <2 x i32> [[SH1]], <i32 17, i32 17>
+; CHECK-NEXT:    [[SH2:%.*]] = ashr <2 x i32> %x, <i32 31, i32 31>
 ; CHECK-NEXT:    ret <2 x i32> [[SH2]]
 ;
   %sh1 = ashr <2 x i32> %x, <i32 15, i32 15>