From: Anna Thomas Date: Fri, 23 Jun 2017 13:41:45 +0000 (+0000) Subject: [InstCombine] Recognize and simplify three way comparison idioms X-Git-Url: https://granicus.if.org/sourcecode?a=commitdiff_plain;h=2cfdb4aa6ccbdbd2bf91a865e3db0e490eb690aa;p=llvm [InstCombine] Recognize and simplify three way comparison idioms Summary: Many languages have a three way comparison idiom where comparing two values produces not a boolean, but a tri-state value. Typical values (e.g. as used in the lcmp/fcmp bytecodes from Java) are -1 for less than, 0 for equality, and +1 for greater than. We actually do a great job already of converting three way comparisons into binary comparisons when the result produced has one a single use. Unfortunately, such values can have more than one use, and in that case, our existing optimizations break down. The patch adds a peephole which converts a three-way compare + test idiom into a binary comparison on the original inputs. It focused on replacing the test on the result of the three way compare and does nothing about removing the three way compare itself. That's left to other optimizations (which do actually kick in commonly.) We currently recognize one idiom on signed integer compare. In the future, we plan to recognize and simplify other comparison idioms on other signed/unsigned datatypes such as floats, vectors etc. This is a resurrection of Philip Reames' original patch: https://reviews.llvm.org/D19452 Reviewers: majnemer, apilipenko, reames, sanjoy, mkazantsev Reviewed by: mkazantsev Subscribers: llvm-commits Differential Revision: https://reviews.llvm.org/D34278 git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@306100 91177308-0d34-0410-b5e6-96231b3b80d8 --- diff --git a/lib/Transforms/InstCombine/InstCombineCompares.cpp b/lib/Transforms/InstCombine/InstCombineCompares.cpp index 1ef4acfb058..bc79e4534e6 100644 --- a/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -2434,6 +2434,77 @@ Instruction *InstCombiner::foldICmpAddConstant(ICmpInst &Cmp, return nullptr; } +bool InstCombiner::matchThreeWayIntCompare(SelectInst *SI, Value *&LHS, + Value *&RHS, ConstantInt *&Less, + ConstantInt *&Equal, + ConstantInt *&Greater) { + // TODO: Generalize this to work with other comparison idioms or ensure + // they get canonicalized into this form. + + // select i1 (a == b), i32 Equal, i32 (select i1 (a < b), i32 Less, i32 + // Greater), where Equal, Less and Greater are placeholders for any three + // constants. + ICmpInst::Predicate PredA, PredB; + if (match(SI->getTrueValue(), m_ConstantInt(Equal)) && + match(SI->getCondition(), m_ICmp(PredA, m_Value(LHS), m_Value(RHS))) && + PredA == ICmpInst::ICMP_EQ && + match(SI->getFalseValue(), + m_Select(m_ICmp(PredB, m_Specific(LHS), m_Specific(RHS)), + m_ConstantInt(Less), m_ConstantInt(Greater))) && + PredB == ICmpInst::ICMP_SLT) { + return true; + } + return false; +} + +Instruction *InstCombiner::foldICmpSelectConstant(ICmpInst &Cmp, + Instruction *Select, + ConstantInt *C) { + + assert(C && "Cmp RHS should be a constant int!"); + // If we're testing a constant value against the result of a three way + // comparison, the result can be expressed directly in terms of the + // original values being compared. Note: We could possibly be more + // aggressive here and remove the hasOneUse test. The original select is + // really likely to simplify or sink when we remove a test of the result. + Value *OrigLHS, *OrigRHS; + ConstantInt *C1LessThan, *C2Equal, *C3GreaterThan; + if (Cmp.hasOneUse() && + matchThreeWayIntCompare(cast(Select), OrigLHS, OrigRHS, + C1LessThan, C2Equal, C3GreaterThan)) { + assert(C1LessThan && C2Equal && C3GreaterThan); + + bool TrueWhenLessThan = + ConstantExpr::getCompare(Cmp.getPredicate(), C1LessThan, C) + ->isAllOnesValue(); + bool TrueWhenEqual = + ConstantExpr::getCompare(Cmp.getPredicate(), C2Equal, C) + ->isAllOnesValue(); + bool TrueWhenGreaterThan = + ConstantExpr::getCompare(Cmp.getPredicate(), C3GreaterThan, C) + ->isAllOnesValue(); + + // This generates the new instruction that will replace the original Cmp + // Instruction. Instead of enumerating the various combinations when + // TrueWhenLessThan, TrueWhenEqual and TrueWhenGreaterThan are true versus + // false, we rely on chaining of ORs and future passes of InstCombine to + // simplify the OR further (i.e. a s< b || a == b becomes a s<= b). + + // When none of the three constants satisfy the predicate for the RHS (C), + // the entire original Cmp can be simplified to a false. + Value *Cond = Builder->getFalse(); + if (TrueWhenLessThan) + Cond = Builder->CreateOr(Cond, Builder->CreateICmp(ICmpInst::ICMP_SLT, OrigLHS, OrigRHS)); + if (TrueWhenEqual) + Cond = Builder->CreateOr(Cond, Builder->CreateICmp(ICmpInst::ICMP_EQ, OrigLHS, OrigRHS)); + if (TrueWhenGreaterThan) + Cond = Builder->CreateOr(Cond, Builder->CreateICmp(ICmpInst::ICMP_SGT, OrigLHS, OrigRHS)); + + return replaceInstUsesWith(Cmp, Cond); + } + return nullptr; +} + /// Try to fold integer comparisons with a constant operand: icmp Pred X, C /// where X is some kind of instruction. Instruction *InstCombiner::foldICmpInstWithConstant(ICmpInst &Cmp) { @@ -2493,11 +2564,28 @@ Instruction *InstCombiner::foldICmpInstWithConstant(ICmpInst &Cmp) { return I; } + // Match against CmpInst LHS being instructions other than binary operators. Instruction *LHSI; - if (match(Cmp.getOperand(0), m_Instruction(LHSI)) && - LHSI->getOpcode() == Instruction::Trunc) - if (Instruction *I = foldICmpTruncConstant(Cmp, LHSI, C)) - return I; + if (match(Cmp.getOperand(0), m_Instruction(LHSI))) { + switch (LHSI->getOpcode()) { + case Instruction::Select: + { + // For now, we only support constant integers while folding the + // ICMP(SELECT)) pattern. We can extend this to support vector of integers + // similar to the cases handled by binary ops above. + if (ConstantInt *ConstRHS = dyn_cast(Cmp.getOperand(1))) + if (Instruction *I = foldICmpSelectConstant(Cmp, LHSI, ConstRHS)) + return I; + break; + } + case Instruction::Trunc: + if (Instruction *I = foldICmpTruncConstant(Cmp, LHSI, C)) + return I; + break; + default: + break; + } + } if (Instruction *I = foldICmpIntrinsicWithConstant(Cmp, C)) return I; diff --git a/lib/Transforms/InstCombine/InstCombineInternal.h b/lib/Transforms/InstCombine/InstCombineInternal.h index 7f0539e52b4..1b0fe84dd4d 100644 --- a/lib/Transforms/InstCombine/InstCombineInternal.h +++ b/lib/Transforms/InstCombine/InstCombineInternal.h @@ -603,6 +603,15 @@ private: Instruction::BinaryOps, Value *, Value *, Value *, Value *); + /// Match a select chain which produces one of three values based on whether + /// the LHS is less than, equal to, or greater than RHS respectively. + /// Return true if we matched a three way compare idiom. The LHS, RHS, Less, + /// Equal and Greater values are saved in the matching process and returned to + /// the caller. + bool matchThreeWayIntCompare(SelectInst *SI, Value *&LHS, Value *&RHS, + ConstantInt *&Less, ConstantInt *&Equal, + ConstantInt *&Greater); + /// \brief Attempts to replace V with a simpler value based on the demanded /// bits. Value *SimplifyDemandedUseBits(Value *V, APInt DemandedMask, KnownBits &Known, @@ -680,6 +689,8 @@ private: Instruction *foldICmpBinOp(ICmpInst &Cmp); Instruction *foldICmpEquality(ICmpInst &Cmp); + Instruction *foldICmpSelectConstant(ICmpInst &Cmp, Instruction *Select, + ConstantInt *C); Instruction *foldICmpTruncConstant(ICmpInst &Cmp, Instruction *Trunc, const APInt *C); Instruction *foldICmpAndConstant(ICmpInst &Cmp, BinaryOperator *And, diff --git a/test/Transforms/InstCombine/compare-3way.ll b/test/Transforms/InstCombine/compare-3way.ll new file mode 100644 index 00000000000..663d470df87 --- /dev/null +++ b/test/Transforms/InstCombine/compare-3way.ll @@ -0,0 +1,395 @@ +; RUN: opt -S -instcombine < %s | FileCheck %s + +declare void @use(i32) + +; These 18 exercise all combinations of signed comparison +; for each of the three values produced by your typical +; 3way compare function (-1, 0, 1) + +define void @test_low_sgt(i64 %a, i64 %b) { +; CHECK-LABEL: @test_low_sgt +; CHECK: [[TMP1:%.*]] = icmp slt i64 %a, %b +; CHECK: br i1 [[TMP1]], label %normal, label %unreached + %eq = icmp eq i64 %a, %b + %slt = icmp slt i64 %a, %b + %. = select i1 %slt, i32 -1, i32 1 + %result = select i1 %eq, i32 0, i32 %. + %cmp = icmp sgt i32 %result, -1 + br i1 %cmp, label %unreached, label %normal +normal: + ret void +unreached: + call void @use(i32 %result) + ret void +} + +define void @test_low_slt(i64 %a, i64 %b) { +; CHECK-LABEL: @test_low_slt +; CHECK: br i1 false, label %unreached, label %normal + %eq = icmp eq i64 %a, %b + %slt = icmp slt i64 %a, %b + %. = select i1 %slt, i32 -1, i32 1 + %result = select i1 %eq, i32 0, i32 %. + %cmp = icmp slt i32 %result, -1 + br i1 %cmp, label %unreached, label %normal +normal: + ret void +unreached: + call void @use(i32 %result) + ret void +} + +define void @test_low_sge(i64 %a, i64 %b) { +; CHECK-LABEL: @test_low_sge +; CHECK: br i1 true, label %unreached, label %normal + %eq = icmp eq i64 %a, %b + %slt = icmp slt i64 %a, %b + %. = select i1 %slt, i32 -1, i32 1 + %result = select i1 %eq, i32 0, i32 %. + %cmp = icmp sge i32 %result, -1 + br i1 %cmp, label %unreached, label %normal +normal: + ret void +unreached: + call void @use(i32 %result) + ret void +} + +define void @test_low_sle(i64 %a, i64 %b) { +; CHECK-LABEL: @test_low_sle +; CHECK: [[TMP1:%.*]] = icmp slt i64 %a, %b +; CHECK: br i1 [[TMP1]], label %unreached, label %normal + %eq = icmp eq i64 %a, %b + %slt = icmp slt i64 %a, %b + %. = select i1 %slt, i32 -1, i32 1 + %result = select i1 %eq, i32 0, i32 %. + %cmp = icmp sle i32 %result, -1 + br i1 %cmp, label %unreached, label %normal +normal: + ret void +unreached: + call void @use(i32 %result) + ret void +} + +define void @test_low_ne(i64 %a, i64 %b) { +; CHECK-LABEL: @test_low_ne +; CHECK: [[TMP1:%.*]] = icmp slt i64 %a, %b +; CHECK: br i1 [[TMP1]], label %normal, label %unreached + %eq = icmp eq i64 %a, %b + %slt = icmp slt i64 %a, %b + %. = select i1 %slt, i32 -1, i32 1 + %result = select i1 %eq, i32 0, i32 %. + %cmp = icmp ne i32 %result, -1 + br i1 %cmp, label %unreached, label %normal +normal: + ret void +unreached: + call void @use(i32 %result) + ret void +} + +define void @test_low_eq(i64 %a, i64 %b) { +; CHECK-LABEL: @test_low_eq +; CHECK: [[TMP1:%.*]] = icmp slt i64 %a, %b +; CHECK: br i1 [[TMP1]], label %unreached, label %normal + %eq = icmp eq i64 %a, %b + %slt = icmp slt i64 %a, %b + %. = select i1 %slt, i32 -1, i32 1 + %result = select i1 %eq, i32 0, i32 %. + %cmp = icmp eq i32 %result, -1 + br i1 %cmp, label %unreached, label %normal +normal: + ret void +unreached: + call void @use(i32 %result) + ret void +} + +define void @test_mid_sgt(i64 %a, i64 %b) { +; CHECK-LABEL: @test_mid_sgt +; CHECK: [[TMP1:%.*]] = icmp sgt i64 %a, %b +; CHECK: br i1 [[TMP1]], label %unreached, label %normal + %eq = icmp eq i64 %a, %b + %slt = icmp slt i64 %a, %b + %. = select i1 %slt, i32 -1, i32 1 + %result = select i1 %eq, i32 0, i32 %. + %cmp = icmp sgt i32 %result, 0 + br i1 %cmp, label %unreached, label %normal +normal: + ret void +unreached: + call void @use(i32 %result) + ret void +} + +define void @test_mid_slt(i64 %a, i64 %b) { +; CHECK-LABEL: @test_mid_slt +; CHECK: [[TMP1:%.*]] = icmp slt i64 %a, %b +; CHECK: br i1 [[TMP1]], label %unreached, label %normal + %eq = icmp eq i64 %a, %b + %slt = icmp slt i64 %a, %b + %. = select i1 %slt, i32 -1, i32 1 + %result = select i1 %eq, i32 0, i32 %. + %cmp = icmp slt i32 %result, 0 + br i1 %cmp, label %unreached, label %normal +normal: + ret void +unreached: + call void @use(i32 %result) + ret void +} + +define void @test_mid_sge(i64 %a, i64 %b) { +; CHECK-LABEL: @test_mid_sge +; CHECK: [[TMP1:%.*]] = icmp slt i64 %a, %b +; CHECK: br i1 [[TMP1]], label %normal, label %unreached + %eq = icmp eq i64 %a, %b + %slt = icmp slt i64 %a, %b + %. = select i1 %slt, i32 -1, i32 1 + %result = select i1 %eq, i32 0, i32 %. + %cmp = icmp sge i32 %result, 0 + br i1 %cmp, label %unreached, label %normal +normal: + ret void +unreached: + call void @use(i32 %result) + ret void +} + +define void @test_mid_sle(i64 %a, i64 %b) { +; CHECK-LABEL: @test_mid_sle +; CHECK: [[TMP1:%.*]] = icmp sgt i64 %a, %b +; CHECK: br i1 [[TMP1]], label %normal, label %unreached + %eq = icmp eq i64 %a, %b + %slt = icmp slt i64 %a, %b + %. = select i1 %slt, i32 -1, i32 1 + %result = select i1 %eq, i32 0, i32 %. + %cmp = icmp sle i32 %result, 0 + br i1 %cmp, label %unreached, label %normal +normal: + ret void +unreached: + call void @use(i32 %result) + ret void +} + +define void @test_mid_ne(i64 %a, i64 %b) { +; CHECK-LABEL: @test_mid_ne +; CHECK: [[TMP1:%.*]] = icmp eq i64 %a, %b +; CHECK: br i1 [[TMP1]], label %normal, label %unreached + %eq = icmp eq i64 %a, %b + %slt = icmp slt i64 %a, %b + %. = select i1 %slt, i32 -1, i32 1 + %result = select i1 %eq, i32 0, i32 %. + %cmp = icmp ne i32 %result, 0 + br i1 %cmp, label %unreached, label %normal +normal: + ret void +unreached: + call void @use(i32 %result) + ret void +} + +define void @test_mid_eq(i64 %a, i64 %b) { +; CHECK-LABEL: @test_mid_eq +; CHECK: icmp eq i64 %a, %b +; CHECK: br i1 %eq, label %unreached, label %normal + %eq = icmp eq i64 %a, %b + %slt = icmp slt i64 %a, %b + %. = select i1 %slt, i32 -1, i32 1 + %result = select i1 %eq, i32 0, i32 %. + %cmp = icmp eq i32 %result, 0 + br i1 %cmp, label %unreached, label %normal +normal: + ret void +unreached: + call void @use(i32 %result) + ret void +} + +define void @test_high_sgt(i64 %a, i64 %b) { +; CHECK-LABEL: @test_high_sgt +; CHECK: br i1 false, label %unreached, label %normal + %eq = icmp eq i64 %a, %b + %slt = icmp slt i64 %a, %b + %. = select i1 %slt, i32 -1, i32 1 + %result = select i1 %eq, i32 0, i32 %. + %cmp = icmp sgt i32 %result, 1 + br i1 %cmp, label %unreached, label %normal +normal: + ret void +unreached: + call void @use(i32 %result) + ret void +} + +define void @test_high_slt(i64 %a, i64 %b) { +; CHECK-LABEL: @test_high_slt +; CHECK: [[TMP1:%.*]] = icmp sgt i64 %a, %b +; CHECK: br i1 [[TMP1]], label %normal, label %unreached + %eq = icmp eq i64 %a, %b + %slt = icmp slt i64 %a, %b + %. = select i1 %slt, i32 -1, i32 1 + %result = select i1 %eq, i32 0, i32 %. + %cmp = icmp slt i32 %result, 1 + br i1 %cmp, label %unreached, label %normal +normal: + ret void +unreached: + call void @use(i32 %result) + ret void +} + +define void @test_high_sge(i64 %a, i64 %b) { +; CHECK-LABEL: @test_high_sge +; CHECK: [[TMP1:%.*]] = icmp sgt i64 %a, %b +; CHECK: br i1 [[TMP1]], label %unreached, label %normal + %eq = icmp eq i64 %a, %b + %slt = icmp slt i64 %a, %b + %. = select i1 %slt, i32 -1, i32 1 + %result = select i1 %eq, i32 0, i32 %. + %cmp = icmp sge i32 %result, 1 + br i1 %cmp, label %unreached, label %normal +normal: + ret void +unreached: + call void @use(i32 %result) + ret void +} + +define void @test_high_sle(i64 %a, i64 %b) { +; CHECK-LABEL: @test_high_sle +; CHECK: br i1 true, label %unreached, label %normal + %eq = icmp eq i64 %a, %b + %slt = icmp slt i64 %a, %b + %. = select i1 %slt, i32 -1, i32 1 + %result = select i1 %eq, i32 0, i32 %. + %cmp = icmp sle i32 %result, 1 + br i1 %cmp, label %unreached, label %normal +normal: + ret void +unreached: + call void @use(i32 %result) + ret void +} + +define void @test_high_ne(i64 %a, i64 %b) { +; CHECK-LABEL: @test_high_ne +; CHECK: [[TMP1:%.*]] = icmp sgt i64 %a, %b +; CHECK: br i1 [[TMP1]], label %normal, label %unreached + %eq = icmp eq i64 %a, %b + %slt = icmp slt i64 %a, %b + %. = select i1 %slt, i32 -1, i32 1 + %result = select i1 %eq, i32 0, i32 %. + %cmp = icmp ne i32 %result, 1 + br i1 %cmp, label %unreached, label %normal +normal: + ret void +unreached: + call void @use(i32 %result) + ret void +} + +define void @test_high_eq(i64 %a, i64 %b) { +; CHECK-LABEL: @test_high_eq +; CHECK: [[TMP1:%.*]] = icmp sgt i64 %a, %b +; CHECK: br i1 [[TMP1]], label %unreached, label %normal + %eq = icmp eq i64 %a, %b + %slt = icmp slt i64 %a, %b + %. = select i1 %slt, i32 -1, i32 1 + %result = select i1 %eq, i32 0, i32 %. + %cmp = icmp eq i32 %result, 1 + br i1 %cmp, label %unreached, label %normal +normal: + ret void +unreached: + call void @use(i32 %result) + ret void +} + +; These five make sure we didn't accidentally hard code one of the +; produced values + +define void @non_standard_low(i64 %a, i64 %b) { +; CHECK-LABEL: @non_standard_low +; CHECK: [[TMP1:%.*]] = icmp slt i64 %a, %b +; CHECK: br i1 [[TMP1]], label %unreached, label %normal + %eq = icmp eq i64 %a, %b + %slt = icmp slt i64 %a, %b + %. = select i1 %slt, i32 -3, i32 -1 + %result = select i1 %eq, i32 -2, i32 %. + %cmp = icmp eq i32 %result, -3 + br i1 %cmp, label %unreached, label %normal +normal: + ret void +unreached: + call void @use(i32 %result) + ret void +} + +define void @non_standard_mid(i64 %a, i64 %b) { +; CHECK-LABEL: @non_standard_mid +; CHECK: icmp eq i64 %a, %b +; CHECK: br i1 %eq, label %unreached, label %normal + %eq = icmp eq i64 %a, %b + %slt = icmp slt i64 %a, %b + %. = select i1 %slt, i32 -3, i32 -1 + %result = select i1 %eq, i32 -2, i32 %. + %cmp = icmp eq i32 %result, -2 + br i1 %cmp, label %unreached, label %normal +normal: + ret void +unreached: + call void @use(i32 %result) + ret void +} + +define void @non_standard_high(i64 %a, i64 %b) { +; CHECK-LABEL: @non_standard_high +; CHECK: [[TMP1:%.*]] = icmp sgt i64 %a, %b +; CHECK: br i1 [[TMP1]], label %unreached, label %normal + %eq = icmp eq i64 %a, %b + %slt = icmp slt i64 %a, %b + %. = select i1 %slt, i32 -3, i32 -1 + %result = select i1 %eq, i32 -2, i32 %. + %cmp = icmp eq i32 %result, -1 + br i1 %cmp, label %unreached, label %normal +normal: + ret void +unreached: + call void @use(i32 %result) + ret void +} + +define void @non_standard_bound1(i64 %a, i64 %b) { +; CHECK-LABEL: @non_standard_bound1 +; CHECK: br i1 false, label %unreached, label %normal + %eq = icmp eq i64 %a, %b + %slt = icmp slt i64 %a, %b + %. = select i1 %slt, i32 -3, i32 -1 + %result = select i1 %eq, i32 -2, i32 %. + %cmp = icmp eq i32 %result, -20 + br i1 %cmp, label %unreached, label %normal +normal: + ret void +unreached: + call void @use(i32 %result) + ret void +} + +define void @non_standard_bound2(i64 %a, i64 %b) { +; CHECK-LABEL: @non_standard_bound2 +; CHECK: br i1 false, label %unreached, label %normal + %eq = icmp eq i64 %a, %b + %slt = icmp slt i64 %a, %b + %. = select i1 %slt, i32 -3, i32 -1 + %result = select i1 %eq, i32 -2, i32 %. + %cmp = icmp eq i32 %result, 0 + br i1 %cmp, label %unreached, label %normal +normal: + ret void +unreached: + call void @use(i32 %result) + ret void +}