]> granicus.if.org Git - llvm/commitdiff
DAG: Setting Masked-Expand-Load as a variant of Masked-Load node
authorElena Demikhovsky <elena.demikhovsky@intel.com>
Sun, 9 Oct 2016 10:48:52 +0000 (10:48 +0000)
committerElena Demikhovsky <elena.demikhovsky@intel.com>
Sun, 9 Oct 2016 10:48:52 +0000 (10:48 +0000)
Masked-expand-load node represents load operation that loads a variable amount of elements from memory according to amount of "true" bits in the mask and expands the loaded elements according to their position in the mask vector.
Right now, the node is used in intrinsics for VEXPAND* instructions.
The work is done towards implementation of masked.expandload and masked.compressstore intrinsics.

Differential Revision: https://reviews.llvm.org/D25322

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@283694 91177308-0d34-0410-b5e6-96231b3b80d8

include/llvm/CodeGen/SelectionDAG.h
include/llvm/CodeGen/SelectionDAGNodes.h
lib/CodeGen/SelectionDAG/SelectionDAG.cpp
lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
lib/Target/X86/X86ISelLowering.cpp
lib/Target/X86/X86InstrAVX512.td
lib/Target/X86/X86InstrFragmentsSIMD.td
lib/Target/X86/X86InstrSSE.td
test/CodeGen/X86/avx512vl-intrinsics.ll

index d934ddb8366441a1e160a7026d9d7e59a51a2ca3..9f45cc82089686139bd2a51d00c782e833125882 100644 (file)
@@ -965,11 +965,12 @@ public:
 
   SDValue getMaskedLoad(EVT VT, const SDLoc &dl, SDValue Chain, SDValue Ptr,
                         SDValue Mask, SDValue Src0, EVT MemVT,
-                        MachineMemOperand *MMO, ISD::LoadExtType);
+                        MachineMemOperand *MMO, ISD::LoadExtType,
+                        bool IsExpanding = false);
   SDValue getMaskedStore(SDValue Chain, const SDLoc &dl, SDValue Val,
                          SDValue Ptr, SDValue Mask, EVT MemVT,
-                         MachineMemOperand *MMO, bool IsTrunc, 
-                         bool isCompressing = false);
+                         MachineMemOperand *MMO, bool IsTruncating = false
+                         bool IsCompressing = false);
   SDValue getMaskedGather(SDVTList VTs, EVT VT, const SDLoc &dl,
                           ArrayRef<SDValue> Ops, MachineMemOperand *MMO);
   SDValue getMaskedScatter(SDVTList VTs, EVT VT, const SDLoc &dl,
index fb5a01fade8a006052fac3afc610c9b11d1d66c3..1d14d1228ce26cb421ce9cec3e4a5c4eb53e923d 100644 (file)
@@ -444,6 +444,7 @@ protected:
     uint16_t : NumLSBaseSDNodeBits;
 
     uint16_t ExtTy : 2; // enum ISD::LoadExtType
+    uint16_t IsExpanding : 1;
   };
 
   class StoreSDNodeBitfields {
@@ -473,7 +474,7 @@ protected:
   static_assert(sizeof(ConstantSDNodeBitfields) <= 2, "field too wide");
   static_assert(sizeof(MemSDNodeBitfields) <= 2, "field too wide");
   static_assert(sizeof(LSBaseSDNodeBitfields) <= 2, "field too wide");
-  static_assert(sizeof(LoadSDNodeBitfields) <= 2, "field too wide");
+  static_assert(sizeof(LoadSDNodeBitfields) <= 4, "field too wide");
   static_assert(sizeof(StoreSDNodeBitfields) <= 2, "field too wide");
 
 private:
@@ -1939,9 +1940,11 @@ class MaskedLoadSDNode : public MaskedLoadStoreSDNode {
 public:
   friend class SelectionDAG;
   MaskedLoadSDNode(unsigned Order, const DebugLoc &dl, SDVTList VTs,
-                   ISD::LoadExtType ETy, EVT MemVT, MachineMemOperand *MMO)
+                   ISD::LoadExtType ETy, bool IsExpanding, EVT MemVT,
+                   MachineMemOperand *MMO)
       : MaskedLoadStoreSDNode(ISD::MLOAD, Order, dl, VTs, MemVT, MMO) {
     LoadSDNodeBits.ExtTy = ETy;
+    LoadSDNodeBits.IsExpanding = IsExpanding;
   }
 
   ISD::LoadExtType getExtensionType() const {
@@ -1952,6 +1955,8 @@ public:
   static bool classof(const SDNode *N) {
     return N->getOpcode() == ISD::MLOAD;
   }
+
+  bool isExpandingLoad() const { return LoadSDNodeBits.IsExpanding; }
 };
 
 /// This class is used to represent an MSTORE node
index a915fe161a3c4947e5585b07fbb003a63a9b83b4..3671422be2c97684b73afd5ca2a8bd74950e09c0 100644 (file)
@@ -5347,7 +5347,7 @@ SDValue SelectionDAG::getIndexedStore(SDValue OrigStore, const SDLoc &dl,
 SDValue SelectionDAG::getMaskedLoad(EVT VT, const SDLoc &dl, SDValue Chain,
                                     SDValue Ptr, SDValue Mask, SDValue Src0,
                                     EVT MemVT, MachineMemOperand *MMO,
-                                    ISD::LoadExtType ExtTy) {
+                                    ISD::LoadExtType ExtTy, bool isExpanding) {
 
   SDVTList VTs = getVTList(VT, MVT::Other);
   SDValue Ops[] = { Chain, Ptr, Mask, Src0 };
@@ -5355,7 +5355,7 @@ SDValue SelectionDAG::getMaskedLoad(EVT VT, const SDLoc &dl, SDValue Chain,
   AddNodeIDNode(ID, ISD::MLOAD, VTs, Ops);
   ID.AddInteger(VT.getRawBits());
   ID.AddInteger(getSyntheticNodeSubclassData<MaskedLoadSDNode>(
-      dl.getIROrder(), VTs, ExtTy, MemVT, MMO));
+      dl.getIROrder(), VTs, ExtTy, isExpanding, MemVT, MMO));
   ID.AddInteger(MMO->getPointerInfo().getAddrSpace());
   void *IP = nullptr;
   if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) {
@@ -5363,7 +5363,7 @@ SDValue SelectionDAG::getMaskedLoad(EVT VT, const SDLoc &dl, SDValue Chain,
     return SDValue(E, 0);
   }
   auto *N = newSDNode<MaskedLoadSDNode>(dl.getIROrder(), dl.getDebugLoc(), VTs,
-                                        ExtTy, MemVT, MMO);
+                                        ExtTy, isExpanding, MemVT, MMO);
   createOperands(N, Ops);
 
   CSEMap.InsertNode(N, IP);
@@ -5374,7 +5374,7 @@ SDValue SelectionDAG::getMaskedLoad(EVT VT, const SDLoc &dl, SDValue Chain,
 SDValue SelectionDAG::getMaskedStore(SDValue Chain, const SDLoc &dl,
                                      SDValue Val, SDValue Ptr, SDValue Mask,
                                      EVT MemVT, MachineMemOperand *MMO,
-                                     bool isTrunc, bool isCompress) {
+                                     bool IsTruncating, bool IsCompressing) {
   assert(Chain.getValueType() == MVT::Other &&
         "Invalid chain type");
   EVT VT = Val.getValueType();
@@ -5384,7 +5384,7 @@ SDValue SelectionDAG::getMaskedStore(SDValue Chain, const SDLoc &dl,
   AddNodeIDNode(ID, ISD::MSTORE, VTs, Ops);
   ID.AddInteger(VT.getRawBits());
   ID.AddInteger(getSyntheticNodeSubclassData<MaskedStoreSDNode>(
-      dl.getIROrder(), VTs, isTrunc, isCompress, MemVT, MMO));
+      dl.getIROrder(), VTs, IsTruncating, IsCompressing, MemVT, MMO));
   ID.AddInteger(MMO->getPointerInfo().getAddrSpace());
   void *IP = nullptr;
   if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) {
@@ -5392,7 +5392,7 @@ SDValue SelectionDAG::getMaskedStore(SDValue Chain, const SDLoc &dl,
     return SDValue(E, 0);
   }
   auto *N = newSDNode<MaskedStoreSDNode>(dl.getIROrder(), dl.getDebugLoc(), VTs,
-                                         isTrunc, isCompress, MemVT, MMO);
+                                         IsTruncating, IsCompressing, MemVT, MMO);
   createOperands(N, Ops);
 
   CSEMap.InsertNode(N, IP);
index 4ae94e0befa55b2576a2798bf0e7df6ad2eb2cc9..2aaab4b0d873fba5c617fe4165ada66ad84c297c 100644 (file)
@@ -3821,7 +3821,7 @@ void SelectionDAGBuilder::visitMaskedLoad(const CallInst &I) {
                           Alignment, AAInfo, Ranges);
 
   SDValue Load = DAG.getMaskedLoad(VT, sdl, InChain, Ptr, Mask, Src0, VT, MMO,
-                                   ISD::NON_EXTLOAD);
+                                   ISD::NON_EXTLOAD, false);
   if (AddToChain) {
     SDValue OutChain = Load.getValue(1);
     DAG.setRoot(OutChain);
index 6fbd9dcfd32349cfab9cde079f4d10fafaea3eef..32c4ffe585a0ce822bbbd1bdc4a497ff62bc3fba 100644 (file)
@@ -18854,7 +18854,8 @@ static SDValue LowerINTRINSIC_W_CHAIN(SDValue Op, const X86Subtarget &Subtarget,
     SDValue VMask = getMaskNode(Mask, MaskVT, Subtarget, DAG, dl);
 
     return DAG.getMaskedStore(Chain, dl, DataToCompress, Addr, VMask, VT,
-                              MemIntr->getMemOperand(), false, true);
+                              MemIntr->getMemOperand(),
+                              false /* truncating */, true /* compressing */);
   }
   case TRUNCATE_TO_MEM_VI8:
   case TRUNCATE_TO_MEM_VI16:
@@ -18877,7 +18878,7 @@ static SDValue LowerINTRINSIC_W_CHAIN(SDValue Op, const X86Subtarget &Subtarget,
     SDValue VMask = getMaskNode(Mask, MaskVT, Subtarget, DAG, dl);
 
     return DAG.getMaskedStore(Chain, dl, DataToTruncate, Addr, VMask, VT,
-                              MemIntr->getMemOperand(), true);
+                              MemIntr->getMemOperand(), true /* truncating */);
   }
   case EXPAND_FROM_MEM: {
     SDValue Mask = Op.getOperand(4);
@@ -18889,16 +18890,16 @@ static SDValue LowerINTRINSIC_W_CHAIN(SDValue Op, const X86Subtarget &Subtarget,
     MemIntrinsicSDNode *MemIntr = dyn_cast<MemIntrinsicSDNode>(Op);
     assert(MemIntr && "Expected MemIntrinsicSDNode!");
 
-    SDValue DataToExpand = DAG.getLoad(VT, dl, Chain, Addr,
-                                       MemIntr->getMemOperand());
+    if (isAllOnesConstant(Mask)) // Return a regular (unmasked) vector load.
+      return DAG.getLoad(VT, dl, Chain, Addr, MemIntr->getMemOperand());
+    if (X86::isZeroNode(Mask))
+      return DAG.getUNDEF(VT);
 
-    if (isAllOnesConstant(Mask)) // return just a load
-      return DataToExpand;
-
-    SDValue Results[] = {
-      getVectorMaskingNode(DAG.getNode(IntrData->Opc0, dl, VT, DataToExpand),
-                           Mask, PassThru, Subtarget, DAG), Chain};
-    return DAG.getMergeValues(Results, dl);
+    MVT MaskVT = MVT::getVectorVT(MVT::i1, VT.getVectorNumElements());
+    SDValue VMask = getMaskNode(Mask, MaskVT, Subtarget, DAG, dl);
+    return DAG.getMaskedLoad(VT, dl, Chain, Addr, VMask, PassThru, VT,
+                             MemIntr->getMemOperand(), ISD::NON_EXTLOAD,
+                             true /* expanding */);
   }
   }
 }
index d0ef44c92ba86649151c038047cc7f5d361739d4..ea1c5b09bc8b66012891be20f833cf9575afce8e 100644 (file)
@@ -7552,13 +7552,28 @@ multiclass expand_by_vec_width<bits<8> opc, X86VectorVTInfo _,
             AVX5128IBase, EVEX_CD8<_.EltSize, CD8VT1>;
 }
 
+multiclass expand_by_vec_width_lowering<X86VectorVTInfo _ > {
+
+  def : Pat<(_.VT (X86mExpandingLoad addr:$src, _.KRCWM:$mask, undef)),
+            (!cast<Instruction>(NAME#_.ZSuffix##rmkz)
+                                        _.KRCWM:$mask, addr:$src)>;
+
+  def : Pat<(_.VT (X86mExpandingLoad addr:$src, _.KRCWM:$mask,
+                                               (_.VT _.RC:$src0))),
+            (!cast<Instruction>(NAME#_.ZSuffix##rmk)
+                            _.RC:$src0, _.KRCWM:$mask, addr:$src)>;
+}
+
 multiclass expand_by_elt_width<bits<8> opc, string OpcodeStr,
                                  AVX512VLVectorVTInfo VTInfo> {
-  defm Z : expand_by_vec_width<opc, VTInfo.info512, OpcodeStr>, EVEX_V512;
+  defm Z : expand_by_vec_width<opc, VTInfo.info512, OpcodeStr>,
+           expand_by_vec_width_lowering<VTInfo.info512>, EVEX_V512;
 
   let Predicates = [HasVLX] in {
-    defm Z256 : expand_by_vec_width<opc, VTInfo.info256, OpcodeStr>, EVEX_V256;
-    defm Z128 : expand_by_vec_width<opc, VTInfo.info128, OpcodeStr>, EVEX_V128;
+    defm Z256 : expand_by_vec_width<opc, VTInfo.info256, OpcodeStr>,
+                expand_by_vec_width_lowering<VTInfo.info256>, EVEX_V256;
+    defm Z128 : expand_by_vec_width<opc, VTInfo.info128, OpcodeStr>,
+                expand_by_vec_width_lowering<VTInfo.info128>, EVEX_V128;
   }
 }
 
index 6eb5c9a45f746e4c260bae0bd144882af2663d7c..f1b9475600f1a72ccfca9be475789ea500dee92b 100644 (file)
@@ -919,30 +919,36 @@ def vinsert256_insert : PatFrag<(ops node:$bigvec, node:$smallvec,
   return X86::isVINSERT256Index(N);
 }], INSERT_get_vinsert256_imm>;
 
-def masked_load_aligned128 : PatFrag<(ops node:$src1, node:$src2, node:$src3),
+def X86mload : PatFrag<(ops node:$src1, node:$src2, node:$src3),
                          (masked_load node:$src1, node:$src2, node:$src3), [{
-  if (auto *Load = dyn_cast<MaskedLoadSDNode>(N))
-    return Load->getAlignment() >= 16;
-  return false;
+  return !cast<MaskedLoadSDNode>(N)->isExpandingLoad() &&
+    cast<MaskedLoadSDNode>(N)->getExtensionType() == ISD::NON_EXTLOAD;
+}]>;
+
+def masked_load_aligned128 : PatFrag<(ops node:$src1, node:$src2, node:$src3),
+                         (X86mload node:$src1, node:$src2, node:$src3), [{
+  return cast<MaskedLoadSDNode>(N)->getAlignment() >= 16;
 }]>;
 
 def masked_load_aligned256 : PatFrag<(ops node:$src1, node:$src2, node:$src3),
-                         (masked_load node:$src1, node:$src2, node:$src3), [{
-  if (auto *Load = dyn_cast<MaskedLoadSDNode>(N))
-    return Load->getAlignment() >= 32;
-  return false;
+                         (X86mload node:$src1, node:$src2, node:$src3), [{
+  return cast<MaskedLoadSDNode>(N)->getAlignment() >= 32;
 }]>;
 
 def masked_load_aligned512 : PatFrag<(ops node:$src1, node:$src2, node:$src3),
-                         (masked_load node:$src1, node:$src2, node:$src3), [{
-  if (auto *Load = dyn_cast<MaskedLoadSDNode>(N))
-    return Load->getAlignment() >= 64;
-  return false;
+                         (X86mload node:$src1, node:$src2, node:$src3), [{
+  return cast<MaskedLoadSDNode>(N)->getAlignment() >= 64;
 }]>;
 
 def masked_load_unaligned : PatFrag<(ops node:$src1, node:$src2, node:$src3),
                          (masked_load node:$src1, node:$src2, node:$src3), [{
-  return isa<MaskedLoadSDNode>(N);
+  return !cast<MaskedLoadSDNode>(N)->isExpandingLoad() &&
+    cast<MaskedLoadSDNode>(N)->getExtensionType() == ISD::NON_EXTLOAD;
+}]>;
+
+def X86mExpandingLoad : PatFrag<(ops node:$src1, node:$src2, node:$src3),
+                         (masked_load node:$src1, node:$src2, node:$src3), [{
+  return cast<MaskedLoadSDNode>(N)->isExpandingLoad();
 }]>;
 
 // Masked store fragments.
index 8db144af9b0832428cebe1797da1f97bdae0170b..b298a2b8812961654bc1d1fe9bc3270739266e13 100644 (file)
@@ -8622,12 +8622,12 @@ multiclass maskmov_lowering<string InstrStr, RegisterClass RC, ValueType VT,
     def: Pat<(X86mstore addr:$ptr, (MaskVT RC:$mask), (VT RC:$src)),
              (!cast<Instruction>(InstrStr#"mr") addr:$ptr, RC:$mask, RC:$src)>;
     // masked load
-    def: Pat<(VT (masked_load addr:$ptr, (MaskVT RC:$mask), undef)),
+    def: Pat<(VT (X86mload addr:$ptr, (MaskVT RC:$mask), undef)),
              (!cast<Instruction>(InstrStr#"rm") RC:$mask, addr:$ptr)>;
-    def: Pat<(VT (masked_load addr:$ptr, (MaskVT RC:$mask),
+    def: Pat<(VT (X86mload addr:$ptr, (MaskVT RC:$mask),
                               (VT (bitconvert (ZeroVT immAllZerosV))))),
              (!cast<Instruction>(InstrStr#"rm") RC:$mask, addr:$ptr)>;
-    def: Pat<(VT (masked_load addr:$ptr, (MaskVT RC:$mask), (VT RC:$src0))),
+    def: Pat<(VT (X86mload addr:$ptr, (MaskVT RC:$mask), (VT RC:$src0))),
              (!cast<Instruction>(BlendStr#"rr")
                  RC:$src0,
                  (!cast<Instruction>(InstrStr#"rm") RC:$mask, addr:$ptr),
index 0ee213d8ba20f794a9f69337c2abc4bd22fd6897..5e4e8fd529bd216f57e4c24666ce6dd253af3e11 100644 (file)
@@ -1042,6 +1042,29 @@ define <4 x i32> @expand10(<4 x i32> %data, i8 %mask) {
 
 declare <4 x i32> @llvm.x86.avx512.mask.expand.d.128(<4 x i32> %data, <4 x i32> %src0, i8 %mask)
 
+define <8 x i64> @expand11(i8* %addr) {
+; CHECK-LABEL: expand11:
+; CHECK:       ## BB#0:
+; CHECK-NEXT:    vmovups (%rdi), %zmm0 ## encoding: [0x62,0xf1,0x7c,0x48,0x10,0x07]
+; CHECK-NEXT:    retq ## encoding: [0xc3]
+  %res = call <8 x i64> @llvm.x86.avx512.mask.expand.load.q.512(i8* %addr, <8 x i64> undef, i8 -1)
+  ret <8 x i64> %res
+}
+
+define <8 x i64> @expand12(i8* %addr, i8 %mask) {
+; CHECK-LABEL: expand12:
+; CHECK:       ## BB#0:
+; CHECK-NEXT:    kmovw %esi, %k1 ## encoding: [0xc5,0xf8,0x92,0xce]
+; CHECK-NEXT:    vpexpandq (%rdi), %zmm0 {%k1} {z} ## encoding: [0x62,0xf2,0xfd,0xc9,0x89,0x07]
+; CHECK-NEXT:    retq ## encoding: [0xc3]
+  %laddr = bitcast i8* %addr to <8 x i64>*
+  %data = load <8 x i64>, <8 x i64>* %laddr, align 1
+  %res = call <8 x i64> @llvm.x86.avx512.mask.expand.q.512(<8 x i64> %data, <8 x i64>zeroinitializer, i8 %mask)
+  ret <8 x i64> %res
+}
+
+declare <8 x i64> @llvm.x86.avx512.mask.expand.q.512(<8 x i64> , <8 x i64>, i8)
+
 define < 2 x i64> @test_mask_mul_epi32_rr_128(< 4 x i32> %a, < 4 x i32> %b) {
 ; CHECK-LABEL: test_mask_mul_epi32_rr_128:
 ; CHECK:       ## BB#0:
@@ -5250,9 +5273,9 @@ define <8 x i32>@test_int_x86_avx512_mask_psrav8_si_const() {
 ; CHECK:       ## BB#0:
 ; CHECK-NEXT:    vmovdqa32 {{.*#+}} ymm0 = [2,9,4294967284,23,4294967270,37,4294967256,51]
 ; CHECK-NEXT:    ## encoding: [0x62,0xf1,0x7d,0x28,0x6f,0x05,A,A,A,A]
-; CHECK-NEXT:    ## fixup A - offset: 6, value: LCPI309_0-4, kind: reloc_riprel_4byte
+; CHECK-NEXT:    ## fixup A - offset: 6, value: LCPI311_0-4, kind: reloc_riprel_4byte
 ; CHECK-NEXT:    vpsravd {{.*}}(%rip), %ymm0, %ymm0 ## encoding: [0x62,0xf2,0x7d,0x28,0x46,0x05,A,A,A,A]
-; CHECK-NEXT:    ## fixup A - offset: 6, value: LCPI309_1-4, kind: reloc_riprel_4byte
+; CHECK-NEXT:    ## fixup A - offset: 6, value: LCPI311_1-4, kind: reloc_riprel_4byte
 ; CHECK-NEXT:    retq ## encoding: [0xc3]
   %res = call <8 x i32> @llvm.x86.avx512.mask.psrav8.si(<8 x i32> <i32 2, i32 9, i32 -12, i32 23, i32 -26, i32 37, i32 -40, i32 51>, <8 x i32> <i32 1, i32 18, i32 35, i32 52, i32 69, i32 15, i32 32, i32 49>, <8 x i32> zeroinitializer, i8 -1)
   ret <8 x i32> %res
@@ -5283,9 +5306,9 @@ define <2 x i64>@test_int_x86_avx512_mask_psrav_q_128_const(i8 %x3) {
 ; CHECK:       ## BB#0:
 ; CHECK-NEXT:    vmovdqa64 {{.*#+}} xmm0 = [2,18446744073709551607]
 ; CHECK-NEXT:    ## encoding: [0x62,0xf1,0xfd,0x08,0x6f,0x05,A,A,A,A]
-; CHECK-NEXT:    ## fixup A - offset: 6, value: LCPI311_0-4, kind: reloc_riprel_4byte
+; CHECK-NEXT:    ## fixup A - offset: 6, value: LCPI313_0-4, kind: reloc_riprel_4byte
 ; CHECK-NEXT:    vpsravq {{.*}}(%rip), %xmm0, %xmm0 ## encoding: [0x62,0xf2,0xfd,0x08,0x46,0x05,A,A,A,A]
-; CHECK-NEXT:    ## fixup A - offset: 6, value: LCPI311_1-4, kind: reloc_riprel_4byte
+; CHECK-NEXT:    ## fixup A - offset: 6, value: LCPI313_1-4, kind: reloc_riprel_4byte
 ; CHECK-NEXT:    retq ## encoding: [0xc3]
   %res = call <2 x i64> @llvm.x86.avx512.mask.psrav.q.128(<2 x i64> <i64 2, i64 -9>, <2 x i64> <i64 1, i64 90>, <2 x i64> zeroinitializer, i8 -1)
   ret <2 x i64> %res