From 23bd7de3923895f7cbae7daf9c76e1d3867ad68a Mon Sep 17 00:00:00 2001 From: Elena Demikhovsky Date: Sun, 9 Oct 2016 10:48:52 +0000 Subject: [PATCH] DAG: Setting Masked-Expand-Load as a variant of Masked-Load node Masked-expand-load node represents load operation that loads a variable amount of elements from memory according to amount of "true" bits in the mask and expands the loaded elements according to their position in the mask vector. Right now, the node is used in intrinsics for VEXPAND* instructions. The work is done towards implementation of masked.expandload and masked.compressstore intrinsics. Differential Revision: https://reviews.llvm.org/D25322 git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@283694 91177308-0d34-0410-b5e6-96231b3b80d8 --- include/llvm/CodeGen/SelectionDAG.h | 7 ++-- include/llvm/CodeGen/SelectionDAGNodes.h | 9 ++++-- lib/CodeGen/SelectionDAG/SelectionDAG.cpp | 12 +++---- .../SelectionDAG/SelectionDAGBuilder.cpp | 2 +- lib/Target/X86/X86ISelLowering.cpp | 23 ++++++------- lib/Target/X86/X86InstrAVX512.td | 21 ++++++++++-- lib/Target/X86/X86InstrFragmentsSIMD.td | 32 +++++++++++-------- lib/Target/X86/X86InstrSSE.td | 6 ++-- test/CodeGen/X86/avx512vl-intrinsics.ll | 31 +++++++++++++++--- 9 files changed, 97 insertions(+), 46 deletions(-) diff --git a/include/llvm/CodeGen/SelectionDAG.h b/include/llvm/CodeGen/SelectionDAG.h index d934ddb8366..9f45cc82089 100644 --- a/include/llvm/CodeGen/SelectionDAG.h +++ b/include/llvm/CodeGen/SelectionDAG.h @@ -965,11 +965,12 @@ public: SDValue getMaskedLoad(EVT VT, const SDLoc &dl, SDValue Chain, SDValue Ptr, SDValue Mask, SDValue Src0, EVT MemVT, - MachineMemOperand *MMO, ISD::LoadExtType); + MachineMemOperand *MMO, ISD::LoadExtType, + bool IsExpanding = false); SDValue getMaskedStore(SDValue Chain, const SDLoc &dl, SDValue Val, SDValue Ptr, SDValue Mask, EVT MemVT, - MachineMemOperand *MMO, bool IsTrunc, - bool isCompressing = false); + MachineMemOperand *MMO, bool IsTruncating = false, + bool IsCompressing = false); SDValue getMaskedGather(SDVTList VTs, EVT VT, const SDLoc &dl, ArrayRef Ops, MachineMemOperand *MMO); SDValue getMaskedScatter(SDVTList VTs, EVT VT, const SDLoc &dl, diff --git a/include/llvm/CodeGen/SelectionDAGNodes.h b/include/llvm/CodeGen/SelectionDAGNodes.h index fb5a01fade8..1d14d1228ce 100644 --- a/include/llvm/CodeGen/SelectionDAGNodes.h +++ b/include/llvm/CodeGen/SelectionDAGNodes.h @@ -444,6 +444,7 @@ protected: uint16_t : NumLSBaseSDNodeBits; uint16_t ExtTy : 2; // enum ISD::LoadExtType + uint16_t IsExpanding : 1; }; class StoreSDNodeBitfields { @@ -473,7 +474,7 @@ protected: static_assert(sizeof(ConstantSDNodeBitfields) <= 2, "field too wide"); static_assert(sizeof(MemSDNodeBitfields) <= 2, "field too wide"); static_assert(sizeof(LSBaseSDNodeBitfields) <= 2, "field too wide"); - static_assert(sizeof(LoadSDNodeBitfields) <= 2, "field too wide"); + static_assert(sizeof(LoadSDNodeBitfields) <= 4, "field too wide"); static_assert(sizeof(StoreSDNodeBitfields) <= 2, "field too wide"); private: @@ -1939,9 +1940,11 @@ class MaskedLoadSDNode : public MaskedLoadStoreSDNode { public: friend class SelectionDAG; MaskedLoadSDNode(unsigned Order, const DebugLoc &dl, SDVTList VTs, - ISD::LoadExtType ETy, EVT MemVT, MachineMemOperand *MMO) + ISD::LoadExtType ETy, bool IsExpanding, EVT MemVT, + MachineMemOperand *MMO) : MaskedLoadStoreSDNode(ISD::MLOAD, Order, dl, VTs, MemVT, MMO) { LoadSDNodeBits.ExtTy = ETy; + LoadSDNodeBits.IsExpanding = IsExpanding; } ISD::LoadExtType getExtensionType() const { @@ -1952,6 +1955,8 @@ public: static bool classof(const SDNode *N) { return N->getOpcode() == ISD::MLOAD; } + + bool isExpandingLoad() const { return LoadSDNodeBits.IsExpanding; } }; /// This class is used to represent an MSTORE node diff --git a/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/lib/CodeGen/SelectionDAG/SelectionDAG.cpp index a915fe161a3..3671422be2c 100644 --- a/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -5347,7 +5347,7 @@ SDValue SelectionDAG::getIndexedStore(SDValue OrigStore, const SDLoc &dl, SDValue SelectionDAG::getMaskedLoad(EVT VT, const SDLoc &dl, SDValue Chain, SDValue Ptr, SDValue Mask, SDValue Src0, EVT MemVT, MachineMemOperand *MMO, - ISD::LoadExtType ExtTy) { + ISD::LoadExtType ExtTy, bool isExpanding) { SDVTList VTs = getVTList(VT, MVT::Other); SDValue Ops[] = { Chain, Ptr, Mask, Src0 }; @@ -5355,7 +5355,7 @@ SDValue SelectionDAG::getMaskedLoad(EVT VT, const SDLoc &dl, SDValue Chain, AddNodeIDNode(ID, ISD::MLOAD, VTs, Ops); ID.AddInteger(VT.getRawBits()); ID.AddInteger(getSyntheticNodeSubclassData( - dl.getIROrder(), VTs, ExtTy, MemVT, MMO)); + dl.getIROrder(), VTs, ExtTy, isExpanding, MemVT, MMO)); ID.AddInteger(MMO->getPointerInfo().getAddrSpace()); void *IP = nullptr; if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) { @@ -5363,7 +5363,7 @@ SDValue SelectionDAG::getMaskedLoad(EVT VT, const SDLoc &dl, SDValue Chain, return SDValue(E, 0); } auto *N = newSDNode(dl.getIROrder(), dl.getDebugLoc(), VTs, - ExtTy, MemVT, MMO); + ExtTy, isExpanding, MemVT, MMO); createOperands(N, Ops); CSEMap.InsertNode(N, IP); @@ -5374,7 +5374,7 @@ SDValue SelectionDAG::getMaskedLoad(EVT VT, const SDLoc &dl, SDValue Chain, SDValue SelectionDAG::getMaskedStore(SDValue Chain, const SDLoc &dl, SDValue Val, SDValue Ptr, SDValue Mask, EVT MemVT, MachineMemOperand *MMO, - bool isTrunc, bool isCompress) { + bool IsTruncating, bool IsCompressing) { assert(Chain.getValueType() == MVT::Other && "Invalid chain type"); EVT VT = Val.getValueType(); @@ -5384,7 +5384,7 @@ SDValue SelectionDAG::getMaskedStore(SDValue Chain, const SDLoc &dl, AddNodeIDNode(ID, ISD::MSTORE, VTs, Ops); ID.AddInteger(VT.getRawBits()); ID.AddInteger(getSyntheticNodeSubclassData( - dl.getIROrder(), VTs, isTrunc, isCompress, MemVT, MMO)); + dl.getIROrder(), VTs, IsTruncating, IsCompressing, MemVT, MMO)); ID.AddInteger(MMO->getPointerInfo().getAddrSpace()); void *IP = nullptr; if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) { @@ -5392,7 +5392,7 @@ SDValue SelectionDAG::getMaskedStore(SDValue Chain, const SDLoc &dl, return SDValue(E, 0); } auto *N = newSDNode(dl.getIROrder(), dl.getDebugLoc(), VTs, - isTrunc, isCompress, MemVT, MMO); + IsTruncating, IsCompressing, MemVT, MMO); createOperands(N, Ops); CSEMap.InsertNode(N, IP); diff --git a/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp index 4ae94e0befa..2aaab4b0d87 100644 --- a/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp +++ b/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp @@ -3821,7 +3821,7 @@ void SelectionDAGBuilder::visitMaskedLoad(const CallInst &I) { Alignment, AAInfo, Ranges); SDValue Load = DAG.getMaskedLoad(VT, sdl, InChain, Ptr, Mask, Src0, VT, MMO, - ISD::NON_EXTLOAD); + ISD::NON_EXTLOAD, false); if (AddToChain) { SDValue OutChain = Load.getValue(1); DAG.setRoot(OutChain); diff --git a/lib/Target/X86/X86ISelLowering.cpp b/lib/Target/X86/X86ISelLowering.cpp index 6fbd9dcfd32..32c4ffe585a 100644 --- a/lib/Target/X86/X86ISelLowering.cpp +++ b/lib/Target/X86/X86ISelLowering.cpp @@ -18854,7 +18854,8 @@ static SDValue LowerINTRINSIC_W_CHAIN(SDValue Op, const X86Subtarget &Subtarget, SDValue VMask = getMaskNode(Mask, MaskVT, Subtarget, DAG, dl); return DAG.getMaskedStore(Chain, dl, DataToCompress, Addr, VMask, VT, - MemIntr->getMemOperand(), false, true); + MemIntr->getMemOperand(), + false /* truncating */, true /* compressing */); } case TRUNCATE_TO_MEM_VI8: case TRUNCATE_TO_MEM_VI16: @@ -18877,7 +18878,7 @@ static SDValue LowerINTRINSIC_W_CHAIN(SDValue Op, const X86Subtarget &Subtarget, SDValue VMask = getMaskNode(Mask, MaskVT, Subtarget, DAG, dl); return DAG.getMaskedStore(Chain, dl, DataToTruncate, Addr, VMask, VT, - MemIntr->getMemOperand(), true); + MemIntr->getMemOperand(), true /* truncating */); } case EXPAND_FROM_MEM: { SDValue Mask = Op.getOperand(4); @@ -18889,16 +18890,16 @@ static SDValue LowerINTRINSIC_W_CHAIN(SDValue Op, const X86Subtarget &Subtarget, MemIntrinsicSDNode *MemIntr = dyn_cast(Op); assert(MemIntr && "Expected MemIntrinsicSDNode!"); - SDValue DataToExpand = DAG.getLoad(VT, dl, Chain, Addr, - MemIntr->getMemOperand()); + if (isAllOnesConstant(Mask)) // Return a regular (unmasked) vector load. + return DAG.getLoad(VT, dl, Chain, Addr, MemIntr->getMemOperand()); + if (X86::isZeroNode(Mask)) + return DAG.getUNDEF(VT); - if (isAllOnesConstant(Mask)) // return just a load - return DataToExpand; - - SDValue Results[] = { - getVectorMaskingNode(DAG.getNode(IntrData->Opc0, dl, VT, DataToExpand), - Mask, PassThru, Subtarget, DAG), Chain}; - return DAG.getMergeValues(Results, dl); + MVT MaskVT = MVT::getVectorVT(MVT::i1, VT.getVectorNumElements()); + SDValue VMask = getMaskNode(Mask, MaskVT, Subtarget, DAG, dl); + return DAG.getMaskedLoad(VT, dl, Chain, Addr, VMask, PassThru, VT, + MemIntr->getMemOperand(), ISD::NON_EXTLOAD, + true /* expanding */); } } } diff --git a/lib/Target/X86/X86InstrAVX512.td b/lib/Target/X86/X86InstrAVX512.td index d0ef44c92ba..ea1c5b09bc8 100644 --- a/lib/Target/X86/X86InstrAVX512.td +++ b/lib/Target/X86/X86InstrAVX512.td @@ -7552,13 +7552,28 @@ multiclass expand_by_vec_width opc, X86VectorVTInfo _, AVX5128IBase, EVEX_CD8<_.EltSize, CD8VT1>; } +multiclass expand_by_vec_width_lowering { + + def : Pat<(_.VT (X86mExpandingLoad addr:$src, _.KRCWM:$mask, undef)), + (!cast(NAME#_.ZSuffix##rmkz) + _.KRCWM:$mask, addr:$src)>; + + def : Pat<(_.VT (X86mExpandingLoad addr:$src, _.KRCWM:$mask, + (_.VT _.RC:$src0))), + (!cast(NAME#_.ZSuffix##rmk) + _.RC:$src0, _.KRCWM:$mask, addr:$src)>; +} + multiclass expand_by_elt_width opc, string OpcodeStr, AVX512VLVectorVTInfo VTInfo> { - defm Z : expand_by_vec_width, EVEX_V512; + defm Z : expand_by_vec_width, + expand_by_vec_width_lowering, EVEX_V512; let Predicates = [HasVLX] in { - defm Z256 : expand_by_vec_width, EVEX_V256; - defm Z128 : expand_by_vec_width, EVEX_V128; + defm Z256 : expand_by_vec_width, + expand_by_vec_width_lowering, EVEX_V256; + defm Z128 : expand_by_vec_width, + expand_by_vec_width_lowering, EVEX_V128; } } diff --git a/lib/Target/X86/X86InstrFragmentsSIMD.td b/lib/Target/X86/X86InstrFragmentsSIMD.td index 6eb5c9a45f7..f1b9475600f 100644 --- a/lib/Target/X86/X86InstrFragmentsSIMD.td +++ b/lib/Target/X86/X86InstrFragmentsSIMD.td @@ -919,30 +919,36 @@ def vinsert256_insert : PatFrag<(ops node:$bigvec, node:$smallvec, return X86::isVINSERT256Index(N); }], INSERT_get_vinsert256_imm>; -def masked_load_aligned128 : PatFrag<(ops node:$src1, node:$src2, node:$src3), +def X86mload : PatFrag<(ops node:$src1, node:$src2, node:$src3), (masked_load node:$src1, node:$src2, node:$src3), [{ - if (auto *Load = dyn_cast(N)) - return Load->getAlignment() >= 16; - return false; + return !cast(N)->isExpandingLoad() && + cast(N)->getExtensionType() == ISD::NON_EXTLOAD; +}]>; + +def masked_load_aligned128 : PatFrag<(ops node:$src1, node:$src2, node:$src3), + (X86mload node:$src1, node:$src2, node:$src3), [{ + return cast(N)->getAlignment() >= 16; }]>; def masked_load_aligned256 : PatFrag<(ops node:$src1, node:$src2, node:$src3), - (masked_load node:$src1, node:$src2, node:$src3), [{ - if (auto *Load = dyn_cast(N)) - return Load->getAlignment() >= 32; - return false; + (X86mload node:$src1, node:$src2, node:$src3), [{ + return cast(N)->getAlignment() >= 32; }]>; def masked_load_aligned512 : PatFrag<(ops node:$src1, node:$src2, node:$src3), - (masked_load node:$src1, node:$src2, node:$src3), [{ - if (auto *Load = dyn_cast(N)) - return Load->getAlignment() >= 64; - return false; + (X86mload node:$src1, node:$src2, node:$src3), [{ + return cast(N)->getAlignment() >= 64; }]>; def masked_load_unaligned : PatFrag<(ops node:$src1, node:$src2, node:$src3), (masked_load node:$src1, node:$src2, node:$src3), [{ - return isa(N); + return !cast(N)->isExpandingLoad() && + cast(N)->getExtensionType() == ISD::NON_EXTLOAD; +}]>; + +def X86mExpandingLoad : PatFrag<(ops node:$src1, node:$src2, node:$src3), + (masked_load node:$src1, node:$src2, node:$src3), [{ + return cast(N)->isExpandingLoad(); }]>; // Masked store fragments. diff --git a/lib/Target/X86/X86InstrSSE.td b/lib/Target/X86/X86InstrSSE.td index 8db144af9b0..b298a2b8812 100644 --- a/lib/Target/X86/X86InstrSSE.td +++ b/lib/Target/X86/X86InstrSSE.td @@ -8622,12 +8622,12 @@ multiclass maskmov_lowering(InstrStr#"mr") addr:$ptr, RC:$mask, RC:$src)>; // masked load - def: Pat<(VT (masked_load addr:$ptr, (MaskVT RC:$mask), undef)), + def: Pat<(VT (X86mload addr:$ptr, (MaskVT RC:$mask), undef)), (!cast(InstrStr#"rm") RC:$mask, addr:$ptr)>; - def: Pat<(VT (masked_load addr:$ptr, (MaskVT RC:$mask), + def: Pat<(VT (X86mload addr:$ptr, (MaskVT RC:$mask), (VT (bitconvert (ZeroVT immAllZerosV))))), (!cast(InstrStr#"rm") RC:$mask, addr:$ptr)>; - def: Pat<(VT (masked_load addr:$ptr, (MaskVT RC:$mask), (VT RC:$src0))), + def: Pat<(VT (X86mload addr:$ptr, (MaskVT RC:$mask), (VT RC:$src0))), (!cast(BlendStr#"rr") RC:$src0, (!cast(InstrStr#"rm") RC:$mask, addr:$ptr), diff --git a/test/CodeGen/X86/avx512vl-intrinsics.ll b/test/CodeGen/X86/avx512vl-intrinsics.ll index 0ee213d8ba2..5e4e8fd529b 100644 --- a/test/CodeGen/X86/avx512vl-intrinsics.ll +++ b/test/CodeGen/X86/avx512vl-intrinsics.ll @@ -1042,6 +1042,29 @@ define <4 x i32> @expand10(<4 x i32> %data, i8 %mask) { declare <4 x i32> @llvm.x86.avx512.mask.expand.d.128(<4 x i32> %data, <4 x i32> %src0, i8 %mask) +define <8 x i64> @expand11(i8* %addr) { +; CHECK-LABEL: expand11: +; CHECK: ## BB#0: +; CHECK-NEXT: vmovups (%rdi), %zmm0 ## encoding: [0x62,0xf1,0x7c,0x48,0x10,0x07] +; CHECK-NEXT: retq ## encoding: [0xc3] + %res = call <8 x i64> @llvm.x86.avx512.mask.expand.load.q.512(i8* %addr, <8 x i64> undef, i8 -1) + ret <8 x i64> %res +} + +define <8 x i64> @expand12(i8* %addr, i8 %mask) { +; CHECK-LABEL: expand12: +; CHECK: ## BB#0: +; CHECK-NEXT: kmovw %esi, %k1 ## encoding: [0xc5,0xf8,0x92,0xce] +; CHECK-NEXT: vpexpandq (%rdi), %zmm0 {%k1} {z} ## encoding: [0x62,0xf2,0xfd,0xc9,0x89,0x07] +; CHECK-NEXT: retq ## encoding: [0xc3] + %laddr = bitcast i8* %addr to <8 x i64>* + %data = load <8 x i64>, <8 x i64>* %laddr, align 1 + %res = call <8 x i64> @llvm.x86.avx512.mask.expand.q.512(<8 x i64> %data, <8 x i64>zeroinitializer, i8 %mask) + ret <8 x i64> %res +} + +declare <8 x i64> @llvm.x86.avx512.mask.expand.q.512(<8 x i64> , <8 x i64>, i8) + define < 2 x i64> @test_mask_mul_epi32_rr_128(< 4 x i32> %a, < 4 x i32> %b) { ; CHECK-LABEL: test_mask_mul_epi32_rr_128: ; CHECK: ## BB#0: @@ -5250,9 +5273,9 @@ define <8 x i32>@test_int_x86_avx512_mask_psrav8_si_const() { ; CHECK: ## BB#0: ; CHECK-NEXT: vmovdqa32 {{.*#+}} ymm0 = [2,9,4294967284,23,4294967270,37,4294967256,51] ; CHECK-NEXT: ## encoding: [0x62,0xf1,0x7d,0x28,0x6f,0x05,A,A,A,A] -; CHECK-NEXT: ## fixup A - offset: 6, value: LCPI309_0-4, kind: reloc_riprel_4byte +; CHECK-NEXT: ## fixup A - offset: 6, value: LCPI311_0-4, kind: reloc_riprel_4byte ; CHECK-NEXT: vpsravd {{.*}}(%rip), %ymm0, %ymm0 ## encoding: [0x62,0xf2,0x7d,0x28,0x46,0x05,A,A,A,A] -; CHECK-NEXT: ## fixup A - offset: 6, value: LCPI309_1-4, kind: reloc_riprel_4byte +; CHECK-NEXT: ## fixup A - offset: 6, value: LCPI311_1-4, kind: reloc_riprel_4byte ; CHECK-NEXT: retq ## encoding: [0xc3] %res = call <8 x i32> @llvm.x86.avx512.mask.psrav8.si(<8 x i32> , <8 x i32> , <8 x i32> zeroinitializer, i8 -1) ret <8 x i32> %res @@ -5283,9 +5306,9 @@ define <2 x i64>@test_int_x86_avx512_mask_psrav_q_128_const(i8 %x3) { ; CHECK: ## BB#0: ; CHECK-NEXT: vmovdqa64 {{.*#+}} xmm0 = [2,18446744073709551607] ; CHECK-NEXT: ## encoding: [0x62,0xf1,0xfd,0x08,0x6f,0x05,A,A,A,A] -; CHECK-NEXT: ## fixup A - offset: 6, value: LCPI311_0-4, kind: reloc_riprel_4byte +; CHECK-NEXT: ## fixup A - offset: 6, value: LCPI313_0-4, kind: reloc_riprel_4byte ; CHECK-NEXT: vpsravq {{.*}}(%rip), %xmm0, %xmm0 ## encoding: [0x62,0xf2,0xfd,0x08,0x46,0x05,A,A,A,A] -; CHECK-NEXT: ## fixup A - offset: 6, value: LCPI311_1-4, kind: reloc_riprel_4byte +; CHECK-NEXT: ## fixup A - offset: 6, value: LCPI313_1-4, kind: reloc_riprel_4byte ; CHECK-NEXT: retq ## encoding: [0xc3] %res = call <2 x i64> @llvm.x86.avx512.mask.psrav.q.128(<2 x i64> , <2 x i64> , <2 x i64> zeroinitializer, i8 -1) ret <2 x i64> %res -- 2.50.0