]> granicus.if.org Git - llvm/commitdiff
[InstCombine] Teach foldSelectICmpAnd to recognize a (icmp slt X, 0) and (icmp sgt...
authorCraig Topper <craig.topper@intel.com>
Mon, 21 Aug 2017 19:02:06 +0000 (19:02 +0000)
committerCraig Topper <craig.topper@intel.com>
Mon, 21 Aug 2017 19:02:06 +0000 (19:02 +0000)
This is similar to what was already done in foldSelectICmpAndOr. Ultimately I'd like to see if we can call foldSelectICmpAnd from foldSelectIntoOp if we detect a power of 2 constant. This would allow us to remove foldSelectICmpAndOr entirely.

Differential Revision: https://reviews.llvm.org/D36498

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@311362 91177308-0d34-0410-b5e6-96231b3b80d8

lib/Transforms/InstCombine/InstCombineSelect.cpp
test/Transforms/InstCombine/select-with-bitwise-ops.ll

index b5fa2dc15d3c8aebba7a5cf2e89406c299569ea6..5712fa4f09c4d47c0dcf4d1da36c856b90ac28ce 100644 (file)
@@ -594,6 +594,15 @@ canonicalizeMinMaxWithConstant(SelectInst &Sel, ICmpInst &Cmp,
 /// icmp instruction with zero, and we have an 'and' with the non-constant value
 /// and a power of two we can turn the select into a shift on the result of the
 /// 'and'.
+/// This folds:
+///  select (icmp eq (and X, C1)), C2, C3
+///    iff C1 is a power 2 and the difference between C2 and C3 is a power of 2.
+/// To something like:
+///  (shr (and (X, C1)), (log2(C1) - log2(C2-C3))) + C3
+/// Or:
+///  (shl (and (X, C1)), (log2(C2-C3) - log2(C1))) + C3
+/// With some variations depending if C3 is larger than C2, or the shift
+/// isn't needed, or the bit widths don't match.
 static Value *foldSelectICmpAnd(Type *SelType, const ICmpInst *IC,
                                 APInt TrueVal, APInt FalseVal,
                                 InstCombiner::BuilderTy &Builder) {
@@ -603,16 +612,32 @@ static Value *foldSelectICmpAnd(Type *SelType, const ICmpInst *IC,
   if (SelType->isVectorTy() != IC->getType()->isVectorTy())
     return nullptr;
 
-  if (!IC->isEquality())
-    return nullptr;
+  Value *V;
+  APInt AndMask;
+  bool CreateAnd = false;
+  ICmpInst::Predicate Pred = IC->getPredicate();
+  if (ICmpInst::isEquality(Pred)) {
+    if (!match(IC->getOperand(1), m_Zero()))
+      return nullptr;
 
-  if (!match(IC->getOperand(1), m_Zero()))
-    return nullptr;
+    V = IC->getOperand(0);
+
+    const APInt *AndRHS;
+    if (!match(V, m_And(m_Value(), m_Power2(AndRHS))))
+      return nullptr;
+
+    AndMask = *AndRHS;
+  } else if (decomposeBitTestICmp(IC->getOperand(0), IC->getOperand(1),
+                                  Pred, V, AndMask)) {
+    assert(ICmpInst::isEquality(Pred) && "Not equality test?");
+
+    if (!AndMask.isPowerOf2())
+      return nullptr;
 
-  const APInt *AndRHS;
-  Value *LHS = IC->getOperand(0);
-  if (!match(LHS, m_And(m_Value(), m_Power2(AndRHS))))
+    CreateAnd = true;
+  } else {
     return nullptr;
+  }
 
   // If both select arms are non-zero see if we have a select of the form
   // 'x ? 2^n + C : C'. Then we can offset both arms by C, use the logic
@@ -639,11 +664,15 @@ static Value *foldSelectICmpAnd(Type *SelType, const ICmpInst *IC,
   // desired result.
   const APInt &ValC = !TrueVal.isNullValue() ? TrueVal : FalseVal;
   unsigned ValZeros = ValC.logBase2();
-  unsigned AndZeros = AndRHS->logBase2();
+  unsigned AndZeros = AndMask.logBase2();
+
+  if (CreateAnd) {
+    // Insert the AND instruction on the input to the truncate.
+    V = Builder.CreateAnd(V, ConstantInt::get(V->getType(), AndMask));
+  }
 
   // If types don't match we can still convert the select by introducing a zext
   // or a trunc of the 'and'.
-  Value *V = LHS;
   if (ValZeros > AndZeros) {
     V = Builder.CreateZExtOrTrunc(V, SelType);
     V = Builder.CreateShl(V, ValZeros - AndZeros);
@@ -656,7 +685,7 @@ static Value *foldSelectICmpAnd(Type *SelType, const ICmpInst *IC,
   // Okay, now we know that everything is set up, we just don't know whether we
   // have a icmp_ne or icmp_eq and whether the true or false val is the zero.
   bool ShouldNotVal = !TrueVal.isNullValue();
-  ShouldNotVal ^= IC->getPredicate() == ICmpInst::ICMP_NE;
+  ShouldNotVal ^= Pred == ICmpInst::ICMP_NE;
   if (ShouldNotVal)
     V = Builder.CreateXor(V, ValC);
 
@@ -672,15 +701,6 @@ Instruction *InstCombiner::foldSelectInstWithICmp(SelectInst &SI,
   Value *TrueVal = SI.getTrueValue();
   Value *FalseVal = SI.getFalseValue();
 
-  {
-    const APInt *TrueValC, *FalseValC;
-    if (match(TrueVal, m_APInt(TrueValC)) &&
-        match(FalseVal, m_APInt(FalseValC)))
-      if (Value *V = foldSelectICmpAnd(SI.getType(), ICI, *TrueValC,
-                                       *FalseValC, Builder))
-        return replaceInstUsesWith(SI, V);
-  }
-
   if (Instruction *NewSel = canonicalizeMinMaxWithConstant(SI, *ICI, Builder))
     return NewSel;
 
@@ -695,6 +715,7 @@ Instruction *InstCombiner::foldSelectInstWithICmp(SelectInst &SI,
   // FIXME: Type and constness constraints could be lifted, but we have to
   //        watch code size carefully. We should consider xor instead of
   //        sub/add when we decide to do that.
+  // TODO: Merge this with foldSelectICmpAnd somehow.
   if (CmpLHS->getType()->isIntOrIntVectorTy() &&
       CmpLHS->getType() == TrueVal->getType()) {
     const APInt *C1, *C2;
@@ -725,6 +746,15 @@ Instruction *InstCombiner::foldSelectInstWithICmp(SelectInst &SI,
     }
   }
 
+  {
+    const APInt *TrueValC, *FalseValC;
+    if (match(TrueVal, m_APInt(TrueValC)) &&
+        match(FalseVal, m_APInt(FalseValC)))
+      if (Value *V = foldSelectICmpAnd(SI.getType(), ICI, *TrueValC,
+                                       *FalseValC, Builder))
+        return replaceInstUsesWith(SI, V);
+  }
+
   // NOTE: if we wanted to, this is where to detect integer MIN/MAX
 
   if (CmpRHS != CmpLHS && isa<Constant>(CmpRHS)) {
index 602c05478a5b3a512188b8582ef39d284136a3df..6248dd0322152e8a0ff507ed15c69e24124e6fc4 100644 (file)
@@ -395,6 +395,108 @@ define i8 @test70(i8 %x, i8 %y) {
   ret i8 %select
 }
 
+define i32 @test71(i32 %x) {
+; CHECK-LABEL: @test71(
+; CHECK-NEXT:    [[TMP1:%.*]] = lshr i32 [[X:%.*]], 6
+; CHECK-NEXT:    [[TMP2:%.*]] = and i32 [[TMP1]], 2
+; CHECK-NEXT:    [[TMP3:%.*]] = xor i32 [[TMP2]], 42
+; CHECK-NEXT:    ret i32 [[TMP3]]
+;
+  %1 = and i32 %x, 128
+  %2 = icmp ne i32 %1, 0
+  %3 = select i1 %2, i32 40, i32 42
+  ret i32 %3
+}
+
+define <2 x i32> @test71vec(<2 x i32> %x) {
+; CHECK-LABEL: @test71vec(
+; CHECK-NEXT:    [[TMP1:%.*]] = lshr <2 x i32> [[X:%.*]], <i32 6, i32 6>
+; CHECK-NEXT:    [[TMP2:%.*]] = and <2 x i32> [[TMP1]], <i32 2, i32 2>
+; CHECK-NEXT:    [[TMP3:%.*]] = xor <2 x i32> [[TMP2]], <i32 42, i32 42>
+; CHECK-NEXT:    ret <2 x i32> [[TMP3]]
+;
+  %1 = and <2 x i32> %x, <i32 128, i32 128>
+  %2 = icmp ne <2 x i32> %1, <i32 0, i32 0>
+  %3 = select <2 x i1> %2, <2 x i32> <i32 40, i32 40>, <2 x i32> <i32 42, i32 42>
+  ret <2 x i32> %3
+}
+
+define i32 @test72(i32 %x) {
+; CHECK-LABEL: @test72(
+; CHECK-NEXT:    [[TMP1:%.*]] = lshr i32 [[X:%.*]], 6
+; CHECK-NEXT:    [[TMP2:%.*]] = and i32 [[TMP1]], 2
+; CHECK-NEXT:    [[TMP3:%.*]] = or i32 [[TMP2]], 40
+; CHECK-NEXT:    ret i32 [[TMP3]]
+;
+  %1 = and i32 %x, 128
+  %2 = icmp eq i32 %1, 0
+  %3 = select i1 %2, i32 40, i32 42
+  ret i32 %3
+}
+
+define <2 x i32> @test72vec(<2 x i32> %x) {
+; CHECK-LABEL: @test72vec(
+; CHECK-NEXT:    [[TMP1:%.*]] = lshr <2 x i32> [[X:%.*]], <i32 6, i32 6>
+; CHECK-NEXT:    [[TMP2:%.*]] = and <2 x i32> [[TMP1]], <i32 2, i32 2>
+; CHECK-NEXT:    [[TMP3:%.*]] = or <2 x i32> [[TMP2]], <i32 40, i32 40>
+; CHECK-NEXT:    ret <2 x i32> [[TMP3]]
+;
+  %1 = and <2 x i32> %x, <i32 128, i32 128>
+  %2 = icmp eq <2 x i32> %1, <i32 0, i32 0>
+  %3 = select <2 x i1> %2, <2 x i32> <i32 40, i32 40>, <2 x i32> <i32 42, i32 42>
+  ret <2 x i32> %3
+}
+
+define i32 @test73(i32 %x) {
+; CHECK-LABEL: @test73(
+; CHECK-NEXT:    [[TMP1:%.*]] = lshr i32 [[X:%.*]], 6
+; CHECK-NEXT:    [[TMP2:%.*]] = and i32 [[TMP1]], 2
+; CHECK-NEXT:    [[TMP3:%.*]] = or i32 [[TMP2]], 40
+; CHECK-NEXT:    ret i32 [[TMP3]]
+;
+  %1 = trunc i32 %x to i8
+  %2 = icmp sgt i8 %1, -1
+  %3 = select i1 %2, i32 40, i32 42
+  ret i32 %3
+}
+
+define <2 x i32> @test73vec(<2 x i32> %x) {
+; CHECK-LABEL: @test73vec(
+; CHECK-NEXT:    [[TMP1:%.*]] = lshr <2 x i32> [[X:%.*]], <i32 6, i32 6>
+; CHECK-NEXT:    [[TMP2:%.*]] = and <2 x i32> [[TMP1]], <i32 2, i32 2>
+; CHECK-NEXT:    [[TMP3:%.*]] = or <2 x i32> [[TMP2]], <i32 40, i32 40>
+; CHECK-NEXT:    ret <2 x i32> [[TMP3]]
+;
+  %1 = trunc <2 x i32> %x to <2 x i8>
+  %2 = icmp sgt <2 x i8> %1, <i8 -1, i8 -1>
+  %3 = select <2 x i1> %2, <2 x i32> <i32 40, i32 40>, <2 x i32> <i32 42, i32 42>
+  ret <2 x i32> %3
+}
+
+define i32 @test74(i32 %x) {
+; CHECK-LABEL: @test74(
+; CHECK-NEXT:    [[TMP1:%.*]] = ashr i32 [[X:%.*]], 31
+; CHECK-NEXT:    [[TMP2:%.*]] = and i32 [[TMP1]], 2
+; CHECK-NEXT:    [[TMP3:%.*]] = or i32 [[TMP2]], 40
+; CHECK-NEXT:    ret i32 [[TMP3]]
+;
+  %1 = icmp sgt i32 %x, -1
+  %2 = select i1 %1, i32 40, i32 42
+  ret i32 %2
+}
+
+define <2 x i32> @test74vec(<2 x i32> %x) {
+; CHECK-LABEL: @test74vec(
+; CHECK-NEXT:    [[TMP1:%.*]] = ashr <2 x i32> [[X:%.*]], <i32 31, i32 31>
+; CHECK-NEXT:    [[TMP2:%.*]] = and <2 x i32> [[TMP1]], <i32 2, i32 2>
+; CHECK-NEXT:    [[TMP3:%.*]] = or <2 x i32> [[TMP2]], <i32 40, i32 40>
+; CHECK-NEXT:    ret <2 x i32> [[TMP3]]
+;
+  %1 = icmp sgt <2 x i32> %x, <i32 -1, i32 -1>
+  %2 = select <2 x i1> %1, <2 x i32> <i32 40, i32 40>, <2 x i32> <i32 42, i32 42>
+  ret <2 x i32> %2
+}
+
 define i32 @shift_no_xor_multiuse_or(i32 %x, i32 %y) {
 ; CHECK-LABEL: @shift_no_xor_multiuse_or(
 ; CHECK-NEXT:    [[OR:%.*]] = or i32 [[Y:%.*]], 2