From 091e96c54f2050bb89b1cf51eea088ddc4c41382 Mon Sep 17 00:00:00 2001 From: Craig Topper Date: Fri, 16 Jun 2017 05:10:37 +0000 Subject: [PATCH] [InstCombine] Fold (!iszero(A & K1) & !iszero(A & K2)) -> (A & (K1 | K2)) == (K1 | K2) if K1 and K2 are a 1-bit mask Summary: This is the demorganed version of the case we already handle for the OR of iszero. Reviewers: spatel Reviewed By: spatel Subscribers: llvm-commits Differential Revision: https://reviews.llvm.org/D34244 git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@305548 91177308-0d34-0410-b5e6-96231b3b80d8 --- .../InstCombine/InstCombineAndOrXor.cpp | 89 ++++++++++++------- .../InstCombine/InstCombineInternal.h | 4 +- test/Transforms/InstCombine/onehot_merge.ll | 20 ++--- 3 files changed, 69 insertions(+), 44 deletions(-) diff --git a/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp index 8c9acc1147e..a881bda5ba9 100644 --- a/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ b/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -763,8 +763,54 @@ foldAndOrOfEqualityCmpsWithConstants(ICmpInst *LHS, ICmpInst *RHS, return nullptr; } +// Fold (iszero(A & K1) | iszero(A & K2)) -> (A & (K1 | K2)) != (K1 | K2) +// Fold (!iszero(A & K1) & !iszero(A & K2)) -> (A & (K1 | K2)) == (K1 | K2) +Value *InstCombiner::foldAndOrOfICmpsOfAndWithPow2(ICmpInst *LHS, ICmpInst *RHS, + bool JoinedByAnd, + Instruction &CxtI) { + ICmpInst::Predicate Pred = LHS->getPredicate(); + if (Pred != RHS->getPredicate()) + return nullptr; + if (JoinedByAnd && Pred != ICmpInst::ICMP_NE) + return nullptr; + if (!JoinedByAnd && Pred != ICmpInst::ICMP_EQ) + return nullptr; + + // TODO support vector splats + ConstantInt *LHSC = dyn_cast(LHS->getOperand(1)); + ConstantInt *RHSC = dyn_cast(RHS->getOperand(1)); + if (!LHSC || !RHSC || !LHSC->isZero() || !RHSC->isZero()) + return nullptr; + + Value *A, *B, *C, *D; + if (match(LHS->getOperand(0), m_And(m_Value(A), m_Value(B))) && + match(RHS->getOperand(0), m_And(m_Value(C), m_Value(D)))) { + if (A == D || B == D) + std::swap(C, D); + if (B == C) + std::swap(A, B); + + if (A == C && + isKnownToBeAPowerOfTwo(B, false, 0, &CxtI) && + isKnownToBeAPowerOfTwo(D, false, 0, &CxtI)) { + Value *Mask = Builder->CreateOr(B, D); + Value *Masked = Builder->CreateAnd(A, Mask); + auto NewPred = JoinedByAnd ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE; + return Builder->CreateICmp(NewPred, Masked, Mask); + } + } + + return nullptr; +} + /// Fold (icmp)&(icmp) if possible. -Value *InstCombiner::foldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS) { +Value *InstCombiner::foldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS, + Instruction &CxtI) { + // Fold (!iszero(A & K1) & !iszero(A & K2)) -> (A & (K1 | K2)) == (K1 | K2) + // if K1 and K2 are a one-bit mask. + if (Value *V = foldAndOrOfICmpsOfAndWithPow2(LHS, RHS, true, CxtI)) + return V; + ICmpInst::Predicate PredL = LHS->getPredicate(), PredR = RHS->getPredicate(); // (icmp1 A, B) & (icmp2 A, B) --> (icmp3 A, B) @@ -1127,7 +1173,7 @@ Instruction *InstCombiner::foldCastedBitwiseLogic(BinaryOperator &I) { ICmpInst *ICmp0 = dyn_cast(Cast0Src); ICmpInst *ICmp1 = dyn_cast(Cast1Src); if (ICmp0 && ICmp1) { - Value *Res = LogicOpc == Instruction::And ? foldAndOfICmps(ICmp0, ICmp1) + Value *Res = LogicOpc == Instruction::And ? foldAndOfICmps(ICmp0, ICmp1, I) : foldOrOfICmps(ICmp0, ICmp1, I); if (Res) return CastInst::Create(CastOpcode, Res, DestTy); @@ -1426,7 +1472,7 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) { ICmpInst *LHS = dyn_cast(Op0); ICmpInst *RHS = dyn_cast(Op1); if (LHS && RHS) - if (Value *Res = foldAndOfICmps(LHS, RHS)) + if (Value *Res = foldAndOfICmps(LHS, RHS, I)) return replaceInstUsesWith(I, Res); // TODO: Make this recursive; it's a little tricky because an arbitrary @@ -1434,18 +1480,18 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) { Value *X, *Y; if (LHS && match(Op1, m_OneUse(m_And(m_Value(X), m_Value(Y))))) { if (auto *Cmp = dyn_cast(X)) - if (Value *Res = foldAndOfICmps(LHS, Cmp)) + if (Value *Res = foldAndOfICmps(LHS, Cmp, I)) return replaceInstUsesWith(I, Builder->CreateAnd(Res, Y)); if (auto *Cmp = dyn_cast(Y)) - if (Value *Res = foldAndOfICmps(LHS, Cmp)) + if (Value *Res = foldAndOfICmps(LHS, Cmp, I)) return replaceInstUsesWith(I, Builder->CreateAnd(Res, X)); } if (RHS && match(Op0, m_OneUse(m_And(m_Value(X), m_Value(Y))))) { if (auto *Cmp = dyn_cast(X)) - if (Value *Res = foldAndOfICmps(Cmp, RHS)) + if (Value *Res = foldAndOfICmps(Cmp, RHS, I)) return replaceInstUsesWith(I, Builder->CreateAnd(Res, Y)); if (auto *Cmp = dyn_cast(Y)) - if (Value *Res = foldAndOfICmps(Cmp, RHS)) + if (Value *Res = foldAndOfICmps(Cmp, RHS, I)) return replaceInstUsesWith(I, Builder->CreateAnd(Res, X)); } } @@ -1592,34 +1638,15 @@ static Value *matchSelectFromAndOr(Value *A, Value *C, Value *B, Value *D, /// Fold (icmp)|(icmp) if possible. Value *InstCombiner::foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, Instruction &CxtI) { - ICmpInst::Predicate PredL = LHS->getPredicate(), PredR = RHS->getPredicate(); - // Fold (iszero(A & K1) | iszero(A & K2)) -> (A & (K1 | K2)) != (K1 | K2) // if K1 and K2 are a one-bit mask. - ConstantInt *LHSC = dyn_cast(LHS->getOperand(1)); - ConstantInt *RHSC = dyn_cast(RHS->getOperand(1)); - - // TODO support vector splats - if (LHS->getPredicate() == ICmpInst::ICMP_EQ && LHSC && LHSC->isZero() && - RHS->getPredicate() == ICmpInst::ICMP_EQ && RHSC && RHSC->isZero()) { + if (Value *V = foldAndOrOfICmpsOfAndWithPow2(LHS, RHS, false, CxtI)) + return V; - Value *A, *B, *C, *D; - if (match(LHS->getOperand(0), m_And(m_Value(A), m_Value(B))) && - match(RHS->getOperand(0), m_And(m_Value(C), m_Value(D)))) { - if (A == D || B == D) - std::swap(C, D); - if (B == C) - std::swap(A, B); + ICmpInst::Predicate PredL = LHS->getPredicate(), PredR = RHS->getPredicate(); - if (A == C && - isKnownToBeAPowerOfTwo(B, false, 0, &CxtI) && - isKnownToBeAPowerOfTwo(D, false, 0, &CxtI)) { - Value *Mask = Builder->CreateOr(B, D); - Value *Masked = Builder->CreateAnd(A, Mask); - return Builder->CreateICmp(ICmpInst::ICMP_NE, Masked, Mask); - } - } - } + ConstantInt *LHSC = dyn_cast(LHS->getOperand(1)); + ConstantInt *RHSC = dyn_cast(RHS->getOperand(1)); // Fold (icmp ult/ule (A + C1), C3) | (icmp ult/ule (A + C2), C3) // --> (icmp ult/ule ((A & ~(C1 ^ C2)) + max(C1, C2)), C3) diff --git a/lib/Transforms/InstCombine/InstCombineInternal.h b/lib/Transforms/InstCombine/InstCombineInternal.h index 9a52a83e37f..8c5aeb2a0cb 100644 --- a/lib/Transforms/InstCombine/InstCombineInternal.h +++ b/lib/Transforms/InstCombine/InstCombineInternal.h @@ -447,12 +447,14 @@ private: Instruction::CastOps isEliminableCastPair(const CastInst *CI1, const CastInst *CI2); - Value *foldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS); + Value *foldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS, Instruction &CxtI); Value *foldAndOfFCmps(FCmpInst *LHS, FCmpInst *RHS); Value *foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, Instruction &CxtI); Value *foldOrOfFCmps(FCmpInst *LHS, FCmpInst *RHS); Value *foldXorOfICmps(ICmpInst *LHS, ICmpInst *RHS); + Value *foldAndOrOfICmpsOfAndWithPow2(ICmpInst *LHS, ICmpInst *RHS, + bool JoinedByAnd, Instruction &CxtI); public: /// \brief Inserts an instruction \p New before instruction \p Old /// diff --git a/test/Transforms/InstCombine/onehot_merge.ll b/test/Transforms/InstCombine/onehot_merge.ll index 086de3ef70c..47a4ca4b628 100644 --- a/test/Transforms/InstCombine/onehot_merge.ll +++ b/test/Transforms/InstCombine/onehot_merge.ll @@ -73,12 +73,10 @@ define i1 @foo1_or(i32 %k, i32 %c1, i32 %c2) { ; CHECK-LABEL: @foo1_or( ; CHECK-NEXT: [[TMP:%.*]] = shl i32 1, [[C1:%.*]] ; CHECK-NEXT: [[TMP4:%.*]] = lshr i32 -2147483648, [[C2:%.*]] -; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[TMP]], [[K:%.*]] -; CHECK-NEXT: [[TMP2:%.*]] = icmp ne i32 [[TMP1]], 0 -; CHECK-NEXT: [[TMP5:%.*]] = and i32 [[TMP4]], [[K]] -; CHECK-NEXT: [[TMP6:%.*]] = icmp ne i32 [[TMP5]], 0 -; CHECK-NEXT: [[OR:%.*]] = and i1 [[TMP2]], [[TMP6]] -; CHECK-NEXT: ret i1 [[OR]] +; CHECK-NEXT: [[TMP1:%.*]] = or i32 [[TMP]], [[TMP4]] +; CHECK-NEXT: [[TMP2:%.*]] = and i32 [[TMP1]], [[K:%.*]] +; CHECK-NEXT: [[TMP3:%.*]] = icmp eq i32 [[TMP2]], [[TMP1]] +; CHECK-NEXT: ret i1 [[TMP3]] ; %tmp = shl i32 1, %c1 %tmp4 = lshr i32 -2147483648, %c2 @@ -96,12 +94,10 @@ define i1 @foo1_or_commuted(i32 %k, i32 %c1, i32 %c2) { ; CHECK-NEXT: [[K2:%.*]] = mul i32 [[K:%.*]], [[K]] ; CHECK-NEXT: [[TMP:%.*]] = shl i32 1, [[C1:%.*]] ; CHECK-NEXT: [[TMP4:%.*]] = lshr i32 -2147483648, [[C2:%.*]] -; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[K2]], [[TMP]] -; CHECK-NEXT: [[TMP2:%.*]] = icmp ne i32 [[TMP1]], 0 -; CHECK-NEXT: [[TMP5:%.*]] = and i32 [[TMP4]], [[K2]] -; CHECK-NEXT: [[TMP6:%.*]] = icmp ne i32 [[TMP5]], 0 -; CHECK-NEXT: [[OR:%.*]] = and i1 [[TMP2]], [[TMP6]] -; CHECK-NEXT: ret i1 [[OR]] +; CHECK-NEXT: [[TMP1:%.*]] = or i32 [[TMP]], [[TMP4]] +; CHECK-NEXT: [[TMP2:%.*]] = and i32 [[K2]], [[TMP1]] +; CHECK-NEXT: [[TMP3:%.*]] = icmp eq i32 [[TMP2]], [[TMP1]] +; CHECK-NEXT: ret i1 [[TMP3]] ; %k2 = mul i32 %k, %k ; to trick the complexity sorting %tmp = shl i32 1, %c1 -- 2.50.1