From bbbb2f573fc5df468971290c6ec74033f8426bb5 Mon Sep 17 00:00:00 2001 From: Craig Topper Date: Mon, 21 Aug 2017 19:02:06 +0000 Subject: [PATCH] [InstCombine] Teach foldSelectICmpAnd to recognize a (icmp slt X, 0) and (icmp sgt X, -1) as equivalent to an and with the sign bit of the truncated type This is similar to what was already done in foldSelectICmpAndOr. Ultimately I'd like to see if we can call foldSelectICmpAnd from foldSelectIntoOp if we detect a power of 2 constant. This would allow us to remove foldSelectICmpAndOr entirely. Differential Revision: https://reviews.llvm.org/D36498 git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@311362 91177308-0d34-0410-b5e6-96231b3b80d8 --- .../InstCombine/InstCombineSelect.cpp | 68 ++++++++---- .../InstCombine/select-with-bitwise-ops.ll | 102 ++++++++++++++++++ 2 files changed, 151 insertions(+), 19 deletions(-) diff --git a/lib/Transforms/InstCombine/InstCombineSelect.cpp b/lib/Transforms/InstCombine/InstCombineSelect.cpp index b5fa2dc15d3..5712fa4f09c 100644 --- a/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -594,6 +594,15 @@ canonicalizeMinMaxWithConstant(SelectInst &Sel, ICmpInst &Cmp, /// icmp instruction with zero, and we have an 'and' with the non-constant value /// and a power of two we can turn the select into a shift on the result of the /// 'and'. +/// This folds: +/// select (icmp eq (and X, C1)), C2, C3 +/// iff C1 is a power 2 and the difference between C2 and C3 is a power of 2. +/// To something like: +/// (shr (and (X, C1)), (log2(C1) - log2(C2-C3))) + C3 +/// Or: +/// (shl (and (X, C1)), (log2(C2-C3) - log2(C1))) + C3 +/// With some variations depending if C3 is larger than C2, or the shift +/// isn't needed, or the bit widths don't match. static Value *foldSelectICmpAnd(Type *SelType, const ICmpInst *IC, APInt TrueVal, APInt FalseVal, InstCombiner::BuilderTy &Builder) { @@ -603,16 +612,32 @@ static Value *foldSelectICmpAnd(Type *SelType, const ICmpInst *IC, if (SelType->isVectorTy() != IC->getType()->isVectorTy()) return nullptr; - if (!IC->isEquality()) - return nullptr; + Value *V; + APInt AndMask; + bool CreateAnd = false; + ICmpInst::Predicate Pred = IC->getPredicate(); + if (ICmpInst::isEquality(Pred)) { + if (!match(IC->getOperand(1), m_Zero())) + return nullptr; - if (!match(IC->getOperand(1), m_Zero())) - return nullptr; + V = IC->getOperand(0); + + const APInt *AndRHS; + if (!match(V, m_And(m_Value(), m_Power2(AndRHS)))) + return nullptr; + + AndMask = *AndRHS; + } else if (decomposeBitTestICmp(IC->getOperand(0), IC->getOperand(1), + Pred, V, AndMask)) { + assert(ICmpInst::isEquality(Pred) && "Not equality test?"); + + if (!AndMask.isPowerOf2()) + return nullptr; - const APInt *AndRHS; - Value *LHS = IC->getOperand(0); - if (!match(LHS, m_And(m_Value(), m_Power2(AndRHS)))) + CreateAnd = true; + } else { return nullptr; + } // If both select arms are non-zero see if we have a select of the form // 'x ? 2^n + C : C'. Then we can offset both arms by C, use the logic @@ -639,11 +664,15 @@ static Value *foldSelectICmpAnd(Type *SelType, const ICmpInst *IC, // desired result. const APInt &ValC = !TrueVal.isNullValue() ? TrueVal : FalseVal; unsigned ValZeros = ValC.logBase2(); - unsigned AndZeros = AndRHS->logBase2(); + unsigned AndZeros = AndMask.logBase2(); + + if (CreateAnd) { + // Insert the AND instruction on the input to the truncate. + V = Builder.CreateAnd(V, ConstantInt::get(V->getType(), AndMask)); + } // If types don't match we can still convert the select by introducing a zext // or a trunc of the 'and'. - Value *V = LHS; if (ValZeros > AndZeros) { V = Builder.CreateZExtOrTrunc(V, SelType); V = Builder.CreateShl(V, ValZeros - AndZeros); @@ -656,7 +685,7 @@ static Value *foldSelectICmpAnd(Type *SelType, const ICmpInst *IC, // Okay, now we know that everything is set up, we just don't know whether we // have a icmp_ne or icmp_eq and whether the true or false val is the zero. bool ShouldNotVal = !TrueVal.isNullValue(); - ShouldNotVal ^= IC->getPredicate() == ICmpInst::ICMP_NE; + ShouldNotVal ^= Pred == ICmpInst::ICMP_NE; if (ShouldNotVal) V = Builder.CreateXor(V, ValC); @@ -672,15 +701,6 @@ Instruction *InstCombiner::foldSelectInstWithICmp(SelectInst &SI, Value *TrueVal = SI.getTrueValue(); Value *FalseVal = SI.getFalseValue(); - { - const APInt *TrueValC, *FalseValC; - if (match(TrueVal, m_APInt(TrueValC)) && - match(FalseVal, m_APInt(FalseValC))) - if (Value *V = foldSelectICmpAnd(SI.getType(), ICI, *TrueValC, - *FalseValC, Builder)) - return replaceInstUsesWith(SI, V); - } - if (Instruction *NewSel = canonicalizeMinMaxWithConstant(SI, *ICI, Builder)) return NewSel; @@ -695,6 +715,7 @@ Instruction *InstCombiner::foldSelectInstWithICmp(SelectInst &SI, // FIXME: Type and constness constraints could be lifted, but we have to // watch code size carefully. We should consider xor instead of // sub/add when we decide to do that. + // TODO: Merge this with foldSelectICmpAnd somehow. if (CmpLHS->getType()->isIntOrIntVectorTy() && CmpLHS->getType() == TrueVal->getType()) { const APInt *C1, *C2; @@ -725,6 +746,15 @@ Instruction *InstCombiner::foldSelectInstWithICmp(SelectInst &SI, } } + { + const APInt *TrueValC, *FalseValC; + if (match(TrueVal, m_APInt(TrueValC)) && + match(FalseVal, m_APInt(FalseValC))) + if (Value *V = foldSelectICmpAnd(SI.getType(), ICI, *TrueValC, + *FalseValC, Builder)) + return replaceInstUsesWith(SI, V); + } + // NOTE: if we wanted to, this is where to detect integer MIN/MAX if (CmpRHS != CmpLHS && isa(CmpRHS)) { diff --git a/test/Transforms/InstCombine/select-with-bitwise-ops.ll b/test/Transforms/InstCombine/select-with-bitwise-ops.ll index 602c05478a5..6248dd03221 100644 --- a/test/Transforms/InstCombine/select-with-bitwise-ops.ll +++ b/test/Transforms/InstCombine/select-with-bitwise-ops.ll @@ -395,6 +395,108 @@ define i8 @test70(i8 %x, i8 %y) { ret i8 %select } +define i32 @test71(i32 %x) { +; CHECK-LABEL: @test71( +; CHECK-NEXT: [[TMP1:%.*]] = lshr i32 [[X:%.*]], 6 +; CHECK-NEXT: [[TMP2:%.*]] = and i32 [[TMP1]], 2 +; CHECK-NEXT: [[TMP3:%.*]] = xor i32 [[TMP2]], 42 +; CHECK-NEXT: ret i32 [[TMP3]] +; + %1 = and i32 %x, 128 + %2 = icmp ne i32 %1, 0 + %3 = select i1 %2, i32 40, i32 42 + ret i32 %3 +} + +define <2 x i32> @test71vec(<2 x i32> %x) { +; CHECK-LABEL: @test71vec( +; CHECK-NEXT: [[TMP1:%.*]] = lshr <2 x i32> [[X:%.*]], +; CHECK-NEXT: [[TMP2:%.*]] = and <2 x i32> [[TMP1]], +; CHECK-NEXT: [[TMP3:%.*]] = xor <2 x i32> [[TMP2]], +; CHECK-NEXT: ret <2 x i32> [[TMP3]] +; + %1 = and <2 x i32> %x, + %2 = icmp ne <2 x i32> %1, + %3 = select <2 x i1> %2, <2 x i32> , <2 x i32> + ret <2 x i32> %3 +} + +define i32 @test72(i32 %x) { +; CHECK-LABEL: @test72( +; CHECK-NEXT: [[TMP1:%.*]] = lshr i32 [[X:%.*]], 6 +; CHECK-NEXT: [[TMP2:%.*]] = and i32 [[TMP1]], 2 +; CHECK-NEXT: [[TMP3:%.*]] = or i32 [[TMP2]], 40 +; CHECK-NEXT: ret i32 [[TMP3]] +; + %1 = and i32 %x, 128 + %2 = icmp eq i32 %1, 0 + %3 = select i1 %2, i32 40, i32 42 + ret i32 %3 +} + +define <2 x i32> @test72vec(<2 x i32> %x) { +; CHECK-LABEL: @test72vec( +; CHECK-NEXT: [[TMP1:%.*]] = lshr <2 x i32> [[X:%.*]], +; CHECK-NEXT: [[TMP2:%.*]] = and <2 x i32> [[TMP1]], +; CHECK-NEXT: [[TMP3:%.*]] = or <2 x i32> [[TMP2]], +; CHECK-NEXT: ret <2 x i32> [[TMP3]] +; + %1 = and <2 x i32> %x, + %2 = icmp eq <2 x i32> %1, + %3 = select <2 x i1> %2, <2 x i32> , <2 x i32> + ret <2 x i32> %3 +} + +define i32 @test73(i32 %x) { +; CHECK-LABEL: @test73( +; CHECK-NEXT: [[TMP1:%.*]] = lshr i32 [[X:%.*]], 6 +; CHECK-NEXT: [[TMP2:%.*]] = and i32 [[TMP1]], 2 +; CHECK-NEXT: [[TMP3:%.*]] = or i32 [[TMP2]], 40 +; CHECK-NEXT: ret i32 [[TMP3]] +; + %1 = trunc i32 %x to i8 + %2 = icmp sgt i8 %1, -1 + %3 = select i1 %2, i32 40, i32 42 + ret i32 %3 +} + +define <2 x i32> @test73vec(<2 x i32> %x) { +; CHECK-LABEL: @test73vec( +; CHECK-NEXT: [[TMP1:%.*]] = lshr <2 x i32> [[X:%.*]], +; CHECK-NEXT: [[TMP2:%.*]] = and <2 x i32> [[TMP1]], +; CHECK-NEXT: [[TMP3:%.*]] = or <2 x i32> [[TMP2]], +; CHECK-NEXT: ret <2 x i32> [[TMP3]] +; + %1 = trunc <2 x i32> %x to <2 x i8> + %2 = icmp sgt <2 x i8> %1, + %3 = select <2 x i1> %2, <2 x i32> , <2 x i32> + ret <2 x i32> %3 +} + +define i32 @test74(i32 %x) { +; CHECK-LABEL: @test74( +; CHECK-NEXT: [[TMP1:%.*]] = ashr i32 [[X:%.*]], 31 +; CHECK-NEXT: [[TMP2:%.*]] = and i32 [[TMP1]], 2 +; CHECK-NEXT: [[TMP3:%.*]] = or i32 [[TMP2]], 40 +; CHECK-NEXT: ret i32 [[TMP3]] +; + %1 = icmp sgt i32 %x, -1 + %2 = select i1 %1, i32 40, i32 42 + ret i32 %2 +} + +define <2 x i32> @test74vec(<2 x i32> %x) { +; CHECK-LABEL: @test74vec( +; CHECK-NEXT: [[TMP1:%.*]] = ashr <2 x i32> [[X:%.*]], +; CHECK-NEXT: [[TMP2:%.*]] = and <2 x i32> [[TMP1]], +; CHECK-NEXT: [[TMP3:%.*]] = or <2 x i32> [[TMP2]], +; CHECK-NEXT: ret <2 x i32> [[TMP3]] +; + %1 = icmp sgt <2 x i32> %x, + %2 = select <2 x i1> %1, <2 x i32> , <2 x i32> + ret <2 x i32> %2 +} + define i32 @shift_no_xor_multiuse_or(i32 %x, i32 %y) { ; CHECK-LABEL: @shift_no_xor_multiuse_or( ; CHECK-NEXT: [[OR:%.*]] = or i32 [[Y:%.*]], 2 -- 2.50.1