]> granicus.if.org Git - llvm/commitdiff
[InstCombine] Recognize and simplify three way comparison idioms
authorAnna Thomas <anna@azul.com>
Fri, 23 Jun 2017 13:41:45 +0000 (13:41 +0000)
committerAnna Thomas <anna@azul.com>
Fri, 23 Jun 2017 13:41:45 +0000 (13:41 +0000)
Summary:
Many languages have a three way comparison idiom where comparing two values
produces not a boolean, but a tri-state value. Typical values (e.g. as used in
the lcmp/fcmp bytecodes from Java) are -1 for less than, 0 for equality, and +1
for greater than.

We actually do a great job already of converting three way comparisons into
binary comparisons when the result produced has one a single use. Unfortunately,
such values can have more than one use, and in that case, our existing
optimizations break down.

The patch adds a peephole which converts a three-way compare + test idiom into a
binary comparison on the original inputs. It focused on replacing the test on
the result of the three way compare and does nothing about removing the three
way compare itself. That's left to other optimizations (which do actually kick
in commonly.)
We currently recognize one idiom on signed integer compare. In the future, we
plan to recognize and simplify other comparison idioms on
other signed/unsigned datatypes such as floats, vectors etc.

This is a resurrection of Philip Reames' original patch:
https://reviews.llvm.org/D19452

Reviewers: majnemer, apilipenko, reames, sanjoy, mkazantsev

Reviewed by: mkazantsev

Subscribers: llvm-commits

Differential Revision: https://reviews.llvm.org/D34278

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@306100 91177308-0d34-0410-b5e6-96231b3b80d8

lib/Transforms/InstCombine/InstCombineCompares.cpp
lib/Transforms/InstCombine/InstCombineInternal.h
test/Transforms/InstCombine/compare-3way.ll [new file with mode: 0644]

index 1ef4acfb058c4961c874d6bd1150bbf64a6e8864..bc79e4534e65a6d940f13369c2a19cf4504edd01 100644 (file)
@@ -2434,6 +2434,77 @@ Instruction *InstCombiner::foldICmpAddConstant(ICmpInst &Cmp,
   return nullptr;
 }
 
+bool InstCombiner::matchThreeWayIntCompare(SelectInst *SI, Value *&LHS,
+                                           Value *&RHS, ConstantInt *&Less,
+                                           ConstantInt *&Equal,
+                                           ConstantInt *&Greater) {
+  // TODO: Generalize this to work with other comparison idioms or ensure
+  // they get canonicalized into this form.
+
+  // select i1 (a == b), i32 Equal, i32 (select i1 (a < b), i32 Less, i32
+  // Greater), where Equal, Less and Greater are placeholders for any three
+  // constants.
+  ICmpInst::Predicate PredA, PredB;
+  if (match(SI->getTrueValue(), m_ConstantInt(Equal)) &&
+      match(SI->getCondition(), m_ICmp(PredA, m_Value(LHS), m_Value(RHS))) &&
+      PredA == ICmpInst::ICMP_EQ &&
+      match(SI->getFalseValue(),
+            m_Select(m_ICmp(PredB, m_Specific(LHS), m_Specific(RHS)),
+                     m_ConstantInt(Less), m_ConstantInt(Greater))) &&
+      PredB == ICmpInst::ICMP_SLT) {
+    return true;
+  }
+  return false;
+}
+
+Instruction *InstCombiner::foldICmpSelectConstant(ICmpInst &Cmp,
+                                                  Instruction *Select,
+                                                  ConstantInt *C) {
+
+  assert(C && "Cmp RHS should be a constant int!");
+  // If we're testing a constant value against the result of a three way
+  // comparison, the result can be expressed directly in terms of the
+  // original values being compared.  Note: We could possibly be more
+  // aggressive here and remove the hasOneUse test. The original select is
+  // really likely to simplify or sink when we remove a test of the result.
+  Value *OrigLHS, *OrigRHS;
+  ConstantInt *C1LessThan, *C2Equal, *C3GreaterThan;
+  if (Cmp.hasOneUse() &&
+      matchThreeWayIntCompare(cast<SelectInst>(Select), OrigLHS, OrigRHS,
+                                 C1LessThan, C2Equal, C3GreaterThan)) {
+    assert(C1LessThan && C2Equal && C3GreaterThan);
+
+    bool TrueWhenLessThan =
+        ConstantExpr::getCompare(Cmp.getPredicate(), C1LessThan, C)
+            ->isAllOnesValue();
+    bool TrueWhenEqual =
+        ConstantExpr::getCompare(Cmp.getPredicate(), C2Equal, C)
+            ->isAllOnesValue();
+    bool TrueWhenGreaterThan =
+        ConstantExpr::getCompare(Cmp.getPredicate(), C3GreaterThan, C)
+            ->isAllOnesValue();
+
+    // This generates the new instruction that will replace the original Cmp
+    // Instruction. Instead of enumerating the various combinations when
+    // TrueWhenLessThan, TrueWhenEqual and TrueWhenGreaterThan are true versus
+    // false, we rely on chaining of ORs and future passes of InstCombine to
+    // simplify the OR further (i.e. a s< b || a == b becomes a s<= b).
+
+    // When none of the three constants satisfy the predicate for the RHS (C),
+    // the entire original Cmp can be simplified to a false.
+    Value *Cond = Builder->getFalse();
+    if (TrueWhenLessThan)
+      Cond = Builder->CreateOr(Cond, Builder->CreateICmp(ICmpInst::ICMP_SLT, OrigLHS, OrigRHS));
+    if (TrueWhenEqual)
+      Cond = Builder->CreateOr(Cond, Builder->CreateICmp(ICmpInst::ICMP_EQ, OrigLHS, OrigRHS));
+    if (TrueWhenGreaterThan)
+      Cond = Builder->CreateOr(Cond, Builder->CreateICmp(ICmpInst::ICMP_SGT, OrigLHS, OrigRHS));
+
+    return replaceInstUsesWith(Cmp, Cond);
+  }
+  return nullptr;
+}
+
 /// Try to fold integer comparisons with a constant operand: icmp Pred X, C
 /// where X is some kind of instruction.
 Instruction *InstCombiner::foldICmpInstWithConstant(ICmpInst &Cmp) {
@@ -2493,11 +2564,28 @@ Instruction *InstCombiner::foldICmpInstWithConstant(ICmpInst &Cmp) {
       return I;
   }
 
+  // Match against CmpInst LHS being instructions other than binary operators.
   Instruction *LHSI;
-  if (match(Cmp.getOperand(0), m_Instruction(LHSI)) &&
-      LHSI->getOpcode() == Instruction::Trunc)
-    if (Instruction *I = foldICmpTruncConstant(Cmp, LHSI, C))
-      return I;
+  if (match(Cmp.getOperand(0), m_Instruction(LHSI))) {
+    switch (LHSI->getOpcode()) {
+    case Instruction::Select:
+      {
+      // For now, we only support constant integers while folding the
+      // ICMP(SELECT)) pattern. We can extend this to support vector of integers
+      // similar to the cases handled by binary ops above.
+      if (ConstantInt *ConstRHS = dyn_cast<ConstantInt>(Cmp.getOperand(1)))
+        if (Instruction *I = foldICmpSelectConstant(Cmp, LHSI, ConstRHS))
+          return I;
+      break;
+      }
+    case Instruction::Trunc:
+      if (Instruction *I = foldICmpTruncConstant(Cmp, LHSI, C))
+        return I;
+      break;
+    default:
+      break;
+    }
+  }
 
   if (Instruction *I = foldICmpIntrinsicWithConstant(Cmp, C))
     return I;
index 7f0539e52b404ddc92760bd5d4e6f15f8b7bdb70..1b0fe84dd4dda5cf5e2b996cbeea62d71edcb415 100644 (file)
@@ -603,6 +603,15 @@ private:
                           Instruction::BinaryOps, Value *, Value *, Value *,
                           Value *);
 
+  /// Match a select chain which produces one of three values based on whether
+  /// the LHS is less than, equal to, or greater than RHS respectively.
+  /// Return true if we matched a three way compare idiom. The LHS, RHS, Less,
+  /// Equal and Greater values are saved in the matching process and returned to
+  /// the caller.
+  bool matchThreeWayIntCompare(SelectInst *SI, Value *&LHS, Value *&RHS,
+                               ConstantInt *&Less, ConstantInt *&Equal,
+                               ConstantInt *&Greater);
+
   /// \brief Attempts to replace V with a simpler value based on the demanded
   /// bits.
   Value *SimplifyDemandedUseBits(Value *V, APInt DemandedMask, KnownBits &Known,
@@ -680,6 +689,8 @@ private:
   Instruction *foldICmpBinOp(ICmpInst &Cmp);
   Instruction *foldICmpEquality(ICmpInst &Cmp);
 
+  Instruction *foldICmpSelectConstant(ICmpInst &Cmp, Instruction *Select,
+                                      ConstantInt *C);
   Instruction *foldICmpTruncConstant(ICmpInst &Cmp, Instruction *Trunc,
                                      const APInt *C);
   Instruction *foldICmpAndConstant(ICmpInst &Cmp, BinaryOperator *And,
diff --git a/test/Transforms/InstCombine/compare-3way.ll b/test/Transforms/InstCombine/compare-3way.ll
new file mode 100644 (file)
index 0000000..663d470
--- /dev/null
@@ -0,0 +1,395 @@
+; RUN: opt -S -instcombine < %s | FileCheck %s
+
+declare void @use(i32)
+
+; These 18 exercise all combinations of signed comparison
+; for each of the three values produced by your typical 
+; 3way compare function (-1, 0, 1)
+
+define void @test_low_sgt(i64 %a, i64 %b) {
+; CHECK-LABEL: @test_low_sgt
+; CHECK: [[TMP1:%.*]] = icmp slt i64 %a, %b
+; CHECK: br i1 [[TMP1]], label %normal, label %unreached
+  %eq = icmp eq i64 %a, %b
+  %slt = icmp slt i64 %a, %b
+  %. = select i1 %slt, i32 -1, i32 1
+  %result = select i1 %eq, i32 0, i32 %.
+  %cmp = icmp sgt i32 %result, -1
+  br i1 %cmp, label %unreached, label %normal
+normal:
+  ret void
+unreached:
+  call void @use(i32 %result)
+  ret void
+}
+
+define void @test_low_slt(i64 %a, i64 %b) {
+; CHECK-LABEL: @test_low_slt
+; CHECK: br i1 false, label %unreached, label %normal
+  %eq = icmp eq i64 %a, %b
+  %slt = icmp slt i64 %a, %b
+  %. = select i1 %slt, i32 -1, i32 1
+  %result = select i1 %eq, i32 0, i32 %.
+  %cmp = icmp slt i32 %result, -1
+  br i1 %cmp, label %unreached, label %normal
+normal:
+  ret void
+unreached:
+  call void @use(i32 %result)
+  ret void
+}
+
+define void @test_low_sge(i64 %a, i64 %b) {
+; CHECK-LABEL: @test_low_sge
+; CHECK: br i1 true, label %unreached, label %normal
+  %eq = icmp eq i64 %a, %b
+  %slt = icmp slt i64 %a, %b
+  %. = select i1 %slt, i32 -1, i32 1
+  %result = select i1 %eq, i32 0, i32 %.
+  %cmp = icmp sge i32 %result, -1
+  br i1 %cmp, label %unreached, label %normal
+normal:
+  ret void
+unreached:
+  call void @use(i32 %result)
+  ret void
+}
+
+define void @test_low_sle(i64 %a, i64 %b) {
+; CHECK-LABEL: @test_low_sle
+; CHECK: [[TMP1:%.*]] = icmp slt i64 %a, %b
+; CHECK: br i1 [[TMP1]], label %unreached, label %normal
+  %eq = icmp eq i64 %a, %b
+  %slt = icmp slt i64 %a, %b
+  %. = select i1 %slt, i32 -1, i32 1
+  %result = select i1 %eq, i32 0, i32 %.
+  %cmp = icmp sle i32 %result, -1
+  br i1 %cmp, label %unreached, label %normal
+normal:
+  ret void
+unreached:
+  call void @use(i32 %result)
+  ret void
+}
+
+define void @test_low_ne(i64 %a, i64 %b) {
+; CHECK-LABEL: @test_low_ne
+; CHECK: [[TMP1:%.*]] = icmp slt i64 %a, %b
+; CHECK: br i1 [[TMP1]], label %normal, label %unreached
+  %eq = icmp eq i64 %a, %b
+  %slt = icmp slt i64 %a, %b
+  %. = select i1 %slt, i32 -1, i32 1
+  %result = select i1 %eq, i32 0, i32 %.
+  %cmp = icmp ne i32 %result, -1
+  br i1 %cmp, label %unreached, label %normal
+normal:
+  ret void
+unreached:
+  call void @use(i32 %result)
+  ret void
+}
+
+define void @test_low_eq(i64 %a, i64 %b) {
+; CHECK-LABEL: @test_low_eq
+; CHECK: [[TMP1:%.*]] = icmp slt i64 %a, %b
+; CHECK: br i1 [[TMP1]], label %unreached, label %normal
+  %eq = icmp eq i64 %a, %b
+  %slt = icmp slt i64 %a, %b
+  %. = select i1 %slt, i32 -1, i32 1
+  %result = select i1 %eq, i32 0, i32 %.
+  %cmp = icmp eq i32 %result, -1
+  br i1 %cmp, label %unreached, label %normal
+normal:
+  ret void
+unreached:
+  call void @use(i32 %result)
+  ret void
+}
+
+define void @test_mid_sgt(i64 %a, i64 %b) {
+; CHECK-LABEL: @test_mid_sgt
+; CHECK: [[TMP1:%.*]] = icmp sgt i64 %a, %b
+; CHECK: br i1 [[TMP1]], label %unreached, label %normal
+  %eq = icmp eq i64 %a, %b
+  %slt = icmp slt i64 %a, %b
+  %. = select i1 %slt, i32 -1, i32 1
+  %result = select i1 %eq, i32 0, i32 %.
+  %cmp = icmp sgt i32 %result, 0
+  br i1 %cmp, label %unreached, label %normal
+normal:
+  ret void
+unreached:
+  call void @use(i32 %result)
+  ret void
+}
+
+define void @test_mid_slt(i64 %a, i64 %b) {
+; CHECK-LABEL: @test_mid_slt
+; CHECK: [[TMP1:%.*]] = icmp slt i64 %a, %b
+; CHECK: br i1 [[TMP1]], label %unreached, label %normal
+  %eq = icmp eq i64 %a, %b
+  %slt = icmp slt i64 %a, %b
+  %. = select i1 %slt, i32 -1, i32 1
+  %result = select i1 %eq, i32 0, i32 %.
+  %cmp = icmp slt i32 %result, 0
+  br i1 %cmp, label %unreached, label %normal
+normal:
+  ret void
+unreached:
+  call void @use(i32 %result)
+  ret void
+}
+
+define void @test_mid_sge(i64 %a, i64 %b) {
+; CHECK-LABEL: @test_mid_sge
+; CHECK: [[TMP1:%.*]] = icmp slt i64 %a, %b
+; CHECK: br i1 [[TMP1]], label %normal, label %unreached
+  %eq = icmp eq i64 %a, %b
+  %slt = icmp slt i64 %a, %b
+  %. = select i1 %slt, i32 -1, i32 1
+  %result = select i1 %eq, i32 0, i32 %.
+  %cmp = icmp sge i32 %result, 0
+  br i1 %cmp, label %unreached, label %normal
+normal:
+  ret void
+unreached:
+  call void @use(i32 %result)
+  ret void
+}
+
+define void @test_mid_sle(i64 %a, i64 %b) {
+; CHECK-LABEL: @test_mid_sle
+; CHECK: [[TMP1:%.*]] = icmp sgt i64 %a, %b
+; CHECK: br i1 [[TMP1]], label %normal, label %unreached
+  %eq = icmp eq i64 %a, %b
+  %slt = icmp slt i64 %a, %b
+  %. = select i1 %slt, i32 -1, i32 1
+  %result = select i1 %eq, i32 0, i32 %.
+  %cmp = icmp sle i32 %result, 0
+  br i1 %cmp, label %unreached, label %normal
+normal:
+  ret void
+unreached:
+  call void @use(i32 %result)
+  ret void
+}
+
+define void @test_mid_ne(i64 %a, i64 %b) {
+; CHECK-LABEL: @test_mid_ne
+; CHECK: [[TMP1:%.*]] = icmp eq i64 %a, %b
+; CHECK: br i1 [[TMP1]], label %normal, label %unreached
+  %eq = icmp eq i64 %a, %b
+  %slt = icmp slt i64 %a, %b
+  %. = select i1 %slt, i32 -1, i32 1
+  %result = select i1 %eq, i32 0, i32 %.
+  %cmp = icmp ne i32 %result, 0
+  br i1 %cmp, label %unreached, label %normal
+normal:
+  ret void
+unreached:
+  call void @use(i32 %result)
+  ret void
+}
+
+define void @test_mid_eq(i64 %a, i64 %b) {
+; CHECK-LABEL: @test_mid_eq
+; CHECK: icmp eq i64 %a, %b
+; CHECK: br i1 %eq, label %unreached, label %normal
+  %eq = icmp eq i64 %a, %b
+  %slt = icmp slt i64 %a, %b
+  %. = select i1 %slt, i32 -1, i32 1
+  %result = select i1 %eq, i32 0, i32 %.
+  %cmp = icmp eq i32 %result, 0
+  br i1 %cmp, label %unreached, label %normal
+normal:
+  ret void
+unreached:
+  call void @use(i32 %result)
+  ret void
+}
+
+define void @test_high_sgt(i64 %a, i64 %b) {
+; CHECK-LABEL: @test_high_sgt
+; CHECK: br i1 false, label %unreached, label %normal
+  %eq = icmp eq i64 %a, %b
+  %slt = icmp slt i64 %a, %b
+  %. = select i1 %slt, i32 -1, i32 1
+  %result = select i1 %eq, i32 0, i32 %.
+  %cmp = icmp sgt i32 %result, 1
+  br i1 %cmp, label %unreached, label %normal
+normal:
+  ret void
+unreached:
+  call void @use(i32 %result)
+  ret void
+}
+
+define void @test_high_slt(i64 %a, i64 %b) {
+; CHECK-LABEL: @test_high_slt
+; CHECK: [[TMP1:%.*]] = icmp sgt i64 %a, %b
+; CHECK: br i1 [[TMP1]], label %normal, label %unreached
+  %eq = icmp eq i64 %a, %b
+  %slt = icmp slt i64 %a, %b
+  %. = select i1 %slt, i32 -1, i32 1
+  %result = select i1 %eq, i32 0, i32 %.
+  %cmp = icmp slt i32 %result, 1
+  br i1 %cmp, label %unreached, label %normal
+normal:
+  ret void
+unreached:
+  call void @use(i32 %result)
+  ret void
+}
+
+define void @test_high_sge(i64 %a, i64 %b) {
+; CHECK-LABEL: @test_high_sge
+; CHECK: [[TMP1:%.*]] = icmp sgt i64 %a, %b
+; CHECK: br i1 [[TMP1]], label %unreached, label %normal
+  %eq = icmp eq i64 %a, %b
+  %slt = icmp slt i64 %a, %b
+  %. = select i1 %slt, i32 -1, i32 1
+  %result = select i1 %eq, i32 0, i32 %.
+  %cmp = icmp sge i32 %result, 1
+  br i1 %cmp, label %unreached, label %normal
+normal:
+  ret void
+unreached:
+  call void @use(i32 %result)
+  ret void
+}
+
+define void @test_high_sle(i64 %a, i64 %b) {
+; CHECK-LABEL: @test_high_sle
+; CHECK: br i1 true, label %unreached, label %normal
+  %eq = icmp eq i64 %a, %b
+  %slt = icmp slt i64 %a, %b
+  %. = select i1 %slt, i32 -1, i32 1
+  %result = select i1 %eq, i32 0, i32 %.
+  %cmp = icmp sle i32 %result, 1
+  br i1 %cmp, label %unreached, label %normal
+normal:
+  ret void
+unreached:
+  call void @use(i32 %result)
+  ret void
+}
+
+define void @test_high_ne(i64 %a, i64 %b) {
+; CHECK-LABEL: @test_high_ne
+; CHECK: [[TMP1:%.*]] = icmp sgt i64 %a, %b
+; CHECK: br i1 [[TMP1]], label %normal, label %unreached
+  %eq = icmp eq i64 %a, %b
+  %slt = icmp slt i64 %a, %b
+  %. = select i1 %slt, i32 -1, i32 1
+  %result = select i1 %eq, i32 0, i32 %.
+  %cmp = icmp ne i32 %result, 1
+  br i1 %cmp, label %unreached, label %normal
+normal:
+  ret void
+unreached:
+  call void @use(i32 %result)
+  ret void
+}
+
+define void @test_high_eq(i64 %a, i64 %b) {
+; CHECK-LABEL: @test_high_eq
+; CHECK: [[TMP1:%.*]] = icmp sgt i64 %a, %b
+; CHECK: br i1 [[TMP1]], label %unreached, label %normal
+  %eq = icmp eq i64 %a, %b
+  %slt = icmp slt i64 %a, %b
+  %. = select i1 %slt, i32 -1, i32 1
+  %result = select i1 %eq, i32 0, i32 %.
+  %cmp = icmp eq i32 %result, 1
+  br i1 %cmp, label %unreached, label %normal
+normal:
+  ret void
+unreached:
+  call void @use(i32 %result)
+  ret void
+}
+
+; These five make sure we didn't accidentally hard code one of the
+; produced values
+
+define void @non_standard_low(i64 %a, i64 %b) {
+; CHECK-LABEL: @non_standard_low
+; CHECK: [[TMP1:%.*]] = icmp slt i64 %a, %b
+; CHECK: br i1 [[TMP1]], label %unreached, label %normal
+  %eq = icmp eq i64 %a, %b
+  %slt = icmp slt i64 %a, %b
+  %. = select i1 %slt, i32 -3, i32 -1
+  %result = select i1 %eq, i32 -2, i32 %.
+  %cmp = icmp eq i32 %result, -3
+  br i1 %cmp, label %unreached, label %normal
+normal:
+  ret void
+unreached:
+  call void @use(i32 %result)
+  ret void
+}
+
+define void @non_standard_mid(i64 %a, i64 %b) {
+; CHECK-LABEL: @non_standard_mid
+; CHECK: icmp eq i64 %a, %b
+; CHECK: br i1 %eq, label %unreached, label %normal
+  %eq = icmp eq i64 %a, %b
+  %slt = icmp slt i64 %a, %b
+  %. = select i1 %slt, i32 -3, i32 -1
+  %result = select i1 %eq, i32 -2, i32 %.
+  %cmp = icmp eq i32 %result, -2
+  br i1 %cmp, label %unreached, label %normal
+normal:
+  ret void
+unreached:
+  call void @use(i32 %result)
+  ret void
+}
+
+define void @non_standard_high(i64 %a, i64 %b) {
+; CHECK-LABEL: @non_standard_high
+; CHECK: [[TMP1:%.*]] = icmp sgt i64 %a, %b
+; CHECK: br i1 [[TMP1]], label %unreached, label %normal
+  %eq = icmp eq i64 %a, %b
+  %slt = icmp slt i64 %a, %b
+  %. = select i1 %slt, i32 -3, i32 -1
+  %result = select i1 %eq, i32 -2, i32 %.
+  %cmp = icmp eq i32 %result, -1
+  br i1 %cmp, label %unreached, label %normal
+normal:
+  ret void
+unreached:
+  call void @use(i32 %result)
+  ret void
+}
+
+define void @non_standard_bound1(i64 %a, i64 %b) {
+; CHECK-LABEL: @non_standard_bound1
+; CHECK: br i1 false, label %unreached, label %normal
+  %eq = icmp eq i64 %a, %b
+  %slt = icmp slt i64 %a, %b
+  %. = select i1 %slt, i32 -3, i32 -1
+  %result = select i1 %eq, i32 -2, i32 %.
+  %cmp = icmp eq i32 %result, -20
+  br i1 %cmp, label %unreached, label %normal
+normal:
+  ret void
+unreached:
+  call void @use(i32 %result)
+  ret void
+}
+
+define void @non_standard_bound2(i64 %a, i64 %b) {
+; CHECK-LABEL: @non_standard_bound2
+; CHECK: br i1 false, label %unreached, label %normal
+  %eq = icmp eq i64 %a, %b
+  %slt = icmp slt i64 %a, %b
+  %. = select i1 %slt, i32 -3, i32 -1
+  %result = select i1 %eq, i32 -2, i32 %.
+  %cmp = icmp eq i32 %result, 0
+  br i1 %cmp, label %unreached, label %normal
+normal:
+  ret void
+unreached:
+  call void @use(i32 %result)
+  ret void
+}