From 17bb71c533efd6783e0b47581796e99e85178d6f Mon Sep 17 00:00:00 2001 From: Roman Lebedev Date: Sat, 24 Aug 2019 06:49:36 +0000 Subject: [PATCH] [InstCombine] matchThreeWayIntCompare(): commutativity awareness Summary: `matchThreeWayIntCompare()` looks for ``` select i1 (a == b), i32 Equal, i32 (select i1 (a < b), i32 Less, i32 Greater) ``` but both of these selects/compares can be in it's commuted form, so out of 8 variants, only the two most basic ones is handled. This fixes regression being introduced in D66232. Reviewers: spatel, nikic, efriedma, xbolva00 Reviewed By: spatel Subscribers: hiraditya, llvm-commits Tags: #llvm Differential Revision: https://reviews.llvm.org/D66607 git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@369841 91177308-0d34-0410-b5e6-96231b3b80d8 --- .../InstCombine/InstCombineCompares.cpp | 55 ++++++++++---- .../unrecognized_three-way-comparison.ll | 76 ++++++------------- 2 files changed, 66 insertions(+), 65 deletions(-) diff --git a/lib/Transforms/InstCombine/InstCombineCompares.cpp b/lib/Transforms/InstCombine/InstCombineCompares.cpp index 0b3edddaf93..b80e9ade75d 100644 --- a/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -2562,20 +2562,49 @@ bool InstCombiner::matchThreeWayIntCompare(SelectInst *SI, Value *&LHS, // TODO: Generalize this to work with other comparison idioms or ensure // they get canonicalized into this form. - // select i1 (a == b), i32 Equal, i32 (select i1 (a < b), i32 Less, i32 - // Greater), where Equal, Less and Greater are placeholders for any three - // constants. - ICmpInst::Predicate PredA, PredB; - if (match(SI->getTrueValue(), m_ConstantInt(Equal)) && - match(SI->getCondition(), m_ICmp(PredA, m_Value(LHS), m_Value(RHS))) && - PredA == ICmpInst::ICMP_EQ && - match(SI->getFalseValue(), - m_Select(m_ICmp(PredB, m_Specific(LHS), m_Specific(RHS)), - m_ConstantInt(Less), m_ConstantInt(Greater))) && - PredB == ICmpInst::ICMP_SLT) { - return true; + // select i1 (a == b), + // i32 Equal, + // i32 (select i1 (a < b), i32 Less, i32 Greater) + // where Equal, Less and Greater are placeholders for any three constants. + ICmpInst::Predicate PredA; + if (!match(SI->getCondition(), m_ICmp(PredA, m_Value(LHS), m_Value(RHS))) || + !ICmpInst::isEquality(PredA)) + return false; + Value *EqualVal = SI->getTrueValue(); + Value *UnequalVal = SI->getFalseValue(); + // We still can get non-canonical predicate here, so canonicalize. + if (PredA == ICmpInst::ICMP_NE) + std::swap(EqualVal, UnequalVal); + if (!match(EqualVal, m_ConstantInt(Equal))) + return false; + ICmpInst::Predicate PredB; + Value *LHS2, *RHS2; + if (!match(UnequalVal, m_Select(m_ICmp(PredB, m_Value(LHS2), m_Value(RHS2)), + m_ConstantInt(Less), m_ConstantInt(Greater)))) + return false; + // We can get predicate mismatch here, so canonicalize if possible: + // First, ensure that 'LHS' match. + if (LHS2 != LHS) { + // x sgt y <--> y slt x + std::swap(LHS2, RHS2); + PredB = ICmpInst::getSwappedPredicate(PredB); + } + if (LHS2 != LHS) + return false; + // We also need to canonicalize 'RHS'. + if (PredB == ICmpInst::ICMP_SGT && isa(RHS2)) { + // x sgt C-1 <--> x sge C <--> not(x slt C) + auto FlippedStrictness = + getFlippedStrictnessPredicateAndConstant(PredB, cast(RHS2)); + if (!FlippedStrictness) + return false; + assert(FlippedStrictness->first == ICmpInst::ICMP_SGE && "Sanity check"); + RHS2 = FlippedStrictness->second; + // And kind-of perform the result swap. + std::swap(Less, Greater); + PredB = ICmpInst::ICMP_SLT; } - return false; + return PredB == ICmpInst::ICMP_SLT && RHS == RHS2; } Instruction *InstCombiner::foldICmpSelectConstant(ICmpInst &Cmp, diff --git a/test/Transforms/InstCombine/unrecognized_three-way-comparison.ll b/test/Transforms/InstCombine/unrecognized_three-way-comparison.ll index 1ca16778299..14e0b557394 100644 --- a/test/Transforms/InstCombine/unrecognized_three-way-comparison.ll +++ b/test/Transforms/InstCombine/unrecognized_three-way-comparison.ll @@ -36,14 +36,10 @@ exit: define i32 @compare_against_zero(i32 %x) { ; CHECK-LABEL: @compare_against_zero( ; CHECK-NEXT: entry: -; CHECK-NEXT: [[CMP1:%.*]] = icmp eq i32 [[X:%.*]], 0 -; CHECK-NEXT: [[CMP2_INV:%.*]] = icmp sgt i32 [[X]], -1 -; CHECK-NEXT: [[SELECT1:%.*]] = select i1 [[CMP2_INV]], i32 1, i32 -1 -; CHECK-NEXT: [[SELECT2:%.*]] = select i1 [[CMP1]], i32 0, i32 [[SELECT1]] -; CHECK-NEXT: [[COND:%.*]] = icmp sgt i32 [[SELECT2]], 0 -; CHECK-NEXT: br i1 [[COND]], label [[CALLFOO:%.*]], label [[EXIT:%.*]] +; CHECK-NEXT: [[TMP0:%.*]] = icmp sgt i32 [[X:%.*]], 0 +; CHECK-NEXT: br i1 [[TMP0]], label [[CALLFOO:%.*]], label [[EXIT:%.*]] ; CHECK: callfoo: -; CHECK-NEXT: call void @foo(i32 [[SELECT2]]) +; CHECK-NEXT: call void @foo(i32 1) ; CHECK-NEXT: br label [[EXIT]] ; CHECK: exit: ; CHECK-NEXT: ret i32 42 @@ -96,14 +92,10 @@ exit: define i32 @compare_against_two(i32 %x) { ; CHECK-LABEL: @compare_against_two( ; CHECK-NEXT: entry: -; CHECK-NEXT: [[CMP1:%.*]] = icmp eq i32 [[X:%.*]], 2 -; CHECK-NEXT: [[CMP2_INV:%.*]] = icmp sgt i32 [[X]], 1 -; CHECK-NEXT: [[SELECT1:%.*]] = select i1 [[CMP2_INV]], i32 1, i32 -1 -; CHECK-NEXT: [[SELECT2:%.*]] = select i1 [[CMP1]], i32 0, i32 [[SELECT1]] -; CHECK-NEXT: [[COND:%.*]] = icmp sgt i32 [[SELECT2]], 0 -; CHECK-NEXT: br i1 [[COND]], label [[CALLFOO:%.*]], label [[EXIT:%.*]] +; CHECK-NEXT: [[TMP0:%.*]] = icmp sgt i32 [[X:%.*]], 2 +; CHECK-NEXT: br i1 [[TMP0]], label [[CALLFOO:%.*]], label [[EXIT:%.*]] ; CHECK: callfoo: -; CHECK-NEXT: call void @foo(i32 [[SELECT2]]) +; CHECK-NEXT: call void @foo(i32 1) ; CHECK-NEXT: br label [[EXIT]] ; CHECK: exit: ; CHECK-NEXT: ret i32 42 @@ -497,13 +489,10 @@ define i32 @compare_against_fortytwo_commutatibility_1(i32 %x) { ; CHECK-NEXT: entry: ; CHECK-NEXT: [[CMP1:%.*]] = icmp ne i32 [[X:%.*]], 42 ; CHECK-NEXT: call void @use1(i1 [[CMP1]]) -; CHECK-NEXT: [[CMP2:%.*]] = icmp slt i32 [[X]], 42 -; CHECK-NEXT: [[SELECT1:%.*]] = select i1 [[CMP2]], i32 -1, i32 1 -; CHECK-NEXT: [[SELECT2:%.*]] = select i1 [[CMP1]], i32 [[SELECT1]], i32 0 -; CHECK-NEXT: [[COND:%.*]] = icmp sgt i32 [[SELECT2]], 0 -; CHECK-NEXT: br i1 [[COND]], label [[CALLFOO:%.*]], label [[EXIT:%.*]] +; CHECK-NEXT: [[TMP0:%.*]] = icmp sgt i32 [[X]], 42 +; CHECK-NEXT: br i1 [[TMP0]], label [[CALLFOO:%.*]], label [[EXIT:%.*]] ; CHECK: callfoo: -; CHECK-NEXT: call void @foo(i32 [[SELECT2]]) +; CHECK-NEXT: call void @foo(i32 1) ; CHECK-NEXT: br label [[EXIT]] ; CHECK: exit: ; CHECK-NEXT: ret i32 84 @@ -527,14 +516,10 @@ exit: define i32 @compare_against_fortytwo_commutatibility_2(i32 %x) { ; CHECK-LABEL: @compare_against_fortytwo_commutatibility_2( ; CHECK-NEXT: entry: -; CHECK-NEXT: [[CMP1:%.*]] = icmp eq i32 [[X:%.*]], 42 -; CHECK-NEXT: [[CMP2:%.*]] = icmp sgt i32 [[X]], 41 -; CHECK-NEXT: [[SELECT1:%.*]] = select i1 [[CMP2]], i32 1, i32 -1 -; CHECK-NEXT: [[SELECT2:%.*]] = select i1 [[CMP1]], i32 0, i32 [[SELECT1]] -; CHECK-NEXT: [[COND:%.*]] = icmp sgt i32 [[SELECT2]], 0 -; CHECK-NEXT: br i1 [[COND]], label [[CALLFOO:%.*]], label [[EXIT:%.*]] +; CHECK-NEXT: [[TMP0:%.*]] = icmp sgt i32 [[X:%.*]], 42 +; CHECK-NEXT: br i1 [[TMP0]], label [[CALLFOO:%.*]], label [[EXIT:%.*]] ; CHECK: callfoo: -; CHECK-NEXT: call void @foo(i32 [[SELECT2]]) +; CHECK-NEXT: call void @foo(i32 1) ; CHECK-NEXT: br label [[EXIT]] ; CHECK: exit: ; CHECK-NEXT: ret i32 84 @@ -559,13 +544,10 @@ define i32 @compare_against_fortytwo_commutatibility_3(i32 %x) { ; CHECK-NEXT: entry: ; CHECK-NEXT: [[CMP1:%.*]] = icmp ne i32 [[X:%.*]], 42 ; CHECK-NEXT: call void @use1(i1 [[CMP1]]) -; CHECK-NEXT: [[CMP2:%.*]] = icmp sgt i32 [[X]], 41 -; CHECK-NEXT: [[SELECT1:%.*]] = select i1 [[CMP2]], i32 1, i32 -1 -; CHECK-NEXT: [[SELECT2:%.*]] = select i1 [[CMP1]], i32 [[SELECT1]], i32 0 -; CHECK-NEXT: [[COND:%.*]] = icmp sgt i32 [[SELECT2]], 0 -; CHECK-NEXT: br i1 [[COND]], label [[CALLFOO:%.*]], label [[EXIT:%.*]] +; CHECK-NEXT: [[TMP0:%.*]] = icmp sgt i32 [[X]], 42 +; CHECK-NEXT: br i1 [[TMP0]], label [[CALLFOO:%.*]], label [[EXIT:%.*]] ; CHECK: callfoo: -; CHECK-NEXT: call void @foo(i32 [[SELECT2]]) +; CHECK-NEXT: call void @foo(i32 1) ; CHECK-NEXT: br label [[EXIT]] ; CHECK: exit: ; CHECK-NEXT: ret i32 84 @@ -616,14 +598,10 @@ exit: define i32 @compare_against_arbitrary_value_commutativity1(i32 %x, i32 %c) { ; CHECK-LABEL: @compare_against_arbitrary_value_commutativity1( ; CHECK-NEXT: entry: -; CHECK-NEXT: [[CMP1:%.*]] = icmp eq i32 [[X:%.*]], [[C:%.*]] -; CHECK-NEXT: [[CMP2:%.*]] = icmp sgt i32 [[C]], [[X]] -; CHECK-NEXT: [[SELECT1:%.*]] = select i1 [[CMP2]], i32 -1, i32 1 -; CHECK-NEXT: [[SELECT2:%.*]] = select i1 [[CMP1]], i32 0, i32 [[SELECT1]] -; CHECK-NEXT: [[COND:%.*]] = icmp sgt i32 [[SELECT2]], 0 -; CHECK-NEXT: br i1 [[COND]], label [[CALLFOO:%.*]], label [[EXIT:%.*]] +; CHECK-NEXT: [[TMP0:%.*]] = icmp sgt i32 [[X:%.*]], [[C:%.*]] +; CHECK-NEXT: br i1 [[TMP0]], label [[CALLFOO:%.*]], label [[EXIT:%.*]] ; CHECK: callfoo: -; CHECK-NEXT: call void @foo(i32 [[SELECT2]]) +; CHECK-NEXT: call void @foo(i32 1) ; CHECK-NEXT: br label [[EXIT]] ; CHECK: exit: ; CHECK-NEXT: ret i32 42 @@ -648,13 +626,10 @@ define i32 @compare_against_arbitrary_value_commutativity2(i32 %x, i32 %c) { ; CHECK-NEXT: entry: ; CHECK-NEXT: [[CMP1:%.*]] = icmp ne i32 [[X:%.*]], [[C:%.*]] ; CHECK-NEXT: call void @use1(i1 [[CMP1]]) -; CHECK-NEXT: [[CMP2:%.*]] = icmp slt i32 [[X]], [[C]] -; CHECK-NEXT: [[SELECT1:%.*]] = select i1 [[CMP2]], i32 -1, i32 1 -; CHECK-NEXT: [[SELECT2:%.*]] = select i1 [[CMP1]], i32 [[SELECT1]], i32 0 -; CHECK-NEXT: [[COND:%.*]] = icmp sgt i32 [[SELECT2]], 0 -; CHECK-NEXT: br i1 [[COND]], label [[CALLFOO:%.*]], label [[EXIT:%.*]] +; CHECK-NEXT: [[TMP0:%.*]] = icmp sgt i32 [[X]], [[C]] +; CHECK-NEXT: br i1 [[TMP0]], label [[CALLFOO:%.*]], label [[EXIT:%.*]] ; CHECK: callfoo: -; CHECK-NEXT: call void @foo(i32 [[SELECT2]]) +; CHECK-NEXT: call void @foo(i32 1) ; CHECK-NEXT: br label [[EXIT]] ; CHECK: exit: ; CHECK-NEXT: ret i32 42 @@ -680,13 +655,10 @@ define i32 @compare_against_arbitrary_value_commutativity3(i32 %x, i32 %c) { ; CHECK-NEXT: entry: ; CHECK-NEXT: [[CMP1:%.*]] = icmp ne i32 [[X:%.*]], [[C:%.*]] ; CHECK-NEXT: call void @use1(i1 [[CMP1]]) -; CHECK-NEXT: [[CMP2:%.*]] = icmp sgt i32 [[C]], [[X]] -; CHECK-NEXT: [[SELECT1:%.*]] = select i1 [[CMP2]], i32 -1, i32 1 -; CHECK-NEXT: [[SELECT2:%.*]] = select i1 [[CMP1]], i32 [[SELECT1]], i32 0 -; CHECK-NEXT: [[COND:%.*]] = icmp sgt i32 [[SELECT2]], 0 -; CHECK-NEXT: br i1 [[COND]], label [[CALLFOO:%.*]], label [[EXIT:%.*]] +; CHECK-NEXT: [[TMP0:%.*]] = icmp sgt i32 [[X]], [[C]] +; CHECK-NEXT: br i1 [[TMP0]], label [[CALLFOO:%.*]], label [[EXIT:%.*]] ; CHECK: callfoo: -; CHECK-NEXT: call void @foo(i32 [[SELECT2]]) +; CHECK-NEXT: call void @foo(i32 1) ; CHECK-NEXT: br label [[EXIT]] ; CHECK: exit: ; CHECK-NEXT: ret i32 42 -- 2.40.0