From 2fd63302fec84fcc7e02aeff2ea0959ea25a9775 Mon Sep 17 00:00:00 2001 From: Elena Demikhovsky Date: Sat, 29 Oct 2016 08:44:46 +0000 Subject: [PATCH] Fixed FMA + FNEG combine. Masked form of FMA should be omitted in this optimization. Differential Revision: https://reviews.llvm.org/D25984 git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@285492 91177308-0d34-0410-b5e6-96231b3b80d8 --- lib/Target/X86/X86ISelLowering.cpp | 16 +++- test/CodeGen/X86/fma-fneg-combine.ll | 106 +++++++++++++++++++++++++++ 2 files changed, 119 insertions(+), 3 deletions(-) diff --git a/lib/Target/X86/X86ISelLowering.cpp b/lib/Target/X86/X86ISelLowering.cpp index c48158072e0..4ab3bd97377 100644 --- a/lib/Target/X86/X86ISelLowering.cpp +++ b/lib/Target/X86/X86ISelLowering.cpp @@ -31277,6 +31277,15 @@ static SDValue combineFMA(SDNode *N, SelectionDAG &DAG, SDValue B = N->getOperand(1); SDValue C = N->getOperand(2); + auto isScalarMaskedNode = [&](SDValue &V) { + if (V.hasOneUse()) + return false; + for (auto User : V.getNode()->uses()) + if (User->getOpcode() == X86ISD::SELECTS && N->isOperandOf(User)) + return true; + return false; + }; + auto invertIfNegative = [](SDValue &V) { if (SDValue NegVal = isFNEG(V.getNode())) { V = NegVal; @@ -31285,9 +31294,10 @@ static SDValue combineFMA(SDNode *N, SelectionDAG &DAG, return false; }; - bool NegA = invertIfNegative(A); - bool NegB = invertIfNegative(B); - bool NegC = invertIfNegative(C); + // Do not convert scalar masked operations. + bool NegA = !isScalarMaskedNode(A) && invertIfNegative(A); + bool NegB = !isScalarMaskedNode(B) && invertIfNegative(B); + bool NegC = !isScalarMaskedNode(C) && invertIfNegative(C); // Negative multiplication when NegA xor NegB bool NegMul = (NegA != NegB); diff --git a/test/CodeGen/X86/fma-fneg-combine.ll b/test/CodeGen/X86/fma-fneg-combine.ll index edcd780a486..76d8cb5a644 100644 --- a/test/CodeGen/X86/fma-fneg-combine.ll +++ b/test/CodeGen/X86/fma-fneg-combine.ll @@ -137,3 +137,109 @@ entry: declare <2 x double> @llvm.x86.avx512.mask.vfmadd.sd(<2 x double> %a, <2 x double> %b, <2 x double> %c, i8, i32) +define <4 x float> @test11(<4 x float> %a, <4 x float> %b, <4 x float> %c, i8 zeroext %mask) local_unnamed_addr #0 { +; SKX-LABEL: test11: +; SKX: # BB#0: # %entry +; SKX-NEXT: vxorps {{.*}}(%rip){1to4}, %xmm2, %xmm0 +; SKX-NEXT: andl $1, %edi +; SKX-NEXT: kmovw %edi, %k1 +; SKX-NEXT: vfmadd231ss %xmm1, %xmm1, %xmm0 {%k1} +; SKX-NEXT: retq +; +; KNL-LABEL: test11: +; KNL: # BB#0: # %entry +; KNL-NEXT: vbroadcastss {{.*}}(%rip), %xmm0 +; KNL-NEXT: vxorps %xmm0, %xmm2, %xmm0 +; KNL-NEXT: andl $1, %edi +; KNL-NEXT: kmovw %edi, %k1 +; KNL-NEXT: vfmadd231ss %xmm1, %xmm1, %xmm0 {%k1} +; KNL-NEXT: retq +entry: + %sub.i = fsub <4 x float> , %c + %0 = tail call <4 x float> @llvm.x86.avx512.mask3.vfmadd.ss(<4 x float> %b, <4 x float> %b, <4 x float> %sub.i, i8 %mask, i32 4) #10 + ret <4 x float> %0 +} + +declare <4 x float> @llvm.x86.avx512.mask3.vfmadd.ss(<4 x float>, <4 x float>, <4 x float>, i8, i32) + +define <8 x double> @test12(<8 x double> %a, <8 x double> %b, <8 x double> %c, i8 %mask) { +; SKX-LABEL: test12: +; SKX: # BB#0: # %entry +; SKX-NEXT: kmovb %edi, %k1 +; SKX-NEXT: vfmadd132pd %zmm1, %zmm2, %zmm0 {%k1} +; SKX-NEXT: vxorpd {{.*}}(%rip){1to8}, %zmm0, %zmm0 +; SKX-NEXT: retq +; +; KNL-LABEL: test12: +; KNL: # BB#0: # %entry +; KNL-NEXT: kmovw %edi, %k1 +; KNL-NEXT: vfmadd132pd %zmm1, %zmm2, %zmm0 {%k1} +; KNL-NEXT: vpxorq {{.*}}(%rip){1to8}, %zmm0, %zmm0 +; KNL-NEXT: retq +entry: + %0 = tail call <8 x double> @llvm.x86.avx512.mask.vfmadd.pd.512(<8 x double> %a, <8 x double> %b, <8 x double> %c, i8 %mask, i32 4) #2 + %sub.i = fsub <8 x double> , %0 + ret <8 x double> %sub.i +} + +define <2 x double> @test13(<2 x double> %a, <2 x double> %b, <2 x double> %c, i8 %mask) { +; CHECK-LABEL: test13: +; CHECK: # BB#0: # %entry +; CHECK-NEXT: vxorps {{.*}}(%rip), %xmm0, %xmm0 +; CHECK-NEXT: andl $1, %edi +; CHECK-NEXT: kmovw %edi, %k1 +; CHECK-NEXT: vfmadd132sd %xmm1, %xmm2, %xmm0 {%k1} +; CHECK-NEXT: retq +entry: + %sub.i = fsub <2 x double> , %a + %0 = tail call <2 x double> @llvm.x86.avx512.mask.vfmadd.sd(<2 x double> %sub.i, <2 x double> %b, <2 x double> %c, i8 %mask, i32 4) + ret <2 x double> %0 +} + +define <16 x float> @test14(<16 x float> %a, <16 x float> %b, <16 x float> %c, i16 %mask) { +; SKX-LABEL: test14: +; SKX: # BB#0: # %entry +; SKX-NEXT: kmovw %edi, %k1 +; SKX-NEXT: vfnmsub132ps {ru-sae}, %zmm1, %zmm2, %zmm0 {%k1} +; SKX-NEXT: vxorps {{.*}}(%rip){1to16}, %zmm0, %zmm0 +; SKX-NEXT: retq +; +; KNL-LABEL: test14: +; KNL: # BB#0: # %entry +; KNL-NEXT: kmovw %edi, %k1 +; KNL-NEXT: vfnmsub132ps {ru-sae}, %zmm1, %zmm2, %zmm0 {%k1} +; KNL-NEXT: vpxord {{.*}}(%rip){1to16}, %zmm0, %zmm0 +; KNL-NEXT: retq +entry: + %0 = tail call <16 x float> @llvm.x86.avx512.mask.vfnmsub.ps.512(<16 x float> %a, <16 x float> %b, <16 x float> %c, i16 %mask, i32 2) #2 + %sub.i = fsub <16 x float> , %0 + ret <16 x float> %sub.i +} + +define <16 x float> @test15(<16 x float> %a, <16 x float> %b, <16 x float> %c, i16 %mask) { +; SKX-LABEL: test15: +; SKX: # BB#0: # %entry +; SKX-NEXT: kmovw %edi, %k1 +; SKX-NEXT: vxorps {{.*}}(%rip){1to16}, %zmm0, %zmm3 +; SKX-NEXT: vfnmadd213ps {ru-sae}, %zmm2, %zmm0, %zmm1 +; SKX-NEXT: vblendmps %zmm1, %zmm3, %zmm1 {%k1} +; SKX-NEXT: vfnmadd132ps {rd-sae}, %zmm0, %zmm2, %zmm1 {%k1} +; SKX-NEXT: vmovaps %zmm1, %zmm0 +; SKX-NEXT: retq +; +; KNL-LABEL: test15: +; KNL: # BB#0: # %entry +; KNL-NEXT: kmovw %edi, %k1 +; KNL-NEXT: vpxord {{.*}}(%rip){1to16}, %zmm0, %zmm3 +; KNL-NEXT: vfnmadd213ps {ru-sae}, %zmm2, %zmm0, %zmm1 +; KNL-NEXT: vblendmps %zmm1, %zmm3, %zmm1 {%k1} +; KNL-NEXT: vfnmadd132ps {rd-sae}, %zmm0, %zmm2, %zmm1 {%k1} +; KNL-NEXT: vmovaps %zmm1, %zmm0 +; KNL-NEXT: retq +entry: + %sub.i = fsub <16 x float> , %a + %0 = tail call <16 x float> @llvm.x86.avx512.mask.vfmadd.ps.512(<16 x float> %sub.i, <16 x float> %b, <16 x float> %c, i16 %mask, i32 2) + %1 = tail call <16 x float> @llvm.x86.avx512.mask.vfmadd.ps.512(<16 x float> %0, <16 x float> %sub.i, <16 x float> %c, i16 %mask, i32 1) + ret <16 x float> %1 +} + -- 2.40.0