From 347e991ccaf6cf892f1edb95a9339ac2c8a6103b Mon Sep 17 00:00:00 2001 From: Elena Demikhovsky Date: Thu, 22 Jun 2017 06:47:41 +0000 Subject: [PATCH] AVX-512: Lowering Masked Gather intrinsic - fixed a bug Masked gather for vector length 2 is lowered incorrectly for element type i32. The type <2 x i32> was automatically extended to <2 x i64> and we generated VPGATHERQQ instead of VPGATHERQD. The type <2 x float> is extended to <4 x float>, so there is no bug for this type, but the sequence may be more optimal. In this patch I'm fixing <2 x i32>bug and optimizing <2 x float> sequence for GATHERs only. The same fix should be done for Scatters as well. Differential revision: https://reviews.llvm.org/D34343 git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@305987 91177308-0d34-0410-b5e6-96231b3b80d8 --- include/llvm/CodeGen/SelectionDAGNodes.h | 2 +- lib/Target/X86/X86ISelDAGToDAG.cpp | 28 ++++++--- lib/Target/X86/X86ISelLowering.cpp | 52 +++++++++++++++++ lib/Target/X86/X86ISelLowering.h | 18 +++++- lib/Target/X86/X86InstrAVX512.td | 2 +- lib/Target/X86/X86InstrFragmentsSIMD.td | 12 ++++ test/CodeGen/X86/masked_gather_scatter.ll | 69 ++++++++++++++++++++--- 7 files changed, 165 insertions(+), 18 deletions(-) diff --git a/include/llvm/CodeGen/SelectionDAGNodes.h b/include/llvm/CodeGen/SelectionDAGNodes.h index 0cd26d35a48..af418d3050e 100644 --- a/include/llvm/CodeGen/SelectionDAGNodes.h +++ b/include/llvm/CodeGen/SelectionDAGNodes.h @@ -2107,7 +2107,7 @@ class MaskedGatherScatterSDNode : public MemSDNode { public: friend class SelectionDAG; - MaskedGatherScatterSDNode(ISD::NodeType NodeTy, unsigned Order, + MaskedGatherScatterSDNode(unsigned NodeTy, unsigned Order, const DebugLoc &dl, SDVTList VTs, EVT MemVT, MachineMemOperand *MMO) : MemSDNode(NodeTy, Order, dl, VTs, MemVT, MMO) {} diff --git a/lib/Target/X86/X86ISelDAGToDAG.cpp b/lib/Target/X86/X86ISelDAGToDAG.cpp index 2a1633de0a2..3c4589ab18f 100644 --- a/lib/Target/X86/X86ISelDAGToDAG.cpp +++ b/lib/Target/X86/X86ISelDAGToDAG.cpp @@ -204,6 +204,11 @@ namespace { bool selectVectorAddr(SDNode *Parent, SDValue N, SDValue &Base, SDValue &Scale, SDValue &Index, SDValue &Disp, SDValue &Segment); + template + bool selectAddrOfGatherScatterNode(GatherScatterSDNode *Parent, SDValue N, + SDValue &Base, SDValue &Scale, + SDValue &Index, SDValue &Disp, + SDValue &Segment); bool selectMOV64Imm32(SDValue N, SDValue &Imm); bool selectLEAAddr(SDValue N, SDValue &Base, SDValue &Scale, SDValue &Index, SDValue &Disp, @@ -1415,13 +1420,10 @@ bool X86DAGToDAGISel::matchAddressBase(SDValue N, X86ISelAddressMode &AM) { return false; } -bool X86DAGToDAGISel::selectVectorAddr(SDNode *Parent, SDValue N, SDValue &Base, - SDValue &Scale, SDValue &Index, - SDValue &Disp, SDValue &Segment) { - - MaskedGatherScatterSDNode *Mgs = dyn_cast(Parent); - if (!Mgs) - return false; +template +bool X86DAGToDAGISel::selectAddrOfGatherScatterNode( + GatherScatterSDNode *Mgs, SDValue N, SDValue &Base, SDValue &Scale, + SDValue &Index, SDValue &Disp, SDValue &Segment) { X86ISelAddressMode AM; unsigned AddrSpace = Mgs->getPointerInfo().getAddrSpace(); // AddrSpace 256 -> GS, 257 -> FS, 258 -> SS. @@ -1453,6 +1455,18 @@ bool X86DAGToDAGISel::selectVectorAddr(SDNode *Parent, SDValue N, SDValue &Base, return true; } +bool X86DAGToDAGISel::selectVectorAddr(SDNode *Parent, SDValue N, SDValue &Base, + SDValue &Scale, SDValue &Index, + SDValue &Disp, SDValue &Segment) { + if (auto Mgs = dyn_cast(Parent)) + return selectAddrOfGatherScatterNode( + Mgs, N, Base, Scale, Index, Disp, Segment); + if (auto X86Gather = dyn_cast(Parent)) + return selectAddrOfGatherScatterNode( + X86Gather, N, Base, Scale, Index, Disp, Segment); + return false; +} + /// Returns true if it is able to pattern match an addressing mode. /// It returns the operands which make up the maximal addressing mode it can /// match by reference. diff --git a/lib/Target/X86/X86ISelLowering.cpp b/lib/Target/X86/X86ISelLowering.cpp index 6d6c44bbbca..6487b4aa626 100644 --- a/lib/Target/X86/X86ISelLowering.cpp +++ b/lib/Target/X86/X86ISelLowering.cpp @@ -23708,6 +23708,57 @@ static SDValue LowerMGATHER(SDValue Op, const X86Subtarget &Subtarget, SDValue RetOps[] = {Exract, NewGather.getValue(1)}; return DAG.getMergeValues(RetOps, dl); } + if (N->getMemoryVT() == MVT::v2i32 && Subtarget.hasVLX()) { + // There is a special case when the return type is v2i32 is illegal and + // the type legaizer extended it to v2i64. Without this conversion we end up + // with VPGATHERQQ (reading q-words from the memory) instead of VPGATHERQD. + // In order to avoid this situation, we'll build an X86 specific Gather node + // with index v2i64 and value type v4i32. + assert(VT == MVT::v2i64 && Src0.getValueType() == MVT::v2i64 && + "Unexpected type in masked gather"); + Src0 = DAG.getVectorShuffle(MVT::v4i32, dl, + DAG.getBitcast(MVT::v4i32, Src0), + DAG.getUNDEF(MVT::v4i32), { 0, 2, -1, -1 }); + // The mask should match the destination type. Extending mask with zeroes + // is not necessary since instruction itself reads only two values from + // memory. + Mask = ExtendToType(Mask, MVT::v4i1, DAG, false); + SDValue Ops[] = { N->getChain(), Src0, Mask, N->getBasePtr(), Index }; + SDValue NewGather = DAG.getTargetMemSDNode( + DAG.getVTList(MVT::v4i32, MVT::Other), Ops, dl, N->getMemoryVT(), + N->getMemOperand()); + + SDValue Sext = getExtendInVec(X86ISD::VSEXT, dl, MVT::v2i64, + NewGather.getValue(0), DAG); + SDValue RetOps[] = { Sext, NewGather.getValue(1) }; + return DAG.getMergeValues(RetOps, dl); + } + if (N->getMemoryVT() == MVT::v2f32 && Subtarget.hasVLX()) { + // This transformation is for optimization only. + // The type legalizer extended mask and index to 4 elements vector + // in order to match requirements of the common gather node - same + // vector width of index and value. X86 Gather node allows mismatch + // of vector width in order to select more optimal instruction at the + // end. + assert(VT == MVT::v4f32 && Src0.getValueType() == MVT::v4f32 && + "Unexpected type in masked gather"); + if (Mask.getOpcode() == ISD::CONCAT_VECTORS && + ISD::isBuildVectorAllZeros(Mask.getOperand(1).getNode()) && + Index.getOpcode() == ISD::CONCAT_VECTORS && + Index.getOperand(1).isUndef()) { + Mask = ExtendToType(Mask.getOperand(0), MVT::v4i1, DAG, false); + Index = Index.getOperand(0); + } else + return Op; + SDValue Ops[] = { N->getChain(), Src0, Mask, N->getBasePtr(), Index }; + SDValue NewGather = DAG.getTargetMemSDNode( + DAG.getVTList(MVT::v4f32, MVT::Other), Ops, dl, N->getMemoryVT(), + N->getMemOperand()); + + SDValue RetOps[] = { NewGather.getValue(0), NewGather.getValue(1) }; + return DAG.getMergeValues(RetOps, dl); + + } return Op; } @@ -24511,6 +24562,7 @@ const char *X86TargetLowering::getTargetNodeName(unsigned Opcode) const { case X86ISD::CVTS2SI_RND: return "X86ISD::CVTS2SI_RND"; case X86ISD::CVTS2UI_RND: return "X86ISD::CVTS2UI_RND"; case X86ISD::LWPINS: return "X86ISD::LWPINS"; + case X86ISD::MGATHER: return "X86ISD::MGATHER"; } return nullptr; } diff --git a/lib/Target/X86/X86ISelLowering.h b/lib/Target/X86/X86ISelLowering.h index f51b6641db2..dd0f84cf7b6 100644 --- a/lib/Target/X86/X86ISelLowering.h +++ b/lib/Target/X86/X86ISelLowering.h @@ -615,7 +615,10 @@ namespace llvm { // Vector truncating store with unsigned/signed saturation VTRUNCSTOREUS, VTRUNCSTORES, // Vector truncating masked store with unsigned/signed saturation - VMTRUNCSTOREUS, VMTRUNCSTORES + VMTRUNCSTOREUS, VMTRUNCSTORES, + + // X86 specific gather + MGATHER // WARNING: Do not add anything in the end unless you want the node to // have memop! In fact, starting from FIRST_TARGET_MEMORY_OPCODE all @@ -1397,6 +1400,19 @@ namespace llvm { } }; + // X86 specific Gather node. + class X86MaskedGatherSDNode : public MaskedGatherScatterSDNode { + public: + X86MaskedGatherSDNode(unsigned Order, + const DebugLoc &dl, SDVTList VTs, EVT MemVT, + MachineMemOperand *MMO) + : MaskedGatherScatterSDNode(X86ISD::MGATHER, Order, dl, VTs, MemVT, MMO) + {} + static bool classof(const SDNode *N) { + return N->getOpcode() == X86ISD::MGATHER; + } + }; + } // end namespace llvm #endif // LLVM_LIB_TARGET_X86_X86ISELLOWERING_H diff --git a/lib/Target/X86/X86InstrAVX512.td b/lib/Target/X86/X86InstrAVX512.td index 2620679df25..d46262573f7 100644 --- a/lib/Target/X86/X86InstrAVX512.td +++ b/lib/Target/X86/X86InstrAVX512.td @@ -7869,7 +7869,7 @@ let Predicates = [HasVLX] in { defm NAME##D##SUFF##Z128: avx512_gather, EVEX_V128; defm NAME##Q##SUFF##Z128: avx512_gather, EVEX_V128; + vx64xmem, X86mgatherv2i64>, EVEX_V128; } } diff --git a/lib/Target/X86/X86InstrFragmentsSIMD.td b/lib/Target/X86/X86InstrFragmentsSIMD.td index c28b35b2297..1cd6320a55e 100644 --- a/lib/Target/X86/X86InstrFragmentsSIMD.td +++ b/lib/Target/X86/X86InstrFragmentsSIMD.td @@ -773,6 +773,9 @@ def memop64 : PatFrag<(ops node:$ptr), (load node:$ptr), [{ def memopmmx : PatFrag<(ops node:$ptr), (x86mmx (memop64 node:$ptr))>; +def X86masked_gather : SDNode<"X86ISD::MGATHER", SDTMaskedGather, + [SDNPHasChain, SDNPMayLoad, SDNPMemOperand]>; + def mgatherv4i32 : PatFrag<(ops node:$src1, node:$src2, node:$src3), (masked_gather node:$src1, node:$src2, node:$src3) , [{ if (MaskedGatherSDNode *Mgt = dyn_cast(N)) @@ -796,6 +799,15 @@ def mgatherv2i64 : PatFrag<(ops node:$src1, node:$src2, node:$src3), Mgt->getBasePtr().getValueType() == MVT::v2i64); return false; }]>; +def X86mgatherv2i64 : PatFrag<(ops node:$src1, node:$src2, node:$src3), + (X86masked_gather node:$src1, node:$src2, node:$src3) , [{ + if (X86MaskedGatherSDNode *Mgt = dyn_cast(N)) + return (Mgt->getIndex().getValueType() == MVT::v2i64 || + Mgt->getBasePtr().getValueType() == MVT::v2i64) && + (Mgt->getMemoryVT() == MVT::v2i32 || + Mgt->getMemoryVT() == MVT::v2f32); + return false; +}]>; def mgatherv4i64 : PatFrag<(ops node:$src1, node:$src2, node:$src3), (masked_gather node:$src1, node:$src2, node:$src3) , [{ if (MaskedGatherSDNode *Mgt = dyn_cast(N)) diff --git a/test/CodeGen/X86/masked_gather_scatter.ll b/test/CodeGen/X86/masked_gather_scatter.ll index 91087f650ad..77254ba6760 100644 --- a/test/CodeGen/X86/masked_gather_scatter.ll +++ b/test/CodeGen/X86/masked_gather_scatter.ll @@ -1226,6 +1226,57 @@ define <2 x float> @test22(float* %base, <2 x i32> %ind, <2 x i1> %mask, <2 x fl ret <2 x float>%res } +define <2 x float> @test22a(float* %base, <2 x i64> %ind, <2 x i1> %mask, <2 x float> %src0) { +; KNL_64-LABEL: test22a: +; KNL_64: # BB#0: +; KNL_64-NEXT: # kill: %XMM2 %XMM2 %YMM2 +; KNL_64-NEXT: # kill: %XMM0 %XMM0 %ZMM0 +; KNL_64-NEXT: vinsertps {{.*#+}} xmm1 = xmm1[0,2],zero,zero +; KNL_64-NEXT: vpxor %ymm3, %ymm3, %ymm3 +; KNL_64-NEXT: vpblendd {{.*#+}} ymm1 = ymm1[0,1,2,3],ymm3[4,5,6,7] +; KNL_64-NEXT: vpslld $31, %ymm1, %ymm1 +; KNL_64-NEXT: vptestmd %zmm1, %zmm1, %k1 +; KNL_64-NEXT: vgatherqps (%rdi,%zmm0,4), %ymm2 {%k1} +; KNL_64-NEXT: vmovaps %xmm2, %xmm0 +; KNL_64-NEXT: vzeroupper +; KNL_64-NEXT: retq +; +; KNL_32-LABEL: test22a: +; KNL_32: # BB#0: +; KNL_32-NEXT: # kill: %XMM2 %XMM2 %YMM2 +; KNL_32-NEXT: # kill: %XMM0 %XMM0 %ZMM0 +; KNL_32-NEXT: vinsertps {{.*#+}} xmm1 = xmm1[0,2],zero,zero +; KNL_32-NEXT: vpxor %ymm3, %ymm3, %ymm3 +; KNL_32-NEXT: vpblendd {{.*#+}} ymm1 = ymm1[0,1,2,3],ymm3[4,5,6,7] +; KNL_32-NEXT: movl {{[0-9]+}}(%esp), %eax +; KNL_32-NEXT: vpslld $31, %ymm1, %ymm1 +; KNL_32-NEXT: vptestmd %zmm1, %zmm1, %k1 +; KNL_32-NEXT: vgatherqps (%eax,%zmm0,4), %ymm2 {%k1} +; KNL_32-NEXT: vmovaps %xmm2, %xmm0 +; KNL_32-NEXT: vzeroupper +; KNL_32-NEXT: retl +; +; SKX-LABEL: test22a: +; SKX: # BB#0: +; SKX-NEXT: vpsllq $63, %xmm1, %xmm1 +; SKX-NEXT: vptestmq %xmm1, %xmm1, %k1 +; SKX-NEXT: vgatherqps (%rdi,%xmm0,4), %xmm2 {%k1} +; SKX-NEXT: vmovaps %xmm2, %xmm0 +; SKX-NEXT: retq +; +; SKX_32-LABEL: test22a: +; SKX_32: # BB#0: +; SKX_32-NEXT: vpsllq $63, %xmm1, %xmm1 +; SKX_32-NEXT: vptestmq %xmm1, %xmm1, %k1 +; SKX_32-NEXT: movl {{[0-9]+}}(%esp), %eax +; SKX_32-NEXT: vgatherqps (%eax,%xmm0,4), %xmm2 {%k1} +; SKX_32-NEXT: vmovaps %xmm2, %xmm0 +; SKX_32-NEXT: retl + %gep.random = getelementptr float, float* %base, <2 x i64> %ind + %res = call <2 x float> @llvm.masked.gather.v2f32.v2p0f32(<2 x float*> %gep.random, i32 4, <2 x i1> %mask, <2 x float> %src0) + ret <2 x float>%res +} + declare <2 x i32> @llvm.masked.gather.v2i32.v2p0i32(<2 x i32*>, i32, <2 x i1>, <2 x i32>) declare <2 x i64> @llvm.masked.gather.v2i64.v2p0i64(<2 x i64*>, i32, <2 x i1>, <2 x i64>) @@ -1262,8 +1313,9 @@ define <2 x i32> @test23(i32* %base, <2 x i32> %ind, <2 x i1> %mask, <2 x i32> % ; SKX: # BB#0: ; SKX-NEXT: vpsllq $63, %xmm1, %xmm1 ; SKX-NEXT: vptestmq %xmm1, %xmm1, %k1 -; SKX-NEXT: vpgatherqq (%rdi,%xmm0,8), %xmm2 {%k1} -; SKX-NEXT: vmovdqa %xmm2, %xmm0 +; SKX-NEXT: vpshufd {{.*#+}} xmm1 = xmm2[0,2,2,3] +; SKX-NEXT: vpgatherqd (%rdi,%xmm0,4), %xmm1 {%k1} +; SKX-NEXT: vpmovsxdq %xmm1, %xmm0 ; SKX-NEXT: retq ; ; SKX_32-LABEL: test23: @@ -1271,8 +1323,9 @@ define <2 x i32> @test23(i32* %base, <2 x i32> %ind, <2 x i1> %mask, <2 x i32> % ; SKX_32-NEXT: vpsllq $63, %xmm1, %xmm1 ; SKX_32-NEXT: vptestmq %xmm1, %xmm1, %k1 ; SKX_32-NEXT: movl {{[0-9]+}}(%esp), %eax -; SKX_32-NEXT: vpgatherqq (%eax,%xmm0,8), %xmm2 {%k1} -; SKX_32-NEXT: vmovdqa %xmm2, %xmm0 +; SKX_32-NEXT: vpshufd {{.*#+}} xmm1 = xmm2[0,2,2,3] +; SKX_32-NEXT: vpgatherqd (%eax,%xmm0,4), %xmm1 {%k1} +; SKX_32-NEXT: vpmovsxdq %xmm1, %xmm0 ; SKX_32-NEXT: retl %sext_ind = sext <2 x i32> %ind to <2 x i64> %gep.random = getelementptr i32, i32* %base, <2 x i64> %sext_ind @@ -1307,16 +1360,16 @@ define <2 x i32> @test24(i32* %base, <2 x i32> %ind) { ; SKX-LABEL: test24: ; SKX: # BB#0: ; SKX-NEXT: kxnorw %k0, %k0, %k1 -; SKX-NEXT: vpgatherqq (%rdi,%xmm0,8), %xmm1 {%k1} -; SKX-NEXT: vmovdqa %xmm1, %xmm0 +; SKX-NEXT: vpgatherqd (%rdi,%xmm0,4), %xmm1 {%k1} +; SKX-NEXT: vpmovsxdq %xmm1, %xmm0 ; SKX-NEXT: retq ; ; SKX_32-LABEL: test24: ; SKX_32: # BB#0: ; SKX_32-NEXT: movl {{[0-9]+}}(%esp), %eax ; SKX_32-NEXT: kxnorw %k0, %k0, %k1 -; SKX_32-NEXT: vpgatherqq (%eax,%xmm0,8), %xmm1 {%k1} -; SKX_32-NEXT: vmovdqa %xmm1, %xmm0 +; SKX_32-NEXT: vpgatherqd (%eax,%xmm0,4), %xmm1 {%k1} +; SKX_32-NEXT: vpmovsxdq %xmm1, %xmm0 ; SKX_32-NEXT: retl %sext_ind = sext <2 x i32> %ind to <2 x i64> %gep.random = getelementptr i32, i32* %base, <2 x i64> %sext_ind -- 2.50.1