]> granicus.if.org Git - llvm/commitdiff
[DAGCombine] GetNegatedExpression - add FMA\FMAD support
authorSimon Pilgrim <llvm-dev@redking.me.uk>
Fri, 23 Aug 2019 10:49:46 +0000 (10:49 +0000)
committerSimon Pilgrim <llvm-dev@redking.me.uk>
Fri, 23 Aug 2019 10:49:46 +0000 (10:49 +0000)
If the accumulator and either of the multiply operands are negatable then we can we negate the entire expression.

Differential Revision: https://reviews.llvm.org/D63141

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@369746 91177308-0d34-0410-b5e6-96231b3b80d8

lib/CodeGen/SelectionDAG/DAGCombiner.cpp
test/CodeGen/X86/fma-fneg-combine-2.ll

index 081544818f6e5494201ef97618e64240e16f4401..ec47f2ba2c4b48d6a898e7d1e17113cb64ed4a6c 100644 (file)
@@ -876,6 +876,27 @@ static char isNegatibleForFree(SDValue Op, bool LegalOperations,
     return isNegatibleForFree(Op.getOperand(1), LegalOperations, TLI, Options,
                               ForCodeSize, Depth + 1);
 
+  case ISD::FMA:
+  case ISD::FMAD: {
+    if (!Options->NoSignedZerosFPMath && !Flags.hasNoSignedZeros())
+      return 0;
+
+    // fold (fneg (fma X, Y, Z)) -> (fma (fneg X), Y, (fneg Z))
+    // fold (fneg (fma X, Y, Z)) -> (fma X, (fneg Y), (fneg Z))
+    char V2 = isNegatibleForFree(Op.getOperand(2), LegalOperations, TLI,
+                                 Options, ForCodeSize, Depth + 1);
+    if (!V2)
+      return 0;
+
+    // One of Op0/Op1 must be cheaply negatible, then select the cheapest.
+    char V0 = isNegatibleForFree(Op.getOperand(0), LegalOperations, TLI,
+                                 Options, ForCodeSize, Depth + 1);
+    char V1 = isNegatibleForFree(Op.getOperand(1), LegalOperations, TLI,
+                                 Options, ForCodeSize, Depth + 1);
+    char V01 = std::max(V0, V1);
+    return V01 ? std::max(V01, V2) : 0;
+  }
+
   case ISD::FP_EXTEND:
   case ISD::FP_ROUND:
   case ISD::FSIN:
@@ -917,7 +938,8 @@ static SDValue GetNegatedExpression(SDValue Op, SelectionDAG &DAG,
     return DAG.getBuildVector(Op.getValueType(), SDLoc(Op), Ops);
   }
   case ISD::FADD:
-    assert(Options.NoSignedZerosFPMath || Flags.hasNoSignedZeros());
+    assert((Options.NoSignedZerosFPMath || Flags.hasNoSignedZeros()) &&
+           "Expected NSZ fp-flag");
 
     // fold (fneg (fadd A, B)) -> (fsub (fneg A), B)
     if (isNegatibleForFree(Op.getOperand(0), LegalOperations,
@@ -964,6 +986,35 @@ static SDValue GetNegatedExpression(SDValue Op, SelectionDAG &DAG,
                                             LegalOperations, ForCodeSize,
                                             Depth + 1), Flags);
 
+  case ISD::FMA:
+  case ISD::FMAD: {
+    assert((Options.NoSignedZerosFPMath || Flags.hasNoSignedZeros()) &&
+           "Expected NSZ fp-flag");
+
+    SDValue Neg2 = GetNegatedExpression(Op.getOperand(2), DAG, LegalOperations,
+                                        ForCodeSize, Depth + 1);
+
+    char V0 = isNegatibleForFree(Op.getOperand(0), LegalOperations,
+                                 DAG.getTargetLoweringInfo(), &Options,
+                                 ForCodeSize, Depth + 1);
+    char V1 = isNegatibleForFree(Op.getOperand(1), LegalOperations,
+                                 DAG.getTargetLoweringInfo(), &Options,
+                                 ForCodeSize, Depth + 1);
+    if (V0 >= V1) {
+      // fold (fneg (fma X, Y, Z)) -> (fma (fneg X), Y, (fneg Z))
+      SDValue Neg0 = GetNegatedExpression(
+          Op.getOperand(0), DAG, LegalOperations, ForCodeSize, Depth + 1);
+      return DAG.getNode(Op.getOpcode(), SDLoc(Op), Op.getValueType(), Neg0,
+                         Op.getOperand(1), Neg2, Flags);
+    }
+
+    // fold (fneg (fma X, Y, Z)) -> (fma X, (fneg Y), (fneg Z))
+    SDValue Neg1 = GetNegatedExpression(Op.getOperand(1), DAG, LegalOperations,
+                                        ForCodeSize, Depth + 1);
+    return DAG.getNode(Op.getOpcode(), SDLoc(Op), Op.getValueType(),
+                       Op.getOperand(0), Neg1, Neg2, Flags);
+  }
+
   case ISD::FP_EXTEND:
   case ISD::FSIN:
     return DAG.getNode(Op.getOpcode(), SDLoc(Op), Op.getValueType(),
index 1f520f9a398c5d9ae3e1ed93cd7deb1188e33988..f9e87955270be70fdf8d4eace326540ef3ec0624 100644 (file)
@@ -5,14 +5,14 @@
 define float @test_fneg_fma_subx_y_negz_f32(float %w, float %x, float %y, float %z)  {
 ; FMA3-LABEL: test_fneg_fma_subx_y_negz_f32:
 ; FMA3:       # %bb.0: # %entry
-; FMA3-NEXT:    vsubss %xmm1, %xmm0, %xmm0
-; FMA3-NEXT:    vfnmadd213ss {{.*#+}} xmm0 = -(xmm2 * xmm0) + xmm3
+; FMA3-NEXT:    vsubss %xmm0, %xmm1, %xmm0
+; FMA3-NEXT:    vfmadd213ss {{.*#+}} xmm0 = (xmm2 * xmm0) + xmm3
 ; FMA3-NEXT:    retq
 ;
 ; FMA4-LABEL: test_fneg_fma_subx_y_negz_f32:
 ; FMA4:       # %bb.0: # %entry
-; FMA4-NEXT:    vsubss %xmm1, %xmm0, %xmm0
-; FMA4-NEXT:    vfnmaddss %xmm3, %xmm2, %xmm0, %xmm0
+; FMA4-NEXT:    vsubss %xmm0, %xmm1, %xmm0
+; FMA4-NEXT:    vfmaddss %xmm3, %xmm2, %xmm0, %xmm0
 ; FMA4-NEXT:    retq
 entry:
   %subx = fsub nsz float %w, %x
@@ -25,14 +25,14 @@ entry:
 define float @test_fneg_fma_x_suby_negz_f32(float %w, float %x, float %y, float %z)  {
 ; FMA3-LABEL: test_fneg_fma_x_suby_negz_f32:
 ; FMA3:       # %bb.0: # %entry
-; FMA3-NEXT:    vsubss %xmm2, %xmm0, %xmm0
-; FMA3-NEXT:    vfnmadd213ss {{.*#+}} xmm0 = -(xmm1 * xmm0) + xmm3
+; FMA3-NEXT:    vsubss %xmm0, %xmm2, %xmm0
+; FMA3-NEXT:    vfmadd213ss {{.*#+}} xmm0 = (xmm1 * xmm0) + xmm3
 ; FMA3-NEXT:    retq
 ;
 ; FMA4-LABEL: test_fneg_fma_x_suby_negz_f32:
 ; FMA4:       # %bb.0: # %entry
-; FMA4-NEXT:    vsubss %xmm2, %xmm0, %xmm0
-; FMA4-NEXT:    vfnmaddss %xmm3, %xmm0, %xmm1, %xmm0
+; FMA4-NEXT:    vsubss %xmm0, %xmm2, %xmm0
+; FMA4-NEXT:    vfmaddss %xmm3, %xmm0, %xmm1, %xmm0
 ; FMA4-NEXT:    retq
 entry:
   %suby = fsub nsz float %w, %y
@@ -45,16 +45,16 @@ entry:
 define float @test_fneg_fma_subx_suby_negz_f32(float %w, float %x, float %y, float %z)  {
 ; FMA3-LABEL: test_fneg_fma_subx_suby_negz_f32:
 ; FMA3:       # %bb.0: # %entry
-; FMA3-NEXT:    vsubss %xmm1, %xmm0, %xmm1
+; FMA3-NEXT:    vsubss %xmm0, %xmm1, %xmm1
 ; FMA3-NEXT:    vsubss %xmm2, %xmm0, %xmm0
-; FMA3-NEXT:    vfnmadd213ss {{.*#+}} xmm0 = -(xmm1 * xmm0) + xmm3
+; FMA3-NEXT:    vfmadd213ss {{.*#+}} xmm0 = (xmm1 * xmm0) + xmm3
 ; FMA3-NEXT:    retq
 ;
 ; FMA4-LABEL: test_fneg_fma_subx_suby_negz_f32:
 ; FMA4:       # %bb.0: # %entry
-; FMA4-NEXT:    vsubss %xmm1, %xmm0, %xmm1
+; FMA4-NEXT:    vsubss %xmm0, %xmm1, %xmm1
 ; FMA4-NEXT:    vsubss %xmm2, %xmm0, %xmm0
-; FMA4-NEXT:    vfnmaddss %xmm3, %xmm0, %xmm1, %xmm0
+; FMA4-NEXT:    vfmaddss %xmm3, %xmm0, %xmm1, %xmm0
 ; FMA4-NEXT:    retq
 entry:
   %subx = fsub nsz float %w, %x