]> granicus.if.org Git - llvm/commitdiff
[DAG Combiner] Fix the native computation of the Newton series for reciprocals
authorEvandro Menezes <e.menezes@samsung.com>
Thu, 10 Nov 2016 23:31:06 +0000 (23:31 +0000)
committerEvandro Menezes <e.menezes@samsung.com>
Thu, 10 Nov 2016 23:31:06 +0000 (23:31 +0000)
The generic infrastructure to compute the Newton series for reciprocal and
reciprocal square root was conceived to allow a target to compute the series
itself.  However, the original code did not properly consider this condition
if returned by a target.  This patch addresses the issues to allow a target
to compute the series on its own.

Differential revision: https://reviews.llvm.org/D22975

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@286523 91177308-0d34-0410-b5e6-96231b3b80d8

include/llvm/Target/TargetLowering.h
lib/CodeGen/SelectionDAG/DAGCombiner.cpp
lib/Target/AArch64/AArch64ISelLowering.cpp
lib/Target/AArch64/AArch64ISelLowering.h
lib/Target/AMDGPU/AMDGPUISelLowering.cpp
lib/Target/AMDGPU/AMDGPUISelLowering.h
lib/Target/PowerPC/PPCISelLowering.cpp
lib/Target/PowerPC/PPCISelLowering.h
lib/Target/X86/X86ISelLowering.cpp
lib/Target/X86/X86ISelLowering.h

index 823d91072fc1eec40bd2e6fc5797168e038a75e1..826b4d5a1804270eecd237bb7b270ca7063abadc 100644 (file)
@@ -2986,21 +2986,24 @@ public:
   /// Hooks for building estimates in place of slower divisions and square
   /// roots.
 
-  /// Return a reciprocal square root estimate value for the input operand.
+  /// Return either a square root or its reciprocal estimate value for the input
+  /// operand.
   /// \p Enabled is a ReciprocalEstimate enum with value either 'Unspecified' or
   /// 'Enabled' as set by a potential default override attribute.
   /// If \p RefinementSteps is 'Unspecified', the number of Newton-Raphson
   /// refinement iterations required to generate a sufficient (though not
   /// necessarily IEEE-754 compliant) estimate is returned in that parameter.
   /// The boolean UseOneConstNR output is used to select a Newton-Raphson
-  /// algorithm implementation that uses one constant or two constants.
+  /// algorithm implementation that uses either one or two constants.
+  /// The boolean Reciprocal is used to select whether the estimate is for the
+  /// square root of the input operand or the reciprocal of its square root.
   /// A target may choose to implement its own refinement within this function.
   /// If that's true, then return '0' as the number of RefinementSteps to avoid
   /// any further refinement of the estimate.
   /// An empty SDValue return means no estimate sequence can be created.
-  virtual SDValue getRsqrtEstimate(SDValue Operand, SelectionDAG &DAG,
-                                   int Enabled, int &RefinementSteps,
-                                   bool &UseOneConstNR) const {
+  virtual SDValue getSqrtEstimate(SDValue Operand, SelectionDAG &DAG,
+                                  int Enabled, int &RefinementSteps,
+                                  bool &UseOneConstNR, bool Reciprocal) const {
     return SDValue();
   }
 
index 6c0f4354029d60385c4f4d8646e6d9b7029797df..63da11659ed2c72b33e564b1d3cd6b70dee1a278 100644 (file)
@@ -14928,6 +14928,12 @@ SDValue DAGCombiner::BuildUDIV(SDNode *N) {
   return S;
 }
 
+/// Newton iteration for a function: F(X) is X_{i+1} = X_i - F(X_i)/F'(X_i)
+/// For the reciprocal, we need to find the zero of the function:
+///   F(X) = A X - 1 [which has a zero at X = 1/A]
+///     =>
+///   X_{i+1} = X_i (2 - A X_i) = X_i + X_i (1 - A X_i) [this second form
+///     does not require additional intermediate precision]
 SDValue DAGCombiner::BuildReciprocalEstimate(SDValue Op, SDNodeFlags *Flags) {
   if (Level >= AfterLegalizeDAG)
     return SDValue();
@@ -14947,19 +14953,13 @@ SDValue DAGCombiner::BuildReciprocalEstimate(SDValue Op, SDNodeFlags *Flags) {
   // refinement steps.
   int Iterations = TLI.getDivRefinementSteps(VT, MF);
   if (SDValue Est = TLI.getRecipEstimate(Op, DAG, Enabled, Iterations)) {
+    AddToWorklist(Est.getNode());
+
     if (Iterations) {
-      // Newton iteration for a function: F(X) is X_{i+1} = X_i - F(X_i)/F'(X_i)
-      // For the reciprocal, we need to find the zero of the function:
-      //   F(X) = A X - 1 [which has a zero at X = 1/A]
-      //     =>
-      //   X_{i+1} = X_i (2 - A X_i) = X_i + X_i (1 - A X_i) [this second form
-      //     does not require additional intermediate precision]
       EVT VT = Op.getValueType();
       SDLoc DL(Op);
       SDValue FPOne = DAG.getConstantFP(1.0, DL, VT);
 
-      AddToWorklist(Est.getNode());
-
       // Newton iterations: Est = Est + Est (1 - Arg * Est)
       for (int i = 0; i < Iterations; ++i) {
         SDValue NewEst = DAG.getNode(ISD::FMUL, DL, VT, Op, Est, Flags);
@@ -15100,12 +15100,30 @@ SDValue DAGCombiner::buildSqrtEstimateImpl(SDValue Op, SDNodeFlags *Flags,
 
   bool UseOneConstNR = false;
   if (SDValue Est =
-      TLI.getRsqrtEstimate(Op, DAG, Enabled, Iterations, UseOneConstNR)) {
+      TLI.getSqrtEstimate(Op, DAG, Enabled, Iterations, UseOneConstNR,
+                          Reciprocal)) {
     AddToWorklist(Est.getNode());
+
     if (Iterations) {
       Est = UseOneConstNR
-      ? buildSqrtNROneConst(Op, Est, Iterations, Flags, Reciprocal)
-      : buildSqrtNRTwoConst(Op, Est, Iterations, Flags, Reciprocal);
+            ? buildSqrtNROneConst(Op, Est, Iterations, Flags, Reciprocal)
+            : buildSqrtNRTwoConst(Op, Est, Iterations, Flags, Reciprocal);
+
+      if (!Reciprocal) {
+        // Unfortunately, Est is now NaN if the input was exactly 0.0.
+        // Select out this case and force the answer to 0.0.
+        EVT VT = Op.getValueType();
+        SDLoc DL(Op);
+
+        SDValue FPZero = DAG.getConstantFP(0.0, DL, VT);
+        EVT CCVT = getSetCCResultType(VT);
+        SDValue ZeroCmp = DAG.getSetCC(DL, CCVT, Op, FPZero, ISD::SETEQ);
+        AddToWorklist(ZeroCmp.getNode());
+
+        Est = DAG.getNode(VT.isVector() ? ISD::VSELECT : ISD::SELECT, DL, VT,
+                          ZeroCmp, FPZero, Est);
+        AddToWorklist(Est.getNode());
+      }
     }
     return Est;
   }
@@ -15118,23 +15136,7 @@ SDValue DAGCombiner::buildRsqrtEstimate(SDValue Op, SDNodeFlags *Flags) {
 }
 
 SDValue DAGCombiner::buildSqrtEstimate(SDValue Op, SDNodeFlags *Flags) {
-  SDValue Est = buildSqrtEstimateImpl(Op, Flags, false);
-  if (!Est)
-    return SDValue();
-
-  // Unfortunately, Est is now NaN if the input was exactly 0.
-  // Select out this case and force the answer to 0.
-  EVT VT = Est.getValueType();
-  SDLoc DL(Op);
-  SDValue Zero = DAG.getConstantFP(0.0, DL, VT);
-  EVT CCVT = getSetCCResultType(VT);
-  SDValue ZeroCmp = DAG.getSetCC(DL, CCVT, Op, Zero, ISD::SETEQ);
-  AddToWorklist(ZeroCmp.getNode());
-
-  Est = DAG.getNode(VT.isVector() ? ISD::VSELECT : ISD::SELECT, DL, VT, ZeroCmp,
-                    Zero, Est);
-  AddToWorklist(Est.getNode());
-  return Est;
+  return buildSqrtEstimateImpl(Op, Flags, false);
 }
 
 /// Return true if base is a frame index, which is known not to alias with
index 5bea916ad2ab44454fcb135d29d59e3c302a887a..a9d2e1fc256eee70fb5b8b67afdc923f5738073a 100644 (file)
@@ -4644,10 +4644,11 @@ static SDValue getEstimate(const AArch64Subtarget *ST, unsigned Opcode,
   return SDValue();
 }
 
-SDValue AArch64TargetLowering::getRsqrtEstimate(SDValue Operand,
-                                                SelectionDAG &DAG, int Enabled,
-                                                int &ExtraSteps,
-                                                bool &UseOneConst) const {
+SDValue AArch64TargetLowering::getSqrtEstimate(SDValue Operand,
+                                               SelectionDAG &DAG, int Enabled,
+                                               int &ExtraSteps,
+                                               bool &UseOneConst,
+                                               bool Reciprocal) const {
   if (Enabled == ReciprocalEstimate::Enabled ||
       (Enabled == ReciprocalEstimate::Unspecified && Subtarget->useRSqrt()))
     if (SDValue Estimate = getEstimate(Subtarget, AArch64ISD::FRSQRTE, Operand,
index c9ae7cd94bcc0ee19585e2baeb678350adab5441..7b317d6ff5cc5c72e36d5dbcc005137d8a2c5581 100644 (file)
@@ -534,8 +534,9 @@ private:
 
   SDValue BuildSDIVPow2(SDNode *N, const APInt &Divisor, SelectionDAG &DAG,
                         std::vector<SDNode *> *Created) const override;
-  SDValue getRsqrtEstimate(SDValue Operand, SelectionDAG &DAG, int Enabled,
-                           int &ExtraSteps, bool &UseOneConst) const override;
+  SDValue getSqrtEstimate(SDValue Operand, SelectionDAG &DAG, int Enabled,
+                          int &ExtraSteps, bool &UseOneConst,
+                          bool Reciprocal) const override;
   SDValue getRecipEstimate(SDValue Operand, SelectionDAG &DAG, int Enabled,
                            int &ExtraSteps) const override;
   unsigned combineRepeatedFPDivisors() const override;
index 5a871489acd489593cbb33fe1786aeff11e32834..a4a5de126dfb61496d860c49011287ca86ff29f6 100644 (file)
@@ -2978,10 +2978,11 @@ const char* AMDGPUTargetLowering::getTargetNodeName(unsigned Opcode) const {
   return nullptr;
 }
 
-SDValue AMDGPUTargetLowering::getRsqrtEstimate(SDValue Operand,
-                                               SelectionDAG &DAG, int Enabled,
-                                               int &RefinementSteps,
-                                               bool &UseOneConstNR) const {
+SDValue AMDGPUTargetLowering::getSqrtEstimate(SDValue Operand,
+                                              SelectionDAG &DAG, int Enabled,
+                                              int &RefinementSteps,
+                                              bool &UseOneConstNR,
+                                              bool Reciprocal) const {
   EVT VT = Operand.getValueType();
 
   if (VT == MVT::f32) {
index 07d2db82e328cb8798d39b963fc20dc9131ff45a..6c6fc2eed3b37d6c71cf0287ccbfbe6dfa21618f 100644 (file)
@@ -172,9 +172,9 @@ public:
   bool isFsqrtCheap(SDValue Operand, SelectionDAG &DAG) const override {
     return true;
   }
-  SDValue getRsqrtEstimate(SDValue Operand, SelectionDAG &DAG, int Enabled,
-                           int &RefinementSteps,
-                           bool &UseOneConstNR) const override;
+  SDValue getSqrtEstimate(SDValue Operand, SelectionDAG &DAG, int Enabled,
+                           int &RefinementSteps, bool &UseOneConstNR,
+                           bool Reciprocal) const override;
   SDValue getRecipEstimate(SDValue Operand, SelectionDAG &DAG, int Enabled,
                            int &RefinementSteps) const override;
 
index d54c76e52c057d389d2fcf03ca81957091561ce5..22f711096900af7e17df82e9e6b61407d94e72bd 100644 (file)
@@ -9637,9 +9637,10 @@ static int getEstimateRefinementSteps(EVT VT, const PPCSubtarget &Subtarget) {
   return RefinementSteps;
 }
 
-SDValue PPCTargetLowering::getRsqrtEstimate(SDValue Operand, SelectionDAG &DAG,
-                                            int Enabled, int &RefinementSteps,
-                                            bool &UseOneConstNR) const {
+SDValue PPCTargetLowering::getSqrtEstimate(SDValue Operand, SelectionDAG &DAG,
+                                           int Enabled, int &RefinementSteps,
+                                           bool &UseOneConstNR,
+                                           bool Reciprocal) const {
   EVT VT = Operand.getValueType();
   if ((VT == MVT::f32 && Subtarget.hasFRSQRTES()) ||
       (VT == MVT::f64 && Subtarget.hasFRSQRTE()) ||
index 2944e99db01de864a17f6ea34d05ac4360273bbb..689a2e6bb68ae84cfda24458f244a32b6e0f2263 100644 (file)
@@ -968,9 +968,9 @@ namespace llvm {
     SDValue DAGCombineTruncBoolExt(SDNode *N, DAGCombinerInfo &DCI) const;
     SDValue combineFPToIntToFP(SDNode *N, DAGCombinerInfo &DCI) const;
 
-    SDValue getRsqrtEstimate(SDValue Operand, SelectionDAG &DAG, int Enabled,
-                             int &RefinementSteps,
-                             bool &UseOneConstNR) const override;
+    SDValue getSqrtEstimate(SDValue Operand, SelectionDAG &DAG, int Enabled,
+                            int &RefinementSteps, bool &UseOneConstNR,
+                            bool Reciprocal) const override;
     SDValue getRecipEstimate(SDValue Operand, SelectionDAG &DAG, int Enabled,
                              int &RefinementSteps) const override;
     unsigned combineRepeatedFPDivisors() const override;
index fceaecf0bc3fc8b7d0ec5643de93bb5ad5f0c578..5f4a29d9812aca7d5434d38df3a8b170e229e910 100644 (file)
@@ -15296,10 +15296,11 @@ bool X86TargetLowering::isFsqrtCheap(SDValue Op, SelectionDAG &DAG) const {
 
 /// The minimum architected relative accuracy is 2^-12. We need one
 /// Newton-Raphson step to have a good float result (24 bits of precision).
-SDValue X86TargetLowering::getRsqrtEstimate(SDValue Op,
-                                            SelectionDAG &DAG, int Enabled,
-                                            int &RefinementSteps,
-                                            bool &UseOneConstNR) const {
+SDValue X86TargetLowering::getSqrtEstimate(SDValue Op,
+                                           SelectionDAG &DAG, int Enabled,
+                                           int &RefinementSteps,
+                                           bool &UseOneConstNR,
+                                           bool Reciprocal) const {
   EVT VT = Op.getValueType();
 
   // SSE1 has rsqrtss and rsqrtps. AVX adds a 256-bit variant for rsqrtps.
index dabef9d216073b6571ff6198de6938a69d2c77a5..b5903e8ce391f843956500db466625febdcc6ae0 100644 (file)
@@ -1268,9 +1268,9 @@ namespace llvm {
     bool isFsqrtCheap(SDValue Operand, SelectionDAG &DAG) const override;
 
     /// Use rsqrt* to speed up sqrt calculations.
-    SDValue getRsqrtEstimate(SDValue Operand, SelectionDAG &DAG, int Enabled,
-                             int &RefinementSteps,
-                             bool &UseOneConstNR) const override;
+    SDValue getSqrtEstimate(SDValue Operand, SelectionDAG &DAG, int Enabled,
+                            int &RefinementSteps, bool &UseOneConstNR,
+                            bool Reciprocal) const override;
 
     /// Use rcp* to speed up fdiv calculations.
     SDValue getRecipEstimate(SDValue Operand, SelectionDAG &DAG, int Enabled,