SDValue &Scale, SDValue &Index,
SDValue &Disp, SDValue &Segment) {
X86ISelAddressMode AM;
- if (auto Mgs = dyn_cast<MaskedGatherScatterSDNode>(Parent)) {
- AM.IndexReg = Mgs->getIndex();
- AM.Scale = Mgs->getValue().getScalarValueSizeInBits() / 8;
- } else {
- auto X86Gather = cast<X86MaskedGatherSDNode>(Parent);
- AM.IndexReg = X86Gather->getIndex();
- AM.Scale = X86Gather->getValue().getScalarValueSizeInBits() / 8;
- }
+ auto *Mgs = cast<X86MaskedGatherScatterSDNode>(Parent);
+ AM.IndexReg = Mgs->getIndex();
+ AM.Scale = Mgs->getValue().getScalarValueSizeInBits() / 8;
unsigned AddrSpace = cast<MemSDNode>(Parent)->getPointerInfo().getAddrSpace();
// AddrSpace 256 -> GS, 257 -> FS, 258 -> SS.
assert(Subtarget.hasAVX512() &&
"MGATHER/MSCATTER are supported on AVX-512 arch only");
- // X86 scatter kills mask register, so its type should be added to
- // the list of return values.
- // If the "scatter" has 2 return values, it is already handled.
- if (Op.getNode()->getNumValues() == 2)
- return Op;
-
MaskedScatterSDNode *N = cast<MaskedScatterSDNode>(Op.getNode());
SDValue Src = N->getValue();
MVT VT = Src.getSimpleValueType();
assert(VT.getScalarSizeInBits() >= 32 && "Unsupported scatter op");
SDLoc dl(Op);
- SDValue NewScatter;
SDValue Index = N->getIndex();
SDValue Mask = N->getMask();
SDValue Chain = N->getChain();
// The mask is killed by scatter, add it to the values
SDVTList VTs = DAG.getVTList(BitMaskVT, MVT::Other);
SDValue Ops[] = {Chain, Src, Mask, BasePtr, Index};
- NewScatter = DAG.getMaskedScatter(VTs, N->getMemoryVT(), dl, Ops,
- N->getMemOperand());
+ SDValue NewScatter = DAG.getTargetMemSDNode<X86MaskedScatterSDNode>(
+ VTs, Ops, dl, N->getMemoryVT(), N->getMemOperand());
DAG.ReplaceAllUsesWith(Op, SDValue(NewScatter.getNode(), 1));
return SDValue(NewScatter.getNode(), 1);
}
case X86ISD::CVTS2UI_RND: return "X86ISD::CVTS2UI_RND";
case X86ISD::LWPINS: return "X86ISD::LWPINS";
case X86ISD::MGATHER: return "X86ISD::MGATHER";
+ case X86ISD::MSCATTER: return "X86ISD::MSCATTER";
case X86ISD::VPDPBUSD: return "X86ISD::VPDPBUSD";
case X86ISD::VPDPBUSDS: return "X86ISD::VPDPBUSDS";
case X86ISD::VPDPWSSD: return "X86ISD::VPDPWSSD";
// Vector truncating masked store with unsigned/signed saturation
VMTRUNCSTOREUS, VMTRUNCSTORES,
- // X86 specific gather
- MGATHER
+ // X86 specific gather and scatter
+ MGATHER, MSCATTER,
// WARNING: Do not add anything in the end unless you want the node to
// have memop! In fact, starting from FIRST_TARGET_MEMORY_OPCODE all
}
};
- // X86 specific Gather node.
- // The class has the same order of operands as MaskedGatherSDNode for
+ // X86 specific Gather/Scatter nodes.
+ // The class has the same order of operands as MaskedGatherScatterSDNode for
// convenience.
- class X86MaskedGatherSDNode : public MemSDNode {
+ class X86MaskedGatherScatterSDNode : public MemSDNode {
public:
- X86MaskedGatherSDNode(unsigned Order,
- const DebugLoc &dl, SDVTList VTs, EVT MemVT,
- MachineMemOperand *MMO)
- : MemSDNode(X86ISD::MGATHER, Order, dl, VTs, MemVT, MMO)
- {}
+ X86MaskedGatherScatterSDNode(unsigned Opc, unsigned Order,
+ const DebugLoc &dl, SDVTList VTs, EVT MemVT,
+ MachineMemOperand *MMO)
+ : MemSDNode(Opc, Order, dl, VTs, MemVT, MMO) {}
const SDValue &getBasePtr() const { return getOperand(3); }
const SDValue &getIndex() const { return getOperand(4); }
const SDValue &getMask() const { return getOperand(2); }
const SDValue &getValue() const { return getOperand(1); }
+ static bool classof(const SDNode *N) {
+ return N->getOpcode() == X86ISD::MGATHER ||
+ N->getOpcode() == X86ISD::MSCATTER;
+ }
+ };
+
+ class X86MaskedGatherSDNode : public X86MaskedGatherScatterSDNode {
+ public:
+ X86MaskedGatherSDNode(unsigned Order, const DebugLoc &dl, SDVTList VTs,
+ EVT MemVT, MachineMemOperand *MMO)
+ : X86MaskedGatherScatterSDNode(X86ISD::MGATHER, Order, dl, VTs, MemVT,
+ MMO) {}
+
static bool classof(const SDNode *N) {
return N->getOpcode() == X86ISD::MGATHER;
}
};
+ class X86MaskedScatterSDNode : public X86MaskedGatherScatterSDNode {
+ public:
+ X86MaskedScatterSDNode(unsigned Order, const DebugLoc &dl, SDVTList VTs,
+ EVT MemVT, MachineMemOperand *MMO)
+ : X86MaskedGatherScatterSDNode(X86ISD::MSCATTER, Order, dl, VTs, MemVT,
+ MMO) {}
+
+ static bool classof(const SDNode *N) {
+ return N->getOpcode() == X86ISD::MSCATTER;
+ }
+ };
+
/// Generate unpacklo/unpackhi shuffle mask.
template <typename T = int>
void createUnpackShuffleMask(MVT VT, SmallVectorImpl<T> &Mask, bool Lo,
SDTCisPtrTy<4>]>,
[SDNPHasChain, SDNPMayLoad, SDNPMemOperand]>;
+def X86masked_scatter : SDNode<"X86ISD::MSCATTER",
+ SDTypeProfile<1, 3, [SDTCisVec<0>, SDTCisVec<1>,
+ SDTCisSameAs<0, 2>,
+ SDTCVecEltisVT<0, i1>,
+ SDTCisPtrTy<3>]>,
+ [SDNPHasChain, SDNPMayStore, SDNPMemOperand]>;
+
def mgatherv4i32 : PatFrag<(ops node:$src1, node:$src2, node:$src3),
(X86masked_gather node:$src1, node:$src2, node:$src3) , [{
X86MaskedGatherSDNode *Mgt = cast<X86MaskedGatherSDNode>(N);
}]>;
def mscatterv2i64 : PatFrag<(ops node:$src1, node:$src2, node:$src3),
- (masked_scatter node:$src1, node:$src2, node:$src3) , [{
- MaskedScatterSDNode *Sc = cast<MaskedScatterSDNode>(N);
+ (X86masked_scatter node:$src1, node:$src2, node:$src3) , [{
+ X86MaskedScatterSDNode *Sc = cast<X86MaskedScatterSDNode>(N);
return Sc->getIndex().getValueType() == MVT::v2i64;
}]>;
def mscatterv4i32 : PatFrag<(ops node:$src1, node:$src2, node:$src3),
- (masked_scatter node:$src1, node:$src2, node:$src3) , [{
- MaskedScatterSDNode *Sc = cast<MaskedScatterSDNode>(N);
+ (X86masked_scatter node:$src1, node:$src2, node:$src3) , [{
+ X86MaskedScatterSDNode *Sc = cast<X86MaskedScatterSDNode>(N);
return Sc->getIndex().getValueType() == MVT::v4i32;
}]>;
def mscatterv4i64 : PatFrag<(ops node:$src1, node:$src2, node:$src3),
- (masked_scatter node:$src1, node:$src2, node:$src3) , [{
- MaskedScatterSDNode *Sc = cast<MaskedScatterSDNode>(N);
+ (X86masked_scatter node:$src1, node:$src2, node:$src3) , [{
+ X86MaskedScatterSDNode *Sc = cast<X86MaskedScatterSDNode>(N);
return Sc->getIndex().getValueType() == MVT::v4i64;
}]>;
def mscatterv8i32 : PatFrag<(ops node:$src1, node:$src2, node:$src3),
- (masked_scatter node:$src1, node:$src2, node:$src3) , [{
- MaskedScatterSDNode *Sc = cast<MaskedScatterSDNode>(N);
+ (X86masked_scatter node:$src1, node:$src2, node:$src3) , [{
+ X86MaskedScatterSDNode *Sc = cast<X86MaskedScatterSDNode>(N);
return Sc->getIndex().getValueType() == MVT::v8i32;
}]>;
def mscatterv8i64 : PatFrag<(ops node:$src1, node:$src2, node:$src3),
- (masked_scatter node:$src1, node:$src2, node:$src3) , [{
- MaskedScatterSDNode *Sc = cast<MaskedScatterSDNode>(N);
+ (X86masked_scatter node:$src1, node:$src2, node:$src3) , [{
+ X86MaskedScatterSDNode *Sc = cast<X86MaskedScatterSDNode>(N);
return Sc->getIndex().getValueType() == MVT::v8i64;
}]>;
def mscatterv16i32 : PatFrag<(ops node:$src1, node:$src2, node:$src3),
- (masked_scatter node:$src1, node:$src2, node:$src3) , [{
- MaskedScatterSDNode *Sc = cast<MaskedScatterSDNode>(N);
+ (X86masked_scatter node:$src1, node:$src2, node:$src3) , [{
+ X86MaskedScatterSDNode *Sc = cast<X86MaskedScatterSDNode>(N);
return Sc->getIndex().getValueType() == MVT::v16i32;
}]>;