AMDGPU: Implement expansion of f16 = FP_TO_FP16 f64
authorTom Stellard <thomas.stellard@amd.com>
Tue, 1 Nov 2016 16:31:48 +0000 (16:31 +0000)
committerTom Stellard <thomas.stellard@amd.com>
Tue, 1 Nov 2016 16:31:48 +0000 (16:31 +0000)
I wanted to implement this as a target independent expansion, however when
targets say they want to expand FP_TO_FP16 what they actually want is
the unsafe math expansion when possible and expansion to a libcall in all
other cases.

The only way to make this work as a target independent would be to add logic
to target's TargetLowering construction to mark theses nodes as Expand when
LegalizeDAG can use the unsafe expansion and mark them as LibCall when it
cannot.  I think this would be possible, but I think it would be too fragile
and complex as it would require targets to keep their expansion logic up
to date with the code in LegalizeDAG.

Reviewers: bogner, ab, t.p.northover, arsenm

Subscribers: wdng, llvm-commits, nhaehnle

Differential Revision: https://reviews.llvm.org/D25999

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@285704 91177308-0d34-0410-b5e6-96231b3b80d8

lib/Target/AMDGPU/AMDGPUISelLowering.cpp
lib/Target/AMDGPU/AMDGPUISelLowering.h
test/CodeGen/AMDGPU/fptrunc.ll
test/CodeGen/AMDGPU/trunc-store-f64-to-f16.ll

index a69d1afdea82190ad409b4ec00c8561d98632f39..153b95e7d8187d96876e8a5979c34a48ee28a767 100644 (file)
@@ -279,6 +279,7 @@ AMDGPUTargetLowering::AMDGPUTargetLowering(const TargetMachine &TM,
   }
 
   setOperationAction(ISD::FP16_TO_FP, MVT::f64, Expand);
+  setOperationAction(ISD::FP_TO_FP16, MVT::f64, Custom);
 
   const MVT ScalarIntVTs[] = { MVT::i32, MVT::i64 };
   for (MVT VT : ScalarIntVTs) {
@@ -806,6 +807,7 @@ SDValue AMDGPUTargetLowering::LowerOperation(SDValue Op,
   case ISD::FFLOOR: return LowerFFLOOR(Op, DAG);
   case ISD::SINT_TO_FP: return LowerSINT_TO_FP(Op, DAG);
   case ISD::UINT_TO_FP: return LowerUINT_TO_FP(Op, DAG);
+  case ISD::FP_TO_FP16: return LowerFP_TO_FP16(Op, DAG);
   case ISD::FP_TO_SINT: return LowerFP_TO_SINT(Op, DAG);
   case ISD::FP_TO_UINT: return LowerFP_TO_UINT(Op, DAG);
   case ISD::CTLZ:
@@ -1959,6 +1961,102 @@ SDValue AMDGPUTargetLowering::LowerFP64_TO_INT(SDValue Op, SelectionDAG &DAG,
   return DAG.getNode(ISD::BITCAST, SL, MVT::i64, Result);
 }
 
+SDValue AMDGPUTargetLowering::LowerFP_TO_FP16(SDValue Op, SelectionDAG &DAG) const {
+
+  if (getTargetMachine().Options.UnsafeFPMath) {
+    // There is a generic expand for FP_TO_FP16 with unsafe fast math.
+    return SDValue();
+  }
+
+  SDLoc DL(Op);
+  SDValue N0 = Op.getOperand(0);
+  MVT SVT = N0.getSimpleValueType();
+  assert(SVT == MVT::f64);
+
+  // f64 -> f16 conversion using round-to-nearest-even rounding mode.
+  const unsigned ExpMask = 0x7ff;
+  const unsigned ExpBiasf64 = 1023;
+  const unsigned ExpBiasf16 = 15;
+  SDValue Zero = DAG.getConstant(0, DL, MVT::i32);
+  SDValue One = DAG.getConstant(1, DL, MVT::i32);
+  SDValue U = DAG.getNode(ISD::BITCAST, DL, MVT::i64, N0);
+  SDValue UH = DAG.getNode(ISD::SRL, DL, MVT::i64, U,
+                           DAG.getConstant(32, DL, MVT::i64));
+  UH = DAG.getZExtOrTrunc(UH, DL, MVT::i32);
+  U = DAG.getZExtOrTrunc(U, DL, MVT::i32);
+  SDValue E = DAG.getNode(ISD::SRL, DL, MVT::i32, UH,
+                          DAG.getConstant(20, DL, MVT::i64));
+  E = DAG.getNode(ISD::AND, DL, MVT::i32, E,
+                  DAG.getConstant(ExpMask, DL, MVT::i32));
+  // Subtract the fp64 exponent bias (1023) to get the real exponent and
+  // add the f16 bias (15) to get the biased exponent for the f16 format.
+  E = DAG.getNode(ISD::ADD, DL, MVT::i32, E,
+                  DAG.getConstant(-ExpBiasf64 + ExpBiasf16, DL, MVT::i32));
+
+  SDValue M = DAG.getNode(ISD::SRL, DL, MVT::i32, UH,
+                          DAG.getConstant(8, DL, MVT::i32));
+  M = DAG.getNode(ISD::AND, DL, MVT::i32, M,
+                  DAG.getConstant(0xffe, DL, MVT::i32));
+
+  SDValue MaskedSig = DAG.getNode(ISD::AND, DL, MVT::i32, UH,
+                                  DAG.getConstant(0x1ff, DL, MVT::i32));
+  MaskedSig = DAG.getNode(ISD::OR, DL, MVT::i32, MaskedSig, U);
+
+  SDValue Lo40Set = DAG.getSelectCC(DL, MaskedSig, Zero, Zero, One, ISD::SETEQ);
+  M = DAG.getNode(ISD::OR, DL, MVT::i32, M, Lo40Set);
+
+  // (M != 0 ? 0x0200 : 0) | 0x7c00;
+  SDValue I = DAG.getNode(ISD::OR, DL, MVT::i32,
+      DAG.getSelectCC(DL, M, Zero, DAG.getConstant(0x0200, DL, MVT::i32),
+                      Zero, ISD::SETNE), DAG.getConstant(0x7c00, DL, MVT::i32));
+
+  // N = M | (E << 12);
+  SDValue N = DAG.getNode(ISD::OR, DL, MVT::i32, M,
+      DAG.getNode(ISD::SHL, DL, MVT::i32, E,
+                  DAG.getConstant(12, DL, MVT::i32)));
+
+  // B = clamp(1-E, 0, 13);
+  SDValue OneSubExp = DAG.getNode(ISD::SUB, DL, MVT::i32,
+                                  One, E);
+  SDValue B = DAG.getNode(ISD::SMAX, DL, MVT::i32, OneSubExp, Zero);
+  B = DAG.getNode(ISD::SMIN, DL, MVT::i32, B,
+                  DAG.getConstant(13, DL, MVT::i32));
+
+  SDValue SigSetHigh = DAG.getNode(ISD::OR, DL, MVT::i32, M,
+                                   DAG.getConstant(0x1000, DL, MVT::i32));
+
+  SDValue D = DAG.getNode(ISD::SRL, DL, MVT::i32, SigSetHigh, B);
+  SDValue D0 = DAG.getNode(ISD::SHL, DL, MVT::i32, D, B);
+  SDValue D1 = DAG.getSelectCC(DL, D0, SigSetHigh, One, Zero, ISD::SETNE);
+  D = DAG.getNode(ISD::OR, DL, MVT::i32, D, D1);
+
+  SDValue V = DAG.getSelectCC(DL, E, One, D, N, ISD::SETLT);
+  SDValue VLow3 = DAG.getNode(ISD::AND, DL, MVT::i32, V,
+                              DAG.getConstant(0x7, DL, MVT::i32));
+  V = DAG.getNode(ISD::SRL, DL, MVT::i32, V,
+                  DAG.getConstant(2, DL, MVT::i32));
+  SDValue V0 = DAG.getSelectCC(DL, VLow3, DAG.getConstant(3, DL, MVT::i32),
+                               One, Zero, ISD::SETEQ);
+  SDValue V1 = DAG.getSelectCC(DL, VLow3, DAG.getConstant(5, DL, MVT::i32),
+                               One, Zero, ISD::SETGT);
+  V1 = DAG.getNode(ISD::OR, DL, MVT::i32, V0, V1);
+  V = DAG.getNode(ISD::ADD, DL, MVT::i32, V, V1);
+
+  V = DAG.getSelectCC(DL, E, DAG.getConstant(30, DL, MVT::i32),
+                      DAG.getConstant(0x7c00, DL, MVT::i32), V, ISD::SETGT);
+  V = DAG.getSelectCC(DL, E, DAG.getConstant(1039, DL, MVT::i32),
+                      I, V, ISD::SETEQ);
+
+  // Extract the sign bit.
+  SDValue Sign = DAG.getNode(ISD::SRL, DL, MVT::i32, UH,
+                            DAG.getConstant(16, DL, MVT::i32));
+  Sign = DAG.getNode(ISD::AND, DL, MVT::i32, Sign,
+                     DAG.getConstant(0x8000, DL, MVT::i32));
+
+  V = DAG.getNode(ISD::OR, DL, MVT::i32, Sign, V);
+  return DAG.getZExtOrTrunc(V, DL, Op.getValueType());
+}
+
 SDValue AMDGPUTargetLowering::LowerFP_TO_SINT(SDValue Op,
                                               SelectionDAG &DAG) const {
   SDValue Src = Op.getOperand(0);
index 4cc1a74d18b6bd70cc6d1c8ce1e4bd55764adb99..cddef58d14e81772070a891ac42813c464d3dce2 100644 (file)
@@ -53,6 +53,7 @@ protected:
   SDValue LowerSINT_TO_FP(SDValue Op, SelectionDAG &DAG) const;
 
   SDValue LowerFP64_TO_INT(SDValue Op, SelectionDAG &DAG, bool Signed) const;
+  SDValue LowerFP_TO_FP16(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerFP_TO_UINT(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerFP_TO_SINT(SDValue Op, SelectionDAG &DAG) const;
 
index 385e10e7baae86e70d609667d955e062162b55c3..838e1b9e279ef1b404c556485f37d4686dea24bf 100644 (file)
@@ -1,17 +1,29 @@
-; RUN: llc -march=amdgcn -mcpu=SI -verify-machineinstrs < %s | FileCheck -check-prefix=SI -check-prefix=FUNC %s
-; RUN: llc -march=amdgcn -mcpu=tonga -verify-machineinstrs < %s | FileCheck -check-prefix=SI -check-prefix=FUNC %s
+; RUN: llc -march=amdgcn -mcpu=SI -verify-machineinstrs < %s | FileCheck -check-prefix=GCN %s
+; RUN: llc -march=amdgcn -mcpu=tonga -verify-machineinstrs < %s | FileCheck -check-prefix=GCN %s
+; RUN: llc -march=amdgcn -mcpu=tonga -enable-unsafe-fp-math -verify-machineinstrs < %s | FileCheck -check-prefixes=GCN,GCN-UNSAFE %s
 
 ; FUNC-LABEL: {{^}}fptrunc_f64_to_f32:
-; SI: v_cvt_f32_f64_e32 {{v[0-9]+}}, {{s\[[0-9]+:[0-9]+\]}}
+; GCN: v_cvt_f32_f64_e32 {{v[0-9]+}}, {{s\[[0-9]+:[0-9]+\]}}
 define void @fptrunc_f64_to_f32(float addrspace(1)* %out, double %in) {
   %result = fptrunc double %in to float
   store float %result, float addrspace(1)* %out
   ret void
 }
 
+; FUNC-LABEL: {{^}}fptrunc_f64_to_f16:
+; GCN-NOT: v_cvt
+; GCN-FAST: v_cvt_f32_f64_e32 [[F32:v[0-9]+]]
+; GCN-FAST: v_cvt_f16_f32_e32 v[0-9]+, [[F32]]
+define void @fptrunc_f64_to_f16(i16 addrspace(1)* %out, double %in) {
+  %result = fptrunc double %in to half
+  %result_i16 = bitcast half %result to i16
+  store i16 %result_i16, i16 addrspace(1)* %out
+  ret void
+}
+
 ; FUNC-LABEL: {{^}}fptrunc_v2f64_to_v2f32:
-; SI: v_cvt_f32_f64_e32
-; SI: v_cvt_f32_f64_e32
+; GCN: v_cvt_f32_f64_e32
+; GCN: v_cvt_f32_f64_e32
 define void @fptrunc_v2f64_to_v2f32(<2 x float> addrspace(1)* %out, <2 x double> %in) {
   %result = fptrunc <2 x double> %in to <2 x float>
   store <2 x float> %result, <2 x float> addrspace(1)* %out
@@ -19,10 +31,10 @@ define void @fptrunc_v2f64_to_v2f32(<2 x float> addrspace(1)* %out, <2 x double>
 }
 
 ; FUNC-LABEL: {{^}}fptrunc_v4f64_to_v4f32:
-; SI: v_cvt_f32_f64_e32
-; SI: v_cvt_f32_f64_e32
-; SI: v_cvt_f32_f64_e32
-; SI: v_cvt_f32_f64_e32
+; GCN: v_cvt_f32_f64_e32
+; GCN: v_cvt_f32_f64_e32
+; GCN: v_cvt_f32_f64_e32
+; GCN: v_cvt_f32_f64_e32
 define void @fptrunc_v4f64_to_v4f32(<4 x float> addrspace(1)* %out, <4 x double> %in) {
   %result = fptrunc <4 x double> %in to <4 x float>
   store <4 x float> %result, <4 x float> addrspace(1)* %out
@@ -30,14 +42,14 @@ define void @fptrunc_v4f64_to_v4f32(<4 x float> addrspace(1)* %out, <4 x double>
 }
 
 ; FUNC-LABEL: {{^}}fptrunc_v8f64_to_v8f32:
-; SI: v_cvt_f32_f64_e32
-; SI: v_cvt_f32_f64_e32
-; SI: v_cvt_f32_f64_e32
-; SI: v_cvt_f32_f64_e32
-; SI: v_cvt_f32_f64_e32
-; SI: v_cvt_f32_f64_e32
-; SI: v_cvt_f32_f64_e32
-; SI: v_cvt_f32_f64_e32
+; GCN: v_cvt_f32_f64_e32
+; GCN: v_cvt_f32_f64_e32
+; GCN: v_cvt_f32_f64_e32
+; GCN: v_cvt_f32_f64_e32
+; GCN: v_cvt_f32_f64_e32
+; GCN: v_cvt_f32_f64_e32
+; GCN: v_cvt_f32_f64_e32
+; GCN: v_cvt_f32_f64_e32
 define void @fptrunc_v8f64_to_v8f32(<8 x float> addrspace(1)* %out, <8 x double> %in) {
   %result = fptrunc <8 x double> %in to <8 x float>
   store <8 x float> %result, <8 x float> addrspace(1)* %out
index c29872beef861281dfef9f20e72ff619249b93a3..03b8af0610d7738bd932bc09923931fa3661ea93 100644 (file)
@@ -1,5 +1,4 @@
-; XFAIL: *
-; RUN: llc -march=amdgcn -mcpu=SI < %s
+; RUN: llc -march=amdgcn -verify-machineinstrs < %s | FileCheck -check-prefix=GCN %s
 
 ; GCN-LABEL: {{^}}global_truncstore_f64_to_f16:
 ; GCN: s_endpgm