From: Craig Topper Date: Sat, 28 Sep 2019 01:08:46 +0000 (+0000) Subject: [X86] Stop using UpdateNodeOperands in combineGatherScatter. Create new nodes like... X-Git-Url: https://granicus.if.org/sourcecode?a=commitdiff_plain;h=f5088332d0b5abce4acb99f883ba3a4f1033e46c;p=llvm [X86] Stop using UpdateNodeOperands in combineGatherScatter. Create new nodes like most other DAG combines. Creating new nodes is what we usually do. Have to explicitly check that we don't update to an existing node and having to manually manage the worklist is unusual. We can probably add a helper function to reduce the duplication of having to check if we should create a gather or scatter, but I wanted to just get the simple thing done. git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@373137 91177308-0d34-0410-b5e6-96231b3b80d8 --- diff --git a/lib/Target/X86/X86ISelLowering.cpp b/lib/Target/X86/X86ISelLowering.cpp index 5741b80d1a7..521fc3cd37b 100644 --- a/lib/Target/X86/X86ISelLowering.cpp +++ b/lib/Target/X86/X86ISelLowering.cpp @@ -43381,26 +43381,36 @@ static SDValue combineX86GatherScatter(SDNode *N, SelectionDAG &DAG, static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG, TargetLowering::DAGCombinerInfo &DCI) { SDLoc DL(N); + auto *GorS = cast(N); + SDValue Chain = GorS->getChain(); + SDValue Index = GorS->getIndex(); + SDValue Mask = GorS->getMask(); + SDValue Base = GorS->getBasePtr(); + SDValue Scale = GorS->getScale(); if (DCI.isBeforeLegalizeOps()) { - SDValue Index = N->getOperand(4); // Remove any sign extends from 32 or smaller to larger than 32. // Only do this before LegalizeOps in case we need the sign extend for // legalization. - if (Index.getOpcode() == ISD::SIGN_EXTEND) { - if (Index.getScalarValueSizeInBits() > 32 && - Index.getOperand(0).getScalarValueSizeInBits() <= 32) { - SmallVector NewOps(N->op_begin(), N->op_end()); - NewOps[4] = Index.getOperand(0); - SDNode *Res = DAG.UpdateNodeOperands(N, NewOps); - if (Res == N) { - // The original sign extend has less users, add back to worklist in - // case it needs to be removed - DCI.AddToWorklist(Index.getNode()); - DCI.AddToWorklist(N); - } - return SDValue(Res, 0); - } + if (Index.getOpcode() == ISD::SIGN_EXTEND && + Index.getScalarValueSizeInBits() > 32 && + Index.getOperand(0).getScalarValueSizeInBits() <= 32) { + Index = Index.getOperand(0); + if (auto *Gather = dyn_cast(GorS)) { + SDValue Ops[] = { Chain, Gather->getPassThru(), + Mask, Base, Index, Scale } ; + return DAG.getMaskedGather(Gather->getVTList(), + Gather->getMemoryVT(), DL, Ops, + Gather->getMemOperand(), + Gather->getIndexType()); + } + auto *Scatter = cast(GorS); + SDValue Ops[] = { Chain, Scatter->getValue(), + Mask, Base, Index, Scale }; + return DAG.getMaskedScatter(Scatter->getVTList(), + Scatter->getMemoryVT(), DL, + Ops, Scatter->getMemOperand(), + Scatter->getIndexType()); } // Make sure the index is either i32 or i64 @@ -43410,36 +43420,49 @@ static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG, EVT IndexVT = EVT::getVectorVT(*DAG.getContext(), EltVT, Index.getValueType().getVectorNumElements()); Index = DAG.getSExtOrTrunc(Index, DL, IndexVT); - SmallVector NewOps(N->op_begin(), N->op_end()); - NewOps[4] = Index; - SDNode *Res = DAG.UpdateNodeOperands(N, NewOps); - if (Res == N) - DCI.AddToWorklist(N); - return SDValue(Res, 0); + if (auto *Gather = dyn_cast(GorS)) { + SDValue Ops[] = { Chain, Gather->getPassThru(), + Mask, Base, Index, Scale } ; + return DAG.getMaskedGather(Gather->getVTList(), + Gather->getMemoryVT(), DL, Ops, + Gather->getMemOperand(), + Gather->getIndexType()); + } + auto *Scatter = cast(GorS); + SDValue Ops[] = { Chain, Scatter->getValue(), + Mask, Base, Index, Scale }; + return DAG.getMaskedScatter(Scatter->getVTList(), + Scatter->getMemoryVT(), DL, + Ops, Scatter->getMemOperand(), + Scatter->getIndexType()); } // Try to remove zero extends from 32->64 if we know the sign bit of // the input is zero. if (Index.getOpcode() == ISD::ZERO_EXTEND && Index.getScalarValueSizeInBits() == 64 && - Index.getOperand(0).getScalarValueSizeInBits() == 32) { - if (DAG.SignBitIsZero(Index.getOperand(0))) { - SmallVector NewOps(N->op_begin(), N->op_end()); - NewOps[4] = Index.getOperand(0); - SDNode *Res = DAG.UpdateNodeOperands(N, NewOps); - if (Res == N) { - // The original sign extend has less users, add back to worklist in - // case it needs to be removed - DCI.AddToWorklist(Index.getNode()); - DCI.AddToWorklist(N); - } - return SDValue(Res, 0); - } + Index.getOperand(0).getScalarValueSizeInBits() == 32 && + DAG.SignBitIsZero(Index.getOperand(0))) { + Index = Index.getOperand(0); + if (auto *Gather = dyn_cast(GorS)) { + SDValue Ops[] = { Chain, Gather->getPassThru(), + Mask, Base, Index, Scale } ; + return DAG.getMaskedGather(Gather->getVTList(), + Gather->getMemoryVT(), DL, Ops, + Gather->getMemOperand(), + Gather->getIndexType()); + } + auto *Scatter = cast(GorS); + SDValue Ops[] = { Chain, Scatter->getValue(), + Mask, Base, Index, Scale }; + return DAG.getMaskedScatter(Scatter->getVTList(), + Scatter->getMemoryVT(), DL, + Ops, Scatter->getMemOperand(), + Scatter->getIndexType()); } } // With vector masks we only demand the upper bit of the mask. - SDValue Mask = cast(N)->getMask(); if (Mask.getScalarValueSizeInBits() != 1) { const TargetLowering &TLI = DAG.getTargetLoweringInfo(); APInt DemandedMask(APInt::getSignMask(Mask.getScalarValueSizeInBits()));