]> granicus.if.org Git - llvm/commitdiff
[X86] Stop using UpdateNodeOperands in combineGatherScatter. Create new nodes like...
authorCraig Topper <craig.topper@intel.com>
Sat, 28 Sep 2019 01:08:46 +0000 (01:08 +0000)
committerCraig Topper <craig.topper@intel.com>
Sat, 28 Sep 2019 01:08:46 +0000 (01:08 +0000)
Creating new nodes is what we usually do. Have to explicitly
check that we don't update to an existing node and having
to manually manage the worklist is unusual.

We can probably add a helper function to reduce the duplication
of having to check if we should create a gather or scatter, but
I wanted to just get the simple thing done.

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@373137 91177308-0d34-0410-b5e6-96231b3b80d8

lib/Target/X86/X86ISelLowering.cpp

index 5741b80d1a7e9411242a5b17a063a738d0bbad29..521fc3cd37b82c40c0fde4bbd65021114b421a5a 100644 (file)
@@ -43381,26 +43381,36 @@ static SDValue combineX86GatherScatter(SDNode *N, SelectionDAG &DAG,
 static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG,
                                     TargetLowering::DAGCombinerInfo &DCI) {
   SDLoc DL(N);
+  auto *GorS = cast<MaskedGatherScatterSDNode>(N);
+  SDValue Chain = GorS->getChain();
+  SDValue Index = GorS->getIndex();
+  SDValue Mask = GorS->getMask();
+  SDValue Base = GorS->getBasePtr();
+  SDValue Scale = GorS->getScale();
 
   if (DCI.isBeforeLegalizeOps()) {
-    SDValue Index = N->getOperand(4);
     // Remove any sign extends from 32 or smaller to larger than 32.
     // Only do this before LegalizeOps in case we need the sign extend for
     // legalization.
-    if (Index.getOpcode() == ISD::SIGN_EXTEND) {
-      if (Index.getScalarValueSizeInBits() > 32 &&
-          Index.getOperand(0).getScalarValueSizeInBits() <= 32) {
-        SmallVector<SDValue, 5> NewOps(N->op_begin(), N->op_end());
-        NewOps[4] = Index.getOperand(0);
-        SDNode *Res = DAG.UpdateNodeOperands(N, NewOps);
-        if (Res == N) {
-          // The original sign extend has less users, add back to worklist in
-          // case it needs to be removed
-          DCI.AddToWorklist(Index.getNode());
-          DCI.AddToWorklist(N);
-        }
-        return SDValue(Res, 0);
-      }
+    if (Index.getOpcode() == ISD::SIGN_EXTEND &&
+        Index.getScalarValueSizeInBits() > 32 &&
+        Index.getOperand(0).getScalarValueSizeInBits() <= 32) {
+      Index = Index.getOperand(0);
+      if (auto *Gather = dyn_cast<MaskedGatherSDNode>(GorS)) {
+        SDValue Ops[] = { Chain, Gather->getPassThru(),
+                          Mask, Base, Index, Scale } ;
+        return DAG.getMaskedGather(Gather->getVTList(),
+                                   Gather->getMemoryVT(), DL, Ops,
+                                   Gather->getMemOperand(),
+                                   Gather->getIndexType());
+      }
+      auto *Scatter = cast<MaskedScatterSDNode>(GorS);
+      SDValue Ops[] = { Chain, Scatter->getValue(),
+                        Mask, Base, Index, Scale };
+      return DAG.getMaskedScatter(Scatter->getVTList(),
+                                  Scatter->getMemoryVT(), DL,
+                                  Ops, Scatter->getMemOperand(),
+                                  Scatter->getIndexType());
     }
 
     // Make sure the index is either i32 or i64
@@ -43410,36 +43420,49 @@ static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG,
       EVT IndexVT = EVT::getVectorVT(*DAG.getContext(), EltVT,
                                    Index.getValueType().getVectorNumElements());
       Index = DAG.getSExtOrTrunc(Index, DL, IndexVT);
-      SmallVector<SDValue, 5> NewOps(N->op_begin(), N->op_end());
-      NewOps[4] = Index;
-      SDNode *Res = DAG.UpdateNodeOperands(N, NewOps);
-      if (Res == N)
-        DCI.AddToWorklist(N);
-      return SDValue(Res, 0);
+      if (auto *Gather = dyn_cast<MaskedGatherSDNode>(GorS)) {
+        SDValue Ops[] = { Chain, Gather->getPassThru(),
+                          Mask, Base, Index, Scale } ;
+        return DAG.getMaskedGather(Gather->getVTList(),
+                                   Gather->getMemoryVT(), DL, Ops,
+                                   Gather->getMemOperand(),
+                                   Gather->getIndexType());
+      }
+      auto *Scatter = cast<MaskedScatterSDNode>(GorS);
+      SDValue Ops[] = { Chain, Scatter->getValue(),
+                        Mask, Base, Index, Scale };
+      return DAG.getMaskedScatter(Scatter->getVTList(),
+                                  Scatter->getMemoryVT(), DL,
+                                  Ops, Scatter->getMemOperand(),
+                                  Scatter->getIndexType());
     }
 
     // Try to remove zero extends from 32->64 if we know the sign bit of
     // the input is zero.
     if (Index.getOpcode() == ISD::ZERO_EXTEND &&
         Index.getScalarValueSizeInBits() == 64 &&
-        Index.getOperand(0).getScalarValueSizeInBits() == 32) {
-      if (DAG.SignBitIsZero(Index.getOperand(0))) {
-        SmallVector<SDValue, 5> NewOps(N->op_begin(), N->op_end());
-        NewOps[4] = Index.getOperand(0);
-        SDNode *Res = DAG.UpdateNodeOperands(N, NewOps);
-        if (Res == N) {
-          // The original sign extend has less users, add back to worklist in
-          // case it needs to be removed
-          DCI.AddToWorklist(Index.getNode());
-          DCI.AddToWorklist(N);
-        }
-        return SDValue(Res, 0);
-      }
+        Index.getOperand(0).getScalarValueSizeInBits() == 32 &&
+        DAG.SignBitIsZero(Index.getOperand(0))) {
+      Index = Index.getOperand(0);
+      if (auto *Gather = dyn_cast<MaskedGatherSDNode>(GorS)) {
+        SDValue Ops[] = { Chain, Gather->getPassThru(),
+                          Mask, Base, Index, Scale } ;
+        return DAG.getMaskedGather(Gather->getVTList(),
+                                   Gather->getMemoryVT(), DL, Ops,
+                                   Gather->getMemOperand(),
+                                   Gather->getIndexType());
+      }
+      auto *Scatter = cast<MaskedScatterSDNode>(GorS);
+      SDValue Ops[] = { Chain, Scatter->getValue(),
+                        Mask, Base, Index, Scale };
+      return DAG.getMaskedScatter(Scatter->getVTList(),
+                                  Scatter->getMemoryVT(), DL,
+                                  Ops, Scatter->getMemOperand(),
+                                  Scatter->getIndexType());
     }
   }
 
   // With vector masks we only demand the upper bit of the mask.
-  SDValue Mask = cast<MaskedGatherScatterSDNode>(N)->getMask();
   if (Mask.getScalarValueSizeInBits() != 1) {
     const TargetLowering &TLI = DAG.getTargetLoweringInfo();
     APInt DemandedMask(APInt::getSignMask(Mask.getScalarValueSizeInBits()));