]> granicus.if.org Git - llvm/commitdiff
[X86] Add an X86ISD::MSCATTER node for consistency with the X86ISD::MGATHER.
authorCraig Topper <craig.topper@intel.com>
Wed, 22 Nov 2017 08:10:54 +0000 (08:10 +0000)
committerCraig Topper <craig.topper@intel.com>
Wed, 22 Nov 2017 08:10:54 +0000 (08:10 +0000)
This makes the fact that X86 needs an explicit mask output not part of the type constraint for the ISD::MSCATTER.

This also gives the X86ISD::MGATHER/MSCATTER nodes a common base class simplifying the address selection code in X86ISelDAGToDAG.cpp

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@318823 91177308-0d34-0410-b5e6-96231b3b80d8

lib/Target/X86/X86ISelDAGToDAG.cpp
lib/Target/X86/X86ISelLowering.cpp
lib/Target/X86/X86ISelLowering.h
lib/Target/X86/X86InstrFragmentsSIMD.td

index 71ae97d7e92caae860842106410a5984e6b1c63f..504482a5e2a270775e6da3adf688501bb4a5343d 100644 (file)
@@ -1522,14 +1522,9 @@ bool X86DAGToDAGISel::selectVectorAddr(SDNode *Parent, SDValue N, SDValue &Base,
                                        SDValue &Scale, SDValue &Index,
                                        SDValue &Disp, SDValue &Segment) {
   X86ISelAddressMode AM;
-  if (auto Mgs = dyn_cast<MaskedGatherScatterSDNode>(Parent)) {
-    AM.IndexReg = Mgs->getIndex();
-    AM.Scale = Mgs->getValue().getScalarValueSizeInBits() / 8;
-  } else {
-    auto X86Gather = cast<X86MaskedGatherSDNode>(Parent);
-    AM.IndexReg = X86Gather->getIndex();
-    AM.Scale = X86Gather->getValue().getScalarValueSizeInBits() / 8;
-  }
+  auto *Mgs = cast<X86MaskedGatherScatterSDNode>(Parent);
+  AM.IndexReg = Mgs->getIndex();
+  AM.Scale = Mgs->getValue().getScalarValueSizeInBits() / 8;
 
   unsigned AddrSpace = cast<MemSDNode>(Parent)->getPointerInfo().getAddrSpace();
   // AddrSpace 256 -> GS, 257 -> FS, 258 -> SS.
index b3afe081834b264f5dc9a40de26ac985c37ad0b2..d6436eeac689cc6138b0435bb97f2945dcf461a1 100644 (file)
@@ -24112,19 +24112,12 @@ static SDValue LowerMSCATTER(SDValue Op, const X86Subtarget &Subtarget,
   assert(Subtarget.hasAVX512() &&
          "MGATHER/MSCATTER are supported on AVX-512 arch only");
 
-  // X86 scatter kills mask register, so its type should be added to
-  // the list of return values.
-  // If the "scatter" has 2 return values, it is already handled.
-  if (Op.getNode()->getNumValues() == 2)
-    return Op;
-
   MaskedScatterSDNode *N = cast<MaskedScatterSDNode>(Op.getNode());
   SDValue Src = N->getValue();
   MVT VT = Src.getSimpleValueType();
   assert(VT.getScalarSizeInBits() >= 32 && "Unsupported scatter op");
   SDLoc dl(Op);
 
-  SDValue NewScatter;
   SDValue Index = N->getIndex();
   SDValue Mask = N->getMask();
   SDValue Chain = N->getChain();
@@ -24195,8 +24188,8 @@ static SDValue LowerMSCATTER(SDValue Op, const X86Subtarget &Subtarget,
   // The mask is killed by scatter, add it to the values
   SDVTList VTs = DAG.getVTList(BitMaskVT, MVT::Other);
   SDValue Ops[] = {Chain, Src, Mask, BasePtr, Index};
-  NewScatter = DAG.getMaskedScatter(VTs, N->getMemoryVT(), dl, Ops,
-                                    N->getMemOperand());
+  SDValue NewScatter = DAG.getTargetMemSDNode<X86MaskedScatterSDNode>(
+      VTs, Ops, dl, N->getMemoryVT(), N->getMemOperand());
   DAG.ReplaceAllUsesWith(Op, SDValue(NewScatter.getNode(), 1));
   return SDValue(NewScatter.getNode(), 1);
 }
@@ -25261,6 +25254,7 @@ const char *X86TargetLowering::getTargetNodeName(unsigned Opcode) const {
   case X86ISD::CVTS2UI_RND:        return "X86ISD::CVTS2UI_RND";
   case X86ISD::LWPINS:             return "X86ISD::LWPINS";
   case X86ISD::MGATHER:            return "X86ISD::MGATHER";
+  case X86ISD::MSCATTER:           return "X86ISD::MSCATTER";
   case X86ISD::VPDPBUSD:           return "X86ISD::VPDPBUSD";
   case X86ISD::VPDPBUSDS:          return "X86ISD::VPDPBUSDS";
   case X86ISD::VPDPWSSD:           return "X86ISD::VPDPWSSD";
index b79addfe198dcea2a983d38776301fdecc19d48c..fc8519bb973fdf1d188254d9e39e7cefe55df8e7 100644 (file)
@@ -637,8 +637,8 @@ namespace llvm {
       // Vector truncating masked store with unsigned/signed saturation
       VMTRUNCSTOREUS, VMTRUNCSTORES,
 
-      // X86 specific gather
-      MGATHER
+      // X86 specific gather and scatter
+      MGATHER, MSCATTER,
 
       // WARNING: Do not add anything in the end unless you want the node to
       // have memop! In fact, starting from FIRST_TARGET_MEMORY_OPCODE all
@@ -1423,27 +1423,51 @@ namespace llvm {
     }
   };
 
-  // X86 specific Gather node.
-  // The class has the same order of operands as MaskedGatherSDNode for
+  // X86 specific Gather/Scatter nodes.
+  // The class has the same order of operands as MaskedGatherScatterSDNode for
   // convenience.
-  class X86MaskedGatherSDNode : public MemSDNode {
+  class X86MaskedGatherScatterSDNode : public MemSDNode {
   public:
-    X86MaskedGatherSDNode(unsigned Order,
-                          const DebugLoc &dl, SDVTList VTs, EVT MemVT,
-                          MachineMemOperand *MMO)
-      : MemSDNode(X86ISD::MGATHER, Order, dl, VTs, MemVT, MMO)
-    {}
+    X86MaskedGatherScatterSDNode(unsigned Opc, unsigned Order,
+                                 const DebugLoc &dl, SDVTList VTs, EVT MemVT,
+                                 MachineMemOperand *MMO)
+        : MemSDNode(Opc, Order, dl, VTs, MemVT, MMO) {}
 
     const SDValue &getBasePtr() const { return getOperand(3); }
     const SDValue &getIndex()   const { return getOperand(4); }
     const SDValue &getMask()    const { return getOperand(2); }
     const SDValue &getValue()   const { return getOperand(1); }
 
+    static bool classof(const SDNode *N) {
+      return N->getOpcode() == X86ISD::MGATHER ||
+             N->getOpcode() == X86ISD::MSCATTER;
+    }
+  };
+
+  class X86MaskedGatherSDNode : public X86MaskedGatherScatterSDNode {
+  public:
+    X86MaskedGatherSDNode(unsigned Order, const DebugLoc &dl, SDVTList VTs,
+                          EVT MemVT, MachineMemOperand *MMO)
+        : X86MaskedGatherScatterSDNode(X86ISD::MGATHER, Order, dl, VTs, MemVT,
+                                       MMO) {}
+
     static bool classof(const SDNode *N) {
       return N->getOpcode() == X86ISD::MGATHER;
     }
   };
 
+  class X86MaskedScatterSDNode : public X86MaskedGatherScatterSDNode {
+  public:
+    X86MaskedScatterSDNode(unsigned Order, const DebugLoc &dl, SDVTList VTs,
+                           EVT MemVT, MachineMemOperand *MMO)
+        : X86MaskedGatherScatterSDNode(X86ISD::MSCATTER, Order, dl, VTs, MemVT,
+                                       MMO) {}
+
+    static bool classof(const SDNode *N) {
+      return N->getOpcode() == X86ISD::MSCATTER;
+    }
+  };
+
   /// Generate unpacklo/unpackhi shuffle mask.
   template <typename T = int>
   void createUnpackShuffleMask(MVT VT, SmallVectorImpl<T> &Mask, bool Lo,
index c38b2c730c7f5642abf9610edcf8e89eec9f17a6..2eb735abd697b37147e0d5347902be7ee1d577a4 100644 (file)
@@ -781,6 +781,13 @@ def X86masked_gather : SDNode<"X86ISD::MGATHER",
                                                    SDTCisPtrTy<4>]>,
                              [SDNPHasChain, SDNPMayLoad, SDNPMemOperand]>;
 
+def X86masked_scatter : SDNode<"X86ISD::MSCATTER",
+                              SDTypeProfile<1, 3, [SDTCisVec<0>, SDTCisVec<1>,
+                                                   SDTCisSameAs<0, 2>,
+                                                   SDTCVecEltisVT<0, i1>,
+                                                   SDTCisPtrTy<3>]>,
+                             [SDNPHasChain, SDNPMayStore, SDNPMemOperand]>;
+
 def mgatherv4i32 : PatFrag<(ops node:$src1, node:$src2, node:$src3),
   (X86masked_gather node:$src1, node:$src2, node:$src3) , [{
   X86MaskedGatherSDNode *Mgt = cast<X86MaskedGatherSDNode>(N);
@@ -815,37 +822,37 @@ def mgatherv16i32 : PatFrag<(ops node:$src1, node:$src2, node:$src3),
 }]>;
 
 def mscatterv2i64 : PatFrag<(ops node:$src1, node:$src2, node:$src3),
-  (masked_scatter node:$src1, node:$src2, node:$src3) , [{
-  MaskedScatterSDNode *Sc = cast<MaskedScatterSDNode>(N);
+  (X86masked_scatter node:$src1, node:$src2, node:$src3) , [{
+  X86MaskedScatterSDNode *Sc = cast<X86MaskedScatterSDNode>(N);
   return Sc->getIndex().getValueType() == MVT::v2i64;
 }]>;
 
 def mscatterv4i32 : PatFrag<(ops node:$src1, node:$src2, node:$src3),
-  (masked_scatter node:$src1, node:$src2, node:$src3) , [{
-  MaskedScatterSDNode *Sc = cast<MaskedScatterSDNode>(N);
+  (X86masked_scatter node:$src1, node:$src2, node:$src3) , [{
+  X86MaskedScatterSDNode *Sc = cast<X86MaskedScatterSDNode>(N);
   return Sc->getIndex().getValueType() == MVT::v4i32;
 }]>;
 
 def mscatterv4i64 : PatFrag<(ops node:$src1, node:$src2, node:$src3),
-  (masked_scatter node:$src1, node:$src2, node:$src3) , [{
-  MaskedScatterSDNode *Sc = cast<MaskedScatterSDNode>(N);
+  (X86masked_scatter node:$src1, node:$src2, node:$src3) , [{
+  X86MaskedScatterSDNode *Sc = cast<X86MaskedScatterSDNode>(N);
   return Sc->getIndex().getValueType() == MVT::v4i64;
 }]>;
 
 def mscatterv8i32 : PatFrag<(ops node:$src1, node:$src2, node:$src3),
-  (masked_scatter node:$src1, node:$src2, node:$src3) , [{
-  MaskedScatterSDNode *Sc = cast<MaskedScatterSDNode>(N);
+  (X86masked_scatter node:$src1, node:$src2, node:$src3) , [{
+  X86MaskedScatterSDNode *Sc = cast<X86MaskedScatterSDNode>(N);
   return Sc->getIndex().getValueType() == MVT::v8i32;
 }]>;
 
 def mscatterv8i64 : PatFrag<(ops node:$src1, node:$src2, node:$src3),
-  (masked_scatter node:$src1, node:$src2, node:$src3) , [{
-  MaskedScatterSDNode *Sc = cast<MaskedScatterSDNode>(N);
+  (X86masked_scatter node:$src1, node:$src2, node:$src3) , [{
+  X86MaskedScatterSDNode *Sc = cast<X86MaskedScatterSDNode>(N);
   return Sc->getIndex().getValueType() == MVT::v8i64;
 }]>;
 def mscatterv16i32 : PatFrag<(ops node:$src1, node:$src2, node:$src3),
-  (masked_scatter node:$src1, node:$src2, node:$src3) , [{
-  MaskedScatterSDNode *Sc = cast<MaskedScatterSDNode>(N);
+  (X86masked_scatter node:$src1, node:$src2, node:$src3) , [{
+  X86MaskedScatterSDNode *Sc = cast<X86MaskedScatterSDNode>(N);
   return Sc->getIndex().getValueType() == MVT::v16i32;
 }]>;