From 6f09ea3f571cd2f4c39fff52deecec7cfc62d0c4 Mon Sep 17 00:00:00 2001 From: Justin Lebar Date: Tue, 31 Jan 2017 23:08:57 +0000 Subject: [PATCH] [NVPTX] Compute approx sqrt as 1/rsqrt(x) rather than x*rsqrt(x). x*rsqrt(x) returns NaN for x == 0, whereas 1/rsqrt(x) returns 0, as desired. Verified that the particular nvptx approximate instructions here do in fact return 0 for x = 0. git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@293713 91177308-0d34-0410-b5e6-96231b3b80d8 --- lib/Target/NVPTX/NVPTXISelLowering.cpp | 11 ++++++++--- test/CodeGen/NVPTX/fast-math.ll | 4 ++-- test/CodeGen/NVPTX/sqrt-approx.ll | 8 +++++--- 3 files changed, 15 insertions(+), 8 deletions(-) diff --git a/lib/Target/NVPTX/NVPTXISelLowering.cpp b/lib/Target/NVPTX/NVPTXISelLowering.cpp index 194e46b0448..9584776e185 100644 --- a/lib/Target/NVPTX/NVPTXISelLowering.cpp +++ b/lib/Target/NVPTX/NVPTXISelLowering.cpp @@ -1080,9 +1080,14 @@ SDValue NVPTXTargetLowering::getSqrtEstimate(SDValue Operand, SelectionDAG &DAG, return MakeIntrinsicCall(Ftz ? Intrinsic::nvvm_sqrt_approx_ftz_f : Intrinsic::nvvm_sqrt_approx_f); else { - // There's no sqrt.approx.f64 instruction, so we emit x * rsqrt(x). - return DAG.getNode(ISD::FMUL, DL, VT, Operand, - MakeIntrinsicCall(Intrinsic::nvvm_rsqrt_approx_d)); + // There's no sqrt.approx.f64 instruction, so we emit + // reciprocal(rsqrt(x)). This is faster than + // select(x == 0, 0, x * rsqrt(x)). (In fact, it's faster than plain + // x * rsqrt(x).) + return DAG.getNode( + ISD::INTRINSIC_WO_CHAIN, DL, VT, + DAG.getConstant(Intrinsic::nvvm_rcp_approx_ftz_d, DL, MVT::i32), + MakeIntrinsicCall(Intrinsic::nvvm_rsqrt_approx_d)); } } } diff --git a/test/CodeGen/NVPTX/fast-math.ll b/test/CodeGen/NVPTX/fast-math.ll index 528d2c02df5..f925d67434c 100644 --- a/test/CodeGen/NVPTX/fast-math.ll +++ b/test/CodeGen/NVPTX/fast-math.ll @@ -40,11 +40,11 @@ define float @sqrt_div_fast_ftz(float %a, float %b) #0 #1 { } ; There are no fast-math or ftz versions of sqrt and div for f64. We use -; x * rsqrt(x) for sqrt(x), and emit a vanilla divide. +; reciprocal(rsqrt(x)) for sqrt(x), and emit a vanilla divide. ; CHECK-LABEL: sqrt_div_fast_ftz_f64( ; CHECK: rsqrt.approx.f64 -; CHECK: mul.f64 +; CHECK: rcp.approx.ftz.f64 ; CHECK: div.rn.f64 define double @sqrt_div_fast_ftz_f64(double %a, double %b) #0 #1 { %t1 = tail call double @llvm.sqrt.f64(double %a) diff --git a/test/CodeGen/NVPTX/sqrt-approx.ll b/test/CodeGen/NVPTX/sqrt-approx.ll index 5edf9e28a93..1e28db44b80 100644 --- a/test/CodeGen/NVPTX/sqrt-approx.ll +++ b/test/CodeGen/NVPTX/sqrt-approx.ll @@ -59,9 +59,11 @@ define float @test_sqrt_ftz(float %a) #0 #1 { ; CHECK-LABEL test_sqrt64 define double @test_sqrt64(double %a) #0 { -; There's no sqrt.approx.f64 instruction; we emit x * rsqrt.approx.f64(x). +; There's no sqrt.approx.f64 instruction; we emit +; reciprocal(rsqrt.approx.f64(x)). There's no non-ftz approximate reciprocal, +; so we just use the ftz version. ; CHECK: rsqrt.approx.f64 -; CHECK: mul.f64 +; CHECK: rcp.approx.ftz.f64 %ret = tail call double @llvm.sqrt.f64(double %a) ret double %ret } @@ -70,7 +72,7 @@ define double @test_sqrt64(double %a) #0 { define double @test_sqrt64_ftz(double %a) #0 #1 { ; There's no sqrt.approx.ftz.f64 instruction; we just use the non-ftz version. ; CHECK: rsqrt.approx.f64 -; CHECK: mul.f64 +; CHECK: rcp.approx.ftz.f64 %ret = tail call double @llvm.sqrt.f64(double %a) ret double %ret } -- 2.40.0