]> granicus.if.org Git - llvm/commitdiff
[InstCombine] Limit FMul constant folding for fma simplifications.
authorFlorian Hahn <flo@fhahn.com>
Wed, 25 Sep 2019 17:03:20 +0000 (17:03 +0000)
committerFlorian Hahn <flo@fhahn.com>
Wed, 25 Sep 2019 17:03:20 +0000 (17:03 +0000)
As @reames pointed out post-commit, rL371518 adds additional rounding
in some cases, when doing constant folding of the multiplication.
This breaks a guarantee llvm.fma makes and must be avoided.

This patch reapplies rL371518, but splits off the simplifications not
requiring rounding from SimplifFMulInst as SimplifyFMAFMul.

Reviewers: spatel, lebedev.ri, reames, scanon

Reviewed By: reames

Differential Revision: https://reviews.llvm.org/D67434

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@372899 91177308-0d34-0410-b5e6-96231b3b80d8

include/llvm/Analysis/InstructionSimplify.h
lib/Analysis/InstructionSimplify.cpp
lib/Transforms/InstCombine/InstCombineCalls.cpp
test/Transforms/InstCombine/fma.ll

index db92130c8632820164099ee6d7998755f0f06fe8..a5ffca13046b7ff131bffdacbaf16c3fd1800240 100644 (file)
@@ -142,6 +142,13 @@ Value *SimplifyFSubInst(Value *LHS, Value *RHS, FastMathFlags FMF,
 Value *SimplifyFMulInst(Value *LHS, Value *RHS, FastMathFlags FMF,
                         const SimplifyQuery &Q);
 
+/// Given operands for the multiplication of a FMA, fold the result or return
+/// null. In contrast to SimplifyFMulInst, this function will not perform
+/// simplifications whose unrounded results differ when rounded to the argument
+/// type.
+Value *SimplifyFMAFMul(Value *LHS, Value *RHS, FastMathFlags FMF,
+                       const SimplifyQuery &Q);
+
 /// Given operands for a Mul, fold the result or return null.
 Value *SimplifyMulInst(Value *LHS, Value *RHS, const SimplifyQuery &Q);
 
index 67b06ea40bc74d9e2b712899547d75dd17c2d05a..4ae052eb14b21762b114246013af57b7ab20f2fb 100644 (file)
@@ -4576,15 +4576,8 @@ static Value *SimplifyFSubInst(Value *Op0, Value *Op1, FastMathFlags FMF,
   return nullptr;
 }
 
-/// Given the operands for an FMul, see if we can fold the result
-static Value *SimplifyFMulInst(Value *Op0, Value *Op1, FastMathFlags FMF,
-                               const SimplifyQuery &Q, unsigned MaxRecurse) {
-  if (Constant *C = foldOrCommuteConstant(Instruction::FMul, Op0, Op1, Q))
-    return C;
-
-  if (Constant *C = simplifyFPBinop(Op0, Op1))
-    return C;
-
+static Value *SimplifyFMAFMul(Value *Op0, Value *Op1, FastMathFlags FMF,
+                              const SimplifyQuery &Q, unsigned MaxRecurse) {
   // fmul X, 1.0 ==> X
   if (match(Op1, m_FPOne()))
     return Op0;
@@ -4605,6 +4598,19 @@ static Value *SimplifyFMulInst(Value *Op0, Value *Op1, FastMathFlags FMF,
   return nullptr;
 }
 
+/// Given the operands for an FMul, see if we can fold the result
+static Value *SimplifyFMulInst(Value *Op0, Value *Op1, FastMathFlags FMF,
+                               const SimplifyQuery &Q, unsigned MaxRecurse) {
+  if (Constant *C = foldOrCommuteConstant(Instruction::FMul, Op0, Op1, Q))
+    return C;
+
+  if (Constant *C = simplifyFPBinop(Op0, Op1))
+    return C;
+
+  // Now apply simplifications that do not require rounding.
+  return SimplifyFMAFMul(Op0, Op1, FMF, Q, MaxRecurse);
+}
+
 Value *llvm::SimplifyFAddInst(Value *Op0, Value *Op1, FastMathFlags FMF,
                               const SimplifyQuery &Q) {
   return ::SimplifyFAddInst(Op0, Op1, FMF, Q, RecursionLimit);
@@ -4621,6 +4627,11 @@ Value *llvm::SimplifyFMulInst(Value *Op0, Value *Op1, FastMathFlags FMF,
   return ::SimplifyFMulInst(Op0, Op1, FMF, Q, RecursionLimit);
 }
 
+Value *llvm::SimplifyFMAFMul(Value *Op0, Value *Op1, FastMathFlags FMF,
+                             const SimplifyQuery &Q) {
+  return ::SimplifyFMAFMul(Op0, Op1, FMF, Q, RecursionLimit);
+}
+
 static Value *SimplifyFDivInst(Value *Op0, Value *Op1, FastMathFlags FMF,
                                const SimplifyQuery &Q, unsigned) {
   if (Constant *C = foldOrCommuteConstant(Instruction::FDiv, Op0, Op1, Q))
index 5c0e648f490a08b097e03b3f6508a71d2c84e4bc..51a21e37e8b4e9b6172edac9ff17924dbb47d263 100644 (file)
@@ -2234,6 +2234,15 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {
       return replaceInstUsesWith(*II, Add);
     }
 
+    // Try to simplify the underlying FMul.
+    if (Value *V = SimplifyFMulInst(II->getArgOperand(0), II->getArgOperand(1),
+                                    II->getFastMathFlags(),
+                                    SQ.getWithInstruction(II))) {
+      auto *FAdd = BinaryOperator::CreateFAdd(V, II->getArgOperand(2));
+      FAdd->copyFastMathFlags(II);
+      return FAdd;
+    }
+
     LLVM_FALLTHROUGH;
   }
   case Intrinsic::fma: {
@@ -2258,9 +2267,12 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {
       return II;
     }
 
-    // fma x, 1, z -> fadd x, z
-    if (match(Src1, m_FPOne())) {
-      auto *FAdd = BinaryOperator::CreateFAdd(Src0, II->getArgOperand(2));
+    // Try to simplify the underlying FMul. We can only apply simplifications
+    // that do not require rounding.
+    if (Value *V = SimplifyFMAFMul(II->getArgOperand(0), II->getArgOperand(1),
+                                   II->getFastMathFlags(),
+                                   SQ.getWithInstruction(II))) {
+      auto *FAdd = BinaryOperator::CreateFAdd(V, II->getArgOperand(2));
       FAdd->copyFastMathFlags(II);
       return FAdd;
     }
index 89fdc6bc9a5a841dc143c7b2b158d1115aebf506..29d372fca05f86278e26c8423a33dcbb1aa6c90c 100644 (file)
@@ -372,8 +372,7 @@ define float @fmuladd_x_1_z_fast(float %x, float %z) {
 define <2 x double> @fmuladd_a_0_b(<2 x double> %a, <2 x double> %b) {
 ; CHECK-LABEL: @fmuladd_a_0_b(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[RES:%.*]] = call nnan nsz <2 x double> @llvm.fmuladd.v2f64(<2 x double> [[A:%.*]], <2 x double> zeroinitializer, <2 x double> [[B:%.*]])
-; CHECK-NEXT:    ret <2 x double> [[RES]]
+; CHECK-NEXT:    ret <2 x double> [[B:%.*]]
 ;
 entry:
   %res = call nnan nsz <2 x double> @llvm.fmuladd.v2f64(<2 x double> %a, <2 x double> zeroinitializer, <2 x double> %b)
@@ -383,8 +382,7 @@ entry:
 define <2 x double> @fmuladd_0_a_b(<2 x double> %a, <2 x double> %b) {
 ; CHECK-LABEL: @fmuladd_0_a_b(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[RES:%.*]] = call nnan nsz <2 x double> @llvm.fmuladd.v2f64(<2 x double> [[A:%.*]], <2 x double> zeroinitializer, <2 x double> [[B:%.*]])
-; CHECK-NEXT:    ret <2 x double> [[RES]]
+; CHECK-NEXT:    ret <2 x double> [[B:%.*]]
 ;
 entry:
   %res = call nnan nsz <2 x double> @llvm.fmuladd.v2f64(<2 x double> zeroinitializer, <2 x double> %a, <2 x double> %b)
@@ -407,8 +405,7 @@ declare <2 x double> @llvm.fmuladd.v2f64(<2 x double>, <2 x double>, <2 x double
 define <2 x double> @fma_a_0_b(<2 x double> %a, <2 x double> %b) {
 ; CHECK-LABEL: @fma_a_0_b(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[RES:%.*]] = call nnan nsz <2 x double> @llvm.fma.v2f64(<2 x double> [[A:%.*]], <2 x double> zeroinitializer, <2 x double> [[B:%.*]])
-; CHECK-NEXT:    ret <2 x double> [[RES]]
+; CHECK-NEXT:    ret <2 x double> [[B:%.*]]
 ;
 entry:
   %res = call nnan nsz <2 x double> @llvm.fma.v2f64(<2 x double> %a, <2 x double> zeroinitializer, <2 x double> %b)
@@ -418,8 +415,7 @@ entry:
 define <2 x double> @fma_0_a_b(<2 x double> %a, <2 x double> %b) {
 ; CHECK-LABEL: @fma_0_a_b(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[RES:%.*]] = call nnan nsz <2 x double> @llvm.fma.v2f64(<2 x double> [[A:%.*]], <2 x double> zeroinitializer, <2 x double> [[B:%.*]])
-; CHECK-NEXT:    ret <2 x double> [[RES]]
+; CHECK-NEXT:    ret <2 x double> [[B:%.*]]
 ;
 entry:
   %res = call nnan nsz <2 x double> @llvm.fma.v2f64(<2 x double> zeroinitializer, <2 x double> %a, <2 x double> %b)
@@ -440,8 +436,7 @@ entry:
 define <2 x double> @fma_sqrt(<2 x double> %a, <2 x double> %b) {
 ; CHECK-LABEL: @fma_sqrt(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[SQRT:%.*]] = call fast <2 x double> @llvm.sqrt.v2f64(<2 x double> [[A:%.*]])
-; CHECK-NEXT:    [[RES:%.*]] = call fast <2 x double> @llvm.fma.v2f64(<2 x double> [[SQRT]], <2 x double> [[SQRT]], <2 x double> [[B:%.*]])
+; CHECK-NEXT:    [[RES:%.*]] = fadd fast <2 x double> [[A:%.*]], [[B:%.*]]
 ; CHECK-NEXT:    ret <2 x double> [[RES]]
 ;
 entry:
@@ -450,6 +445,71 @@ entry:
   ret <2 x double> %res
 }
 
+; We do not fold constant multiplies in FMAs, as they could require rounding, unless either constant is 0.0 or 1.0.
+define <2 x double> @fma_const_fmul(<2 x double> %b) {
+; CHECK-LABEL: @fma_const_fmul(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[RES:%.*]] = call nnan nsz <2 x double> @llvm.fma.v2f64(<2 x double> <double 0x4131233302898702, double 0x40C387800000D6C0>, <2 x double> <double 1.291820e-08, double 9.123000e-06>, <2 x double> [[B:%.*]])
+; CHECK-NEXT:    ret <2 x double> [[RES]]
+;
+entry:
+  %res = call nnan nsz <2 x double> @llvm.fma.v2f64(<2 x double> <double 1123123.0099110012314, double 9999.0000001>, <2 x double> <double 0.0000000129182, double 0.000009123>, <2 x double> %b)
+  ret <2 x double> %res
+}
+
+define <2 x double> @fma_const_fmul_zero(<2 x double> %b) {
+; CHECK-LABEL: @fma_const_fmul_zero(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[RES:%.*]] = call nnan nsz <2 x double> @llvm.fma.v2f64(<2 x double> zeroinitializer, <2 x double> <double 0x4131233302898702, double 0x40C387800000D6C0>, <2 x double> [[B:%.*]])
+; CHECK-NEXT:    ret <2 x double> [[RES]]
+;
+entry:
+  %res = call nnan nsz <2 x double> @llvm.fma.v2f64(<2 x double> <double 0.0, double 0.0>, <2 x double> <double 1123123.0099110012314, double 9999.0000001>, <2 x double> %b)
+  ret <2 x double> %res
+}
+
+define <2 x double> @fma_const_fmul_zero2(<2 x double> %b) {
+; CHECK-LABEL: @fma_const_fmul_zero2(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    ret <2 x double> [[B:%.*]]
+;
+entry:
+  %res = call nnan nsz <2 x double> @llvm.fma.v2f64(<2 x double> <double 1123123.0099110012314, double 9999.0000001>, <2 x double> <double 0.0, double 0.0>, <2 x double> %b)
+  ret <2 x double> %res
+}
+
+define <2 x double> @fma_const_fmul_one(<2 x double> %b) {
+; CHECK-LABEL: @fma_const_fmul_one(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[RES:%.*]] = call nnan nsz <2 x double> @llvm.fma.v2f64(<2 x double> <double 1.000000e+00, double 1.000000e+00>, <2 x double> <double 0x4131233302898702, double 0x40C387800000D6C0>, <2 x double> [[B:%.*]])
+; CHECK-NEXT:    ret <2 x double> [[RES]]
+;
+entry:
+  %res = call nnan nsz <2 x double> @llvm.fma.v2f64(<2 x double> <double 1.0, double 1.0>, <2 x double> <double 1123123.0099110012314, double 9999.0000001>, <2 x double> %b)
+  ret <2 x double> %res
+}
+
+define <2 x double> @fma_const_fmul_one2(<2 x double> %b) {
+; CHECK-LABEL: @fma_const_fmul_one2(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[RES:%.*]] = fadd nnan nsz <2 x double> [[B:%.*]], <double 0x4131233302898702, double 0x40C387800000D6C0>
+; CHECK-NEXT:    ret <2 x double> [[RES]]
+;
+entry:
+  %res = call nnan nsz <2 x double> @llvm.fma.v2f64(<2 x double> <double 1123123.0099110012314, double 9999.0000001>, <2 x double> <double 1.0, double 1.0>, <2 x double> %b)
+  ret <2 x double> %res
+}
+
+define <2 x double> @fmuladd_const_fmul(<2 x double> %b) {
+; CHECK-LABEL: @fmuladd_const_fmul(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[RES:%.*]] = fadd nnan nsz <2 x double> [[B:%.*]], <double 0x3F8DB6C076AD949B, double 0x3FB75A405B6E6D69>
+; CHECK-NEXT:    ret <2 x double> [[RES]]
+;
+entry:
+  %res = call nnan nsz <2 x double> @llvm.fmuladd.v2f64(<2 x double> <double 1123123.0099110012314, double 9999.0000001>, <2 x double> <double 0.0000000129182, double 0.000009123>, <2 x double> %b)
+  ret <2 x double> %res
+}
 
 declare <2 x double> @llvm.fma.v2f64(<2 x double>, <2 x double>, <2 x double>)
 declare <2 x double> @llvm.sqrt.v2f64(<2 x double>)