]> granicus.if.org Git - llvm/commitdiff
Lower consecutive select instructions correctly.
authorDehao Chen <dehao@google.com>
Mon, 12 Sep 2016 20:23:28 +0000 (20:23 +0000)
committerDehao Chen <dehao@google.com>
Mon, 12 Sep 2016 20:23:28 +0000 (20:23 +0000)
Summary: If consecutive select instructions are lowered separately in CGP, it will introduce redundant condition check and branches that cannot be removed by later optimization phases. This patch lowers all consecutive select instructions at the same to to avoid inefficent code as demonstrated in https://llvm.org/bugs/show_bug.cgi?id=29095

Reviewers: davidxl

Subscribers: vsk, llvm-commits

Differential Revision: https://reviews.llvm.org/D24147

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@281252 91177308-0d34-0410-b5e6-96231b3b80d8

lib/CodeGen/CodeGenPrepare.cpp
test/CodeGen/X86/pseudo_cmov_lower2.ll

index 3bdf60c172a1a43cc3e9b108c46896b54694c7e9..fc27f0ee7b2da857872b8594818cc9c4adda1eda 100644 (file)
@@ -4578,10 +4578,45 @@ static bool isFormingBranchFromSelectProfitable(const TargetTransformInfo *TTI,
   return false;
 }
 
+/// If \p isTrue is true, return the true value of \p SI, otherwise return
+/// false value of \p SI. If the true/false value of \p SI is defined by any
+/// select instructions in \p Selects, look through the defining select
+/// instruction until the true/false value is not defined in \p Selects.
+static Value *getTrueOrFalseValue(
+    SelectInst *SI, bool isTrue,
+    const SmallPtrSet<const Instruction *, 2> &Selects) {
+  Value *V;
+
+  for (SelectInst *DefSI = SI; DefSI != nullptr && Selects.count(DefSI);
+       DefSI = dyn_cast<SelectInst>(V)) {
+    assert(DefSI.getCondition() == SI->getCondition() &&
+           "The condition of DefSI does not match with SI");
+    V = (isTrue ? DefSI->getTrueValue() : DefSI->getFalseValue());
+  }
+  return V;
+}
 
 /// If we have a SelectInst that will likely profit from branch prediction,
 /// turn it into a branch.
 bool CodeGenPrepare::optimizeSelectInst(SelectInst *SI) {
+  // Find all consecutive select instructions that share the same condition.
+  SmallVector<SelectInst *, 2> ASI;
+  ASI.push_back(SI);
+  for (BasicBlock::iterator It = ++BasicBlock::iterator(SI);
+       It != SI->getParent()->end(); ++It) {
+    SelectInst *I = dyn_cast<SelectInst>(&*It);
+    if (I && SI->getCondition() == I->getCondition()) {
+      ASI.push_back(I);
+    } else {
+      break;
+    }
+  }
+
+  SelectInst *LastSI = ASI.back();
+  // Increment the current iterator to skip all the rest of select instructions
+  // because they will be either "not lowered" or "all lowered" to branch.
+  CurInstIterator = std::next(LastSI->getIterator());
+
   bool VectorCond = !SI->getCondition()->getType()->isIntegerTy(1);
 
   // Can we convert the 'select' to CF ?
@@ -4628,7 +4663,7 @@ bool CodeGenPrepare::optimizeSelectInst(SelectInst *SI) {
 
   // First, we split the block containing the select into 2 blocks.
   BasicBlock *StartBlock = SI->getParent();
-  BasicBlock::iterator SplitPt = ++(BasicBlock::iterator(SI));
+  BasicBlock::iterator SplitPt = ++(BasicBlock::iterator(LastSI));
   BasicBlock *EndBlock = StartBlock->splitBasicBlock(SplitPt, "select.end");
 
   // Delete the unconditional branch that was just created by the split.
@@ -4638,22 +4673,30 @@ bool CodeGenPrepare::optimizeSelectInst(SelectInst *SI) {
   // At least one will become an actual new basic block.
   BasicBlock *TrueBlock = nullptr;
   BasicBlock *FalseBlock = nullptr;
+  BranchInst *TrueBranch = nullptr;
+  BranchInst *FalseBranch = nullptr;
 
   // Sink expensive instructions into the conditional blocks to avoid executing
   // them speculatively.
-  if (sinkSelectOperand(TTI, SI->getTrueValue())) {
-    TrueBlock = BasicBlock::Create(SI->getContext(), "select.true.sink",
-                                   EndBlock->getParent(), EndBlock);
-    auto *TrueBranch = BranchInst::Create(EndBlock, TrueBlock);
-    auto *TrueInst = cast<Instruction>(SI->getTrueValue());
-    TrueInst->moveBefore(TrueBranch);
-  }
-  if (sinkSelectOperand(TTI, SI->getFalseValue())) {
-    FalseBlock = BasicBlock::Create(SI->getContext(), "select.false.sink",
-                                    EndBlock->getParent(), EndBlock);
-    auto *FalseBranch = BranchInst::Create(EndBlock, FalseBlock);
-    auto *FalseInst = cast<Instruction>(SI->getFalseValue());
-    FalseInst->moveBefore(FalseBranch);
+  for (SelectInst *SI : ASI) {
+    if (sinkSelectOperand(TTI, SI->getTrueValue())) {
+      if (TrueBlock == nullptr) {
+        TrueBlock = BasicBlock::Create(SI->getContext(), "select.true.sink",
+                                       EndBlock->getParent(), EndBlock);
+        TrueBranch = BranchInst::Create(EndBlock, TrueBlock);
+      }
+      auto *TrueInst = cast<Instruction>(SI->getTrueValue());
+      TrueInst->moveBefore(TrueBranch);
+    }
+    if (sinkSelectOperand(TTI, SI->getFalseValue())) {
+      if (FalseBlock == nullptr) {
+        FalseBlock = BasicBlock::Create(SI->getContext(), "select.false.sink",
+                                        EndBlock->getParent(), EndBlock);
+        FalseBranch = BranchInst::Create(EndBlock, FalseBlock);
+      }
+      auto *FalseInst = cast<Instruction>(SI->getFalseValue());
+      FalseInst->moveBefore(FalseBranch);
+    }
   }
 
   // If there was nothing to sink, then arbitrarily choose the 'false' side
@@ -4687,18 +4730,27 @@ bool CodeGenPrepare::optimizeSelectInst(SelectInst *SI) {
   }
   IRBuilder<>(SI).CreateCondBr(SI->getCondition(), TT, FT, SI);
 
-  // The select itself is replaced with a PHI Node.
-  PHINode *PN = PHINode::Create(SI->getType(), 2, "", &EndBlock->front());
-  PN->takeName(SI);
-  PN->addIncoming(SI->getTrueValue(), TrueBlock);
-  PN->addIncoming(SI->getFalseValue(), FalseBlock);
-
-  SI->replaceAllUsesWith(PN);
-  SI->eraseFromParent();
+  SmallPtrSet<const Instruction *, 2> INS;
+  INS.insert(ASI.begin(), ASI.end());
+  // Use reverse iterator because later select may use the value of the
+  // earlier select, and we need to propagate value through earlier select
+  // to get the PHI operand.
+  for (auto It = ASI.rbegin(); It != ASI.rend(); ++It) {
+    SelectInst *SI = *It;
+    // The select itself is replaced with a PHI Node.
+    PHINode *PN = PHINode::Create(SI->getType(), 2, "", &EndBlock->front());
+    PN->takeName(SI);
+    PN->addIncoming(getTrueOrFalseValue(SI, true, INS), TrueBlock);
+    PN->addIncoming(getTrueOrFalseValue(SI, false, INS), FalseBlock);
+
+    SI->replaceAllUsesWith(PN);
+    SI->eraseFromParent();
+    INS.erase(SI);
+    ++NumSelectsExpanded;
+  }
 
   // Instruct OptimizeBlock to skip to the next block.
   CurInstIterator = StartBlock->end();
-  ++NumSelectsExpanded;
   return true;
 }
 
index 0133963b36d042e2c75c9292c7facb0018840a38..38712a96b2bf734cede2d7cc87a19bc581daab7c 100644 (file)
@@ -98,3 +98,47 @@ entry:
   %d5 = fdiv double %d4, %d3
   ret double %d5
 }
+
+; This test checks that only a single jae gets generated in the final code
+; for lowering the CMOV pseudos that get created for this IR.  The tricky part
+; of this test is that it tests the special code in CodeGenPrepare.
+;
+; CHECK-LABEL: foo5:
+; CHECK: jae
+; CHECK-NOT: jae
+define double @foo5(float %p1, double %p2, double %p3) nounwind {
+entry:
+  %c1 = fcmp oge float %p1, 0.000000e+00
+  %d0 = fadd double %p2, 1.25e0
+  %d1 = fadd double %p3, 1.25e0
+  %d2 = select i1 %c1, double %d0, double %d1, !prof !0
+  %d3 = select i1 %c1, double %d2, double %p2, !prof !0
+  %d4 = select i1 %c1, double %d3, double %p3, !prof !0
+  %d5 = fsub double %d2, %d3
+  %d6 = fadd double %d5, %d4
+  ret double %d6
+}
+
+; We should expand select instructions into 3 conditional branches as their
+; condtions are different.
+;
+; CHECK-LABEL: foo6:
+; CHECK: jae
+; CHECK: jae
+; CHECK: jae
+define double @foo6(float %p1, double %p2, double %p3) nounwind {
+entry:
+  %c1 = fcmp oge float %p1, 0.000000e+00
+  %c2 = fcmp oge float %p1, 1.000000e+00
+  %c3 = fcmp oge float %p1, 2.000000e+00
+  %d0 = fadd double %p2, 1.25e0
+  %d1 = fadd double %p3, 1.25e0
+  %d2 = select i1 %c1, double %d0, double %d1, !prof !0
+  %d3 = select i1 %c2, double %d2, double %p2, !prof !0
+  %d4 = select i1 %c3, double %d3, double %p3, !prof !0
+  %d5 = fsub double %d2, %d3
+  %d6 = fadd double %d5, %d4
+  ret double %d6
+}
+
+!0 = !{!"branch_weights", i32 1, i32 2000}