]> granicus.if.org Git - llvm/commitdiff
[InstCombine] use m_APInt instead of faking it
authorSanjay Patel <spatel@rotateright.com>
Mon, 16 Jan 2017 21:24:41 +0000 (21:24 +0000)
committerSanjay Patel <spatel@rotateright.com>
Mon, 16 Jan 2017 21:24:41 +0000 (21:24 +0000)
git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@292164 91177308-0d34-0410-b5e6-96231b3b80d8

lib/Transforms/InstCombine/InstCombineShifts.cpp

index 1df0afc05e555a626b8cf25bc3ac7756737fd0f9..bf0ab82e89d8d1f64d7e2811d1220d64f1031f97 100644 (file)
@@ -312,7 +312,7 @@ static Value *getShiftedValue(Value *V, unsigned NumBits, bool isLeftShift,
 /// Try to fold (X << C1) << C2, where the shifts are some combination of
 /// shl/ashr/lshr.
 static Instruction *
-foldShiftByConstOfShiftByConst(BinaryOperator &I, ConstantInt *COp1,
+foldShiftByConstOfShiftByConst(BinaryOperator &I, const APInt *COp1,
                                InstCombiner::BuilderTy *Builder) {
   Value *Op0 = I.getOperand(0);
   uint32_t TypeBits = Op0->getType()->getScalarSizeInBits();
@@ -475,33 +475,26 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1,
                                                BinaryOperator &I) {
   bool isLeftShift = I.getOpcode() == Instruction::Shl;
 
-  ConstantInt *COp1 = nullptr;
-  if (ConstantDataVector *CV = dyn_cast<ConstantDataVector>(Op1))
-    COp1 = dyn_cast_or_null<ConstantInt>(CV->getSplatValue());
-  else if (ConstantVector *CV = dyn_cast<ConstantVector>(Op1))
-    COp1 = dyn_cast_or_null<ConstantInt>(CV->getSplatValue());
-  else
-    COp1 = dyn_cast<ConstantInt>(Op1);
-
-  if (!COp1)
+  const APInt *Op1C;
+  if (!match(Op1, m_APInt(Op1C)))
     return nullptr;
 
   // See if we can propagate this shift into the input, this covers the trivial
   // cast of lshr(shl(x,c1),c2) as well as other more complex cases.
   if (I.getOpcode() != Instruction::AShr &&
-      canEvaluateShifted(Op0, COp1->getZExtValue(), isLeftShift, *this, &I)) {
+      canEvaluateShifted(Op0, Op1C->getZExtValue(), isLeftShift, *this, &I)) {
     DEBUG(dbgs() << "ICE: GetShiftedValue propagating shift through expression"
               " to eliminate shift:\n  IN: " << *Op0 << "\n  SH: " << I <<"\n");
 
     return replaceInstUsesWith(
-        I, getShiftedValue(Op0, COp1->getZExtValue(), isLeftShift, *this, DL));
+        I, getShiftedValue(Op0, Op1C->getZExtValue(), isLeftShift, *this, DL));
   }
 
   // See if we can simplify any instructions used by the instruction whose sole
   // purpose is to compute bits we don't care about.
-  uint32_t TypeBits = Op0->getType()->getScalarSizeInBits();
+  unsigned TypeBits = Op0->getType()->getScalarSizeInBits();
 
-  assert(!COp1->uge(TypeBits) &&
+  assert(!Op1C->uge(TypeBits) &&
          "Shift over the type width should have been removed already");
 
   // ((X*C1) << C2) == (X * (C1 << C2))
@@ -525,7 +518,8 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1,
     if (TrOp && I.isLogicalShift() && TrOp->isShift() &&
         isa<ConstantInt>(TrOp->getOperand(1))) {
       // Okay, we'll do this xform.  Make the shift of shift.
-      Constant *ShAmt = ConstantExpr::getZExt(COp1, TrOp->getType());
+      Constant *ShAmt =
+          ConstantExpr::getZExt(cast<Constant>(Op1), TrOp->getType());
       // (shift2 (shift1 & 0x00FF), c2)
       Value *NSh = Builder->CreateBinOp(I.getOpcode(), TrOp, ShAmt,I.getName());
 
@@ -542,10 +536,10 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1,
       // shift.  We know that it is a logical shift by a constant, so adjust the
       // mask as appropriate.
       if (I.getOpcode() == Instruction::Shl)
-        MaskV <<= COp1->getZExtValue();
+        MaskV <<= Op1C->getZExtValue();
       else {
         assert(I.getOpcode() == Instruction::LShr && "Unknown logical shift");
-        MaskV = MaskV.lshr(COp1->getZExtValue());
+        MaskV = MaskV.lshr(Op1C->getZExtValue());
       }
 
       // shift1 & 0x00FF
@@ -579,7 +573,7 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1,
           // (X + (Y << C))
           Value *X = Builder->CreateBinOp(Op0BO->getOpcode(), YS, V1,
                                           Op0BO->getOperand(1)->getName());
-          uint32_t Op1Val = COp1->getLimitedValue(TypeBits);
+          unsigned Op1Val = Op1C->getLimitedValue(TypeBits);
 
           APInt Bits = APInt::getHighBitsSet(TypeBits, TypeBits - Op1Val);
           Constant *Mask = ConstantInt::get(I.getContext(), Bits);
@@ -615,7 +609,7 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1,
           // (X + (Y << C))
           Value *X = Builder->CreateBinOp(Op0BO->getOpcode(), V1, YS,
                                           Op0BO->getOperand(0)->getName());
-          uint32_t Op1Val = COp1->getLimitedValue(TypeBits);
+          unsigned Op1Val = Op1C->getLimitedValue(TypeBits);
 
           APInt Bits = APInt::getHighBitsSet(TypeBits, TypeBits - Op1Val);
           Constant *Mask = ConstantInt::get(I.getContext(), Bits);
@@ -686,7 +680,7 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1,
     }
   }
 
-  if (Instruction *Folded = foldShiftByConstOfShiftByConst(I, COp1, Builder))
+  if (Instruction *Folded = foldShiftByConstOfShiftByConst(I, Op1C, Builder))
     return Folded;
 
   return nullptr;