]> granicus.if.org Git - llvm/commitdiff
[NVPTX] Fix the codegen for llvm.round.
authorBixia Zheng <bixia@google.com>
Mon, 1 Apr 2019 16:10:26 +0000 (16:10 +0000)
committerBixia Zheng <bixia@google.com>
Mon, 1 Apr 2019 16:10:26 +0000 (16:10 +0000)
Summary:
Previously, we translate llvm.round to PTX cvt.rni, which rounds to the
even interger when the source is equidistant between two integers. This
is not correct as llvm.round should round away from zero. This change
replaces llvm.round with a round away from zero implementation through
target specific custom lowering.

Modify a few affected tests to not check for cvt.rni. Instead, we check
for the use of a few constants used in implementing round. We are also
adding CUDA runnable tests to check for the values produced by
llvm.round to test-suites/External/CUDA.

Reviewers: tra

Subscribers: jholewinski, sanjoy, jlebar, hiraditya, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D59947

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@357407 91177308-0d34-0410-b5e6-96231b3b80d8

lib/Target/NVPTX/NVPTXISelLowering.cpp
lib/Target/NVPTX/NVPTXISelLowering.h
lib/Target/NVPTX/NVPTXInstrInfo.td
test/CodeGen/NVPTX/f16-instructions.ll
test/CodeGen/NVPTX/f16x2-instructions.ll
test/CodeGen/NVPTX/math-intrins.ll

index 2e955f89049c3d8d3d47803431e372d558efca80..cae94d497596077cc0ee39742c52c4f72b73c5bc 100644 (file)
@@ -546,13 +546,19 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
 
   // These map to conversion instructions for scalar FP types.
   for (const auto &Op : {ISD::FCEIL, ISD::FFLOOR, ISD::FNEARBYINT, ISD::FRINT,
-                         ISD::FROUND, ISD::FTRUNC}) {
+                         ISD::FTRUNC}) {
     setOperationAction(Op, MVT::f16, Legal);
     setOperationAction(Op, MVT::f32, Legal);
     setOperationAction(Op, MVT::f64, Legal);
     setOperationAction(Op, MVT::v2f16, Expand);
   }
 
+  setOperationAction(ISD::FROUND, MVT::f16, Promote);
+  setOperationAction(ISD::FROUND, MVT::v2f16, Expand);
+  setOperationAction(ISD::FROUND, MVT::f32, Custom);
+  setOperationAction(ISD::FROUND, MVT::f64, Custom);
+
+
   // 'Expand' implements FCOPYSIGN without calling an external library.
   setOperationAction(ISD::FCOPYSIGN, MVT::f16, Expand);
   setOperationAction(ISD::FCOPYSIGN, MVT::v2f16, Expand);
@@ -2068,6 +2074,100 @@ SDValue NVPTXTargetLowering::LowerShiftLeftParts(SDValue Op,
   }
 }
 
+SDValue NVPTXTargetLowering::LowerFROUND(SDValue Op, SelectionDAG &DAG) const {
+  EVT VT = Op.getValueType();
+
+  if (VT == MVT::f32)
+    return LowerFROUND32(Op, DAG);
+
+  if (VT == MVT::f64)
+    return LowerFROUND64(Op, DAG);
+
+  llvm_unreachable("unhandled type");
+}
+
+// This is the the rounding method used in CUDA libdevice in C like code:
+// float roundf(float A)
+// {
+//   float RoundedA = (float) (int) ( A > 0 ? (A + 0.5f) : (A - 0.5f));
+//   RoundedA = abs(A) > 0x1.0p23 ? A : RoundedA;
+//   return abs(A) < 0.5 ? (float)(int)A : RoundedA;
+// }
+SDValue NVPTXTargetLowering::LowerFROUND32(SDValue Op,
+                                           SelectionDAG &DAG) const {
+  SDLoc SL(Op);
+  SDValue A = Op.getOperand(0);
+  EVT VT = Op.getValueType();
+
+  SDValue AbsA = DAG.getNode(ISD::FABS, SL, VT, A);
+
+  // RoundedA = (float) (int) ( A > 0 ? (A + 0.5f) : (A - 0.5f))
+  SDValue Bitcast  = DAG.getNode(ISD::BITCAST, SL, MVT::i32, A);
+  const int SignBitMask = 0x80000000;
+  SDValue Sign = DAG.getNode(ISD::AND, SL, MVT::i32, Bitcast,
+                             DAG.getConstant(SignBitMask, SL, MVT::i32));
+  const int PointFiveInBits = 0x3F000000;
+  SDValue PointFiveWithSignRaw =
+      DAG.getNode(ISD::OR, SL, MVT::i32, Sign,
+                  DAG.getConstant(PointFiveInBits, SL, MVT::i32));
+  SDValue PointFiveWithSign =
+      DAG.getNode(ISD::BITCAST, SL, VT, PointFiveWithSignRaw);
+  SDValue AdjustedA = DAG.getNode(ISD::FADD, SL, VT, A, PointFiveWithSign);
+  SDValue RoundedA = DAG.getNode(ISD::FTRUNC, SL, VT, AdjustedA);
+
+  // RoundedA = abs(A) > 0x1.0p23 ? A : RoundedA;
+  EVT SetCCVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
+  SDValue IsLarge =
+      DAG.getSetCC(SL, SetCCVT, AbsA, DAG.getConstantFP(pow(2.0, 23.0), SL, VT),
+                   ISD::SETOGT);
+  RoundedA = DAG.getNode(ISD::SELECT, SL, VT, IsLarge, A, RoundedA);
+
+  // return abs(A) < 0.5 ? (float)(int)A : RoundedA;
+  SDValue IsSmall =DAG.getSetCC(SL, SetCCVT, AbsA,
+                                DAG.getConstantFP(0.5, SL, VT), ISD::SETOLT);
+  SDValue RoundedAForSmallA = DAG.getNode(ISD::FTRUNC, SL, VT, A);
+  return DAG.getNode(ISD::SELECT, SL, VT, IsSmall, RoundedAForSmallA, RoundedA);
+}
+
+// The implementation of round(double) is similar to that of round(float) in
+// that they both separate the value range into three regions and use a method
+// specific to the region to round the values. However, round(double) first
+// calculates the round of the absolute value and then adds the sign back while
+// round(float) directly rounds the value with sign.
+SDValue NVPTXTargetLowering::LowerFROUND64(SDValue Op,
+                                           SelectionDAG &DAG) const {
+  SDLoc SL(Op);
+  SDValue A = Op.getOperand(0);
+  EVT VT = Op.getValueType();
+
+  SDValue AbsA = DAG.getNode(ISD::FABS, SL, VT, A);
+
+  // double RoundedA = (double) (int) (abs(A) + 0.5f);
+  SDValue AdjustedA = DAG.getNode(ISD::FADD, SL, VT, AbsA,
+                                  DAG.getConstantFP(0.5, SL, VT));
+  SDValue RoundedA = DAG.getNode(ISD::FTRUNC, SL, VT, AdjustedA);
+
+  // RoundedA = abs(A) < 0.5 ? (double)0 : RoundedA;
+  EVT SetCCVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
+  SDValue IsSmall =DAG.getSetCC(SL, SetCCVT, AbsA,
+                                DAG.getConstantFP(0.5, SL, VT), ISD::SETOLT);
+  RoundedA = DAG.getNode(ISD::SELECT, SL, VT, IsSmall,
+                         DAG.getConstantFP(0, SL, VT),
+                         RoundedA);
+
+  // Add sign to rounded_A
+  RoundedA = DAG.getNode(ISD::FCOPYSIGN, SL, VT, RoundedA, A);
+  DAG.getNode(ISD::FTRUNC, SL, VT, A);
+
+  // RoundedA = abs(A) > 0x1.0p52 ? A : RoundedA;
+  SDValue IsLarge =
+      DAG.getSetCC(SL, SetCCVT, AbsA, DAG.getConstantFP(pow(2.0, 52.0), SL, VT),
+                   ISD::SETOGT);
+  return DAG.getNode(ISD::SELECT, SL, VT, IsLarge, A, RoundedA);
+}
+
+
+
 SDValue
 NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
   switch (Op.getOpcode()) {
@@ -2098,6 +2198,8 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
     return LowerShiftRightParts(Op, DAG);
   case ISD::SELECT:
     return LowerSelect(Op, DAG);
+  case ISD::FROUND:
+    return LowerFROUND(Op, DAG);
   default:
     llvm_unreachable("Custom lowering not defined for operation");
   }
index bbcc35f49d99c72a1546b28d70f3443645de62e6..ef645fc1e541f86f2fc026d2976bbe4bc64b2b3a 100644 (file)
@@ -556,6 +556,10 @@ private:
   SDValue LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerEXTRACT_VECTOR_ELT(SDValue Op, SelectionDAG &DAG) const;
 
+  SDValue LowerFROUND(SDValue Op, SelectionDAG &DAG) const;
+  SDValue LowerFROUND32(SDValue Op, SelectionDAG &DAG) const;
+  SDValue LowerFROUND64(SDValue Op, SelectionDAG &DAG) const;
+
   SDValue LowerLOAD(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerLOADi1(SDValue Op, SelectionDAG &DAG) const;
 
index 1bf556c9287a7c7bb62a8849174b87e6fdcd8943..2ee90abb411057acb1b87f21b793365482fb3d89 100644 (file)
@@ -3002,15 +3002,6 @@ def : Pat<(ffloor Float32Regs:$a),
 def : Pat<(ffloor Float64Regs:$a),
           (CVT_f64_f64 Float64Regs:$a, CvtRMI)>;
 
-def : Pat<(f16 (fround Float16Regs:$a)),
-          (CVT_f16_f16 Float16Regs:$a, CvtRNI)>;
-def : Pat<(fround Float32Regs:$a),
-          (CVT_f32_f32 Float32Regs:$a, CvtRNI_FTZ)>, Requires<[doF32FTZ]>;
-def : Pat<(f32 (fround Float32Regs:$a)),
-          (CVT_f32_f32 Float32Regs:$a, CvtRNI)>, Requires<[doNoF32FTZ]>;
-def : Pat<(f64 (fround Float64Regs:$a)),
-          (CVT_f64_f64 Float64Regs:$a, CvtRNI)>;
-
 def : Pat<(ftrunc Float16Regs:$a),
           (CVT_f16_f16 Float16Regs:$a, CvtRZI)>;
 def : Pat<(ftrunc Float32Regs:$a),
index 7788adc8698934d87f5d9c590d9680f9c2910857..9aa81dac1262a11046723ae0df8b2dba48713bf5 100644 (file)
@@ -1107,9 +1107,11 @@ define half @test_nearbyint(half %a) #0 {
 }\r
 \r
 ; CHECK-LABEL: test_round(\r
-; CHECK:      ld.param.b16    [[A:%h[0-9]+]], [test_round_param_0];\r
-; CHECK:      cvt.rni.f16.f16 [[R:%h[0-9]+]], [[A]];\r
-; CHECK:      st.param.b16    [func_retval0+0], [[R]];\r
+; CHECK:      ld.param.b16    {{.*}}, [test_round_param_0];\r
+; check the use of sign mask and 0.5 to implement round\r
+; CHECK:      and.b32 [[R:%r[0-9]+]], {{.*}}, -2147483648;\r
+; CHECK:      or.b32 {{.*}}, [[R]], 1056964608;\r
+; CHECK:      st.param.b16    [func_retval0+0], {{.*}};\r
 ; CHECK:      ret;\r
 define half @test_round(half %a) #0 {\r
   %r = call half @llvm.round.f16(half %a)\r
index a8996815af4134ebc3aa76859b1634d09c269a6e..44dda09a902d9028c27e98139cfd4540491373dd 100644 (file)
@@ -1378,12 +1378,13 @@ define <2 x half> @test_nearbyint(<2 x half> %a) #0 {
 }\r
 \r
 ; CHECK-LABEL: test_round(\r
-; CHECK:      ld.param.b32    [[A:%hh[0-9]+]], [test_round_param_0];\r
-; CHECK-DAG:  mov.b32         {[[A0:%h[0-9]+]], [[A1:%h[0-9]+]]}, [[A]];\r
-; CHECK-DAG:  cvt.rni.f16.f16 [[R1:%h[0-9]+]], [[A1]];\r
-; CHECK-DAG:  cvt.rni.f16.f16 [[R0:%h[0-9]+]], [[A0]];\r
-; CHECK:      mov.b32         [[R:%hh[0-9]+]], {[[R0]], [[R1]]}\r
-; CHECK:      st.param.b32    [func_retval0+0], [[R]];\r
+; CHECK:      ld.param.b32    {{.*}}, [test_round_param_0];\r
+; check the use of sign mask and 0.5 to implement round\r
+; CHECK:      and.b32 [[R1:%r[0-9]+]], {{.*}}, -2147483648;\r
+; CHECK:      or.b32 {{.*}}, [[R1]], 1056964608;\r
+; CHECK:      and.b32 [[R2:%r[0-9]+]], {{.*}}, -2147483648;\r
+; CHECK:      or.b32 {{.*}}, [[R2]], 1056964608;\r
+; CHECK:      st.param.b32    [func_retval0+0], {{.*}};\r
 ; CHECK:      ret;\r
 define <2 x half> @test_round(<2 x half> %a) #0 {\r
   %r = call <2 x half> @llvm.round.f16(<2 x half> %a)\r
index 828a8807dcfaabab89b023bda8cfa095bfb0d10e..412b25c7a3befe0e37f9ee4b250c9b964f9c6a6d 100644 (file)
@@ -74,21 +74,27 @@ define double @floor_double(double %a) {
 
 ; CHECK-LABEL: round_float
 define float @round_float(float %a) {
-  ; CHECK: cvt.rni.f32.f32
+; check the use of sign mask and 0.5 to implement round
+; CHECK: and.b32 [[R1:%r[0-9]+]], {{.*}}, -2147483648;
+; CHECK: or.b32 {{.*}}, [[R1]], 1056964608;
   %b = call float @llvm.round.f32(float %a)
   ret float %b
 }
 
 ; CHECK-LABEL: round_float_ftz
 define float @round_float_ftz(float %a) #1 {
-  ; CHECK: cvt.rni.ftz.f32.f32
+; check the use of sign mask and 0.5 to implement round
+; CHECK: and.b32 [[R1:%r[0-9]+]], {{.*}}, -2147483648;
+; CHECK: or.b32 {{.*}}, [[R1]], 1056964608;
   %b = call float @llvm.round.f32(float %a)
   ret float %b
 }
 
 ; CHECK-LABEL: round_double
 define double @round_double(double %a) {
-  ; CHECK: cvt.rni.f64.f64
+; check the use of 0.5 to implement round
+; CHECK: setp.lt.f64 {{.*}}, [[R:%fd[0-9]+]], 0d3FE0000000000000;
+; CHECK: add.rn.f64 {{.*}}, [[R]], 0d3FE0000000000000;
   %b = call double @llvm.round.f64(double %a)
   ret double %b
 }