]> granicus.if.org Git - llvm/commitdiff
[TargetLowering] Add buildLegalVectorShuffle facility to help build legal shuffles
authorAmaury Sechet <deadalnix@gmail.com>
Wed, 28 Aug 2019 12:00:06 +0000 (12:00 +0000)
committerAmaury Sechet <deadalnix@gmail.com>
Wed, 28 Aug 2019 12:00:06 +0000 (12:00 +0000)
Summary: There are at least 2 ways to express the same shuffle. Various pieces of code explicit check for both option, but other places do not when they would benefit from doing it. This patches refactor the codebase to use buildLegalVectorShuffle in order to make that behavior more consistent.

Reviewers: craig.topper, efriedma, RKSimon, lebedev.ri

Subscribers: javed.absar, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D66804

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@370190 91177308-0d34-0410-b5e6-96231b3b80d8

include/llvm/CodeGen/TargetLowering.h
lib/CodeGen/SelectionDAG/DAGCombiner.cpp
lib/CodeGen/SelectionDAG/TargetLowering.cpp
lib/Target/ARM/ARMISelLowering.cpp

index 292b5a38ad37b488997963aa54a8410747176768..87145a7d78f6d2366ab7e8447f57af9394b079c3 100644 (file)
@@ -3192,6 +3192,14 @@ public:
       SDValue Op, const APInt &DemandedBits, const APInt &DemandedElts,
       SelectionDAG &DAG, unsigned Depth) const;
 
+  /// Tries to build a legal vector shuffle using the provided parameters
+  /// or equivalent variations. The Mask argument maybe be modified as the
+  /// function tries different variations.
+  /// Returns an empty SDValue if the operation fails.
+  SDValue buildLegalVectorShuffle(EVT VT, const SDLoc &DL, SDValue N0,
+                                  SDValue N1, MutableArrayRef<int> Mask,
+                                  SelectionDAG &DAG) const;
+
   /// This method returns the constant pool value that will be loaded by LD.
   /// NOTE: You must check for implicit extensions of the constant by LD.
   virtual const Constant *getTargetConstantFromLoad(LoadSDNode *LD) const;
index 959493170e7c21b8e804e3fc3fd0f2d5499a8e84..62368501d9ac7f41702401a9342e94b4dd8124a0 100644 (file)
@@ -5904,15 +5904,11 @@ SDValue DAGCombiner::visitOR(SDNode *N) {
           SDValue NewLHS = ZeroN00 ? N0.getOperand(1) : N0.getOperand(0);
           SDValue NewRHS = ZeroN10 ? N1.getOperand(1) : N1.getOperand(0);
 
-          bool LegalMask = TLI.isShuffleMaskLegal(Mask, VT);
-          if (!LegalMask) {
-            std::swap(NewLHS, NewRHS);
-            ShuffleVectorSDNode::commuteMask(Mask);
-            LegalMask = TLI.isShuffleMaskLegal(Mask, VT);
-          }
-
-          if (LegalMask)
-            return DAG.getVectorShuffle(VT, SDLoc(N), NewLHS, NewRHS, Mask);
+          SDValue LegalShuffle =
+              TLI.buildLegalVectorShuffle(VT, SDLoc(N), NewLHS, NewRHS,
+                                          Mask, DAG);
+          if (LegalShuffle)
+            return LegalShuffle;
         }
       }
     }
@@ -11120,15 +11116,10 @@ SDValue DAGCombiner::visitBITCAST(SDNode *N) {
       for (int i = 0; i != MaskScale; ++i)
         NewMask.push_back(M < 0 ? -1 : M * MaskScale + i);
 
-    bool LegalMask = TLI.isShuffleMaskLegal(NewMask, VT);
-    if (!LegalMask) {
-      std::swap(SV0, SV1);
-      ShuffleVectorSDNode::commuteMask(NewMask);
-      LegalMask = TLI.isShuffleMaskLegal(NewMask, VT);
-    }
-
-    if (LegalMask)
-      return DAG.getVectorShuffle(VT, SDLoc(N), SV0, SV1, NewMask);
+    SDValue LegalShuffle =
+        TLI.buildLegalVectorShuffle(VT, SDLoc(N), SV0, SV1, NewMask, DAG);
+    if (LegalShuffle)
+      return LegalShuffle;
   }
 
   return SDValue();
@@ -17252,17 +17243,16 @@ static SDValue reduceBuildVecToShuffleWithZero(SDNode *BV, SelectionDAG &DAG) {
     // the shuffle mask with -1.
   }
 
-  // Turn this into a shuffle with zero if that's legal.
-  EVT VecVT = Extract.getOperand(0).getValueType();
-  if (!DAG.getTargetLoweringInfo().isShuffleMaskLegal(ShufMask, VecVT))
-    return SDValue();
-
   // buildvec undef, ..., (zext (extractelt V, IndexC)), undef... -->
   // bitcast (shuffle V, ZeroVec, VectorMask)
   SDLoc DL(BV);
+  EVT VecVT = Extract.getOperand(0).getValueType();
   SDValue ZeroVec = DAG.getConstant(0, DL, VecVT);
-  SDValue Shuf = DAG.getVectorShuffle(VecVT, DL, Extract.getOperand(0), ZeroVec,
-                                      ShufMask);
+  const TargetLowering &TLI = DAG.getTargetLoweringInfo();
+  SDValue Shuf = TLI.buildLegalVectorShuffle(VecVT, DL, Extract.getOperand(0),
+                                             ZeroVec, ShufMask, DAG);
+  if (!Shuf)
+    return SDValue();
   return DAG.getBitcast(VT, Shuf);
 }
 
@@ -17737,11 +17727,9 @@ static SDValue combineConcatVectorOfExtracts(SDNode *N, SelectionDAG &DAG) {
     }
   }
 
-  if (!DAG.getTargetLoweringInfo().isShuffleMaskLegal(Mask, VT))
-    return SDValue();
-
-  return DAG.getVectorShuffle(VT, SDLoc(N), DAG.getBitcast(VT, SV0),
-                              DAG.getBitcast(VT, SV1), Mask);
+  const TargetLowering &TLI = DAG.getTargetLoweringInfo();
+  return TLI.buildLegalVectorShuffle(VT, SDLoc(N), DAG.getBitcast(VT, SV0),
+                                     DAG.getBitcast(VT, SV1), Mask, DAG);
 }
 
 SDValue DAGCombiner::visitCONCAT_VECTORS(SDNode *N) {
@@ -19077,22 +19065,13 @@ SDValue DAGCombiner::visitVECTOR_SHUFFLE(SDNode *N) {
       SV1 = DAG.getUNDEF(VT);
 
     // Avoid introducing shuffles with illegal mask.
-    if (!TLI.isShuffleMaskLegal(Mask, VT)) {
-      ShuffleVectorSDNode::commuteMask(Mask);
-
-      if (!TLI.isShuffleMaskLegal(Mask, VT))
-        return SDValue();
-
-      //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(B, A, M2)
-      //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(C, A, M2)
-      //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(C, B, M2)
-      std::swap(SV0, SV1);
-    }
-
     //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(A, B, M2)
     //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(A, C, M2)
     //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(B, C, M2)
-    return DAG.getVectorShuffle(VT, SDLoc(N), SV0, SV1, Mask);
+    //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(B, A, M2)
+    //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(C, A, M2)
+    //   shuffle(shuffle(A, B, M0), C, M1) -> shuffle(C, B, M2)
+    return TLI.buildLegalVectorShuffle(VT, SDLoc(N), SV0, SV1, Mask, DAG);
   }
 
   if (SDValue V = foldShuffleOfConcatUndefs(SVN, DAG))
@@ -19115,35 +19094,35 @@ SDValue DAGCombiner::visitSCALAR_TO_VECTOR(SDNode *N) {
       SmallVector<int, 8> NewMask(InVecT.getVectorNumElements(), -1);
       int Elt = C0->getZExtValue();
       NewMask[0] = Elt;
-      SDValue Val;
       // If we have an implict truncate do truncate here as long as it's legal.
       // if it's not legal, this should
       if (VT.getScalarType() != InVal.getValueType() &&
           InVal.getValueType().isScalarInteger() &&
           isTypeLegal(VT.getScalarType())) {
-        Val =
+        SDValue Val =
             DAG.getNode(ISD::TRUNCATE, SDLoc(InVal), VT.getScalarType(), InVal);
         return DAG.getNode(ISD::SCALAR_TO_VECTOR, SDLoc(N), VT, Val);
       }
       if (VT.getScalarType() == InVecT.getScalarType() &&
-          VT.getVectorNumElements() <= InVecT.getVectorNumElements() &&
-          TLI.isShuffleMaskLegal(NewMask, VT)) {
-        Val = DAG.getVectorShuffle(InVecT, SDLoc(N), InVec,
-                                   DAG.getUNDEF(InVecT), NewMask);
-        // If the initial vector is the correct size this shuffle is a
-        // valid result.
-        if (VT == InVecT)
-          return Val;
-        // If not we must truncate the vector.
-        if (VT.getVectorNumElements() != InVecT.getVectorNumElements()) {
-          MVT IdxTy = TLI.getVectorIdxTy(DAG.getDataLayout());
-          SDValue ZeroIdx = DAG.getConstant(0, SDLoc(N), IdxTy);
-          EVT SubVT =
-              EVT::getVectorVT(*DAG.getContext(), InVecT.getVectorElementType(),
-                               VT.getVectorNumElements());
-          Val = DAG.getNode(ISD::EXTRACT_SUBVECTOR, SDLoc(N), SubVT, Val,
-                            ZeroIdx);
-          return Val;
+          VT.getVectorNumElements() <= InVecT.getVectorNumElements()) {
+        SDValue LegalShuffle =
+          TLI.buildLegalVectorShuffle(InVecT, SDLoc(N), InVec,
+                                      DAG.getUNDEF(InVecT), NewMask, DAG);
+        if (LegalShuffle) {
+          // If the initial vector is the correct size this shuffle is a
+          // valid result.
+          if (VT == InVecT)
+            return LegalShuffle;
+          // If not we must truncate the vector.
+          if (VT.getVectorNumElements() != InVecT.getVectorNumElements()) {
+            MVT IdxTy = TLI.getVectorIdxTy(DAG.getDataLayout());
+            SDValue ZeroIdx = DAG.getConstant(0, SDLoc(N), IdxTy);
+            EVT SubVT =
+                EVT::getVectorVT(*DAG.getContext(), InVecT.getVectorElementType(),
+                                 VT.getVectorNumElements());
+            return DAG.getNode(ISD::EXTRACT_SUBVECTOR, SDLoc(N), SubVT,
+                               LegalShuffle, ZeroIdx);
+          }
         }
       }
     }
index f4438d41f65355f5d90515b606dde22345ff0af4..6fd90ef3301db53e8fe5342f075ea202e0f67e30 100644 (file)
@@ -2382,11 +2382,13 @@ bool TargetLowering::SimplifyDemandedVectorElts(
 
     // Update legal shuffle masks based on demanded elements if it won't reduce
     // to Identity which can cause premature removal of the shuffle mask.
-    if (Updated && !IdentityLHS && !IdentityRHS && !TLO.LegalOps &&
-        isShuffleMaskLegal(NewMask, VT))
-      return TLO.CombineTo(Op,
-                           TLO.DAG.getVectorShuffle(VT, DL, Op.getOperand(0),
-                                                    Op.getOperand(1), NewMask));
+    if (Updated && !IdentityLHS && !IdentityRHS && !TLO.LegalOps) {
+      SDValue LegalShuffle =
+          buildLegalVectorShuffle(VT, DL, Op.getOperand(0), Op.getOperand(1),
+                                  NewMask, TLO.DAG);
+      if (LegalShuffle)
+        return TLO.CombineTo(Op, LegalShuffle);
+    }
 
     // Propagate undef/zero elements from LHS/RHS.
     for (unsigned i = 0; i != NumElts; ++i) {
@@ -2624,6 +2626,23 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBitsForTargetNode(
   return SDValue();
 }
 
+SDValue
+TargetLowering::buildLegalVectorShuffle(EVT VT, const SDLoc &DL, SDValue N0,
+                                        SDValue N1, MutableArrayRef<int> Mask,
+                                        SelectionDAG &DAG) const {
+  bool LegalMask = isShuffleMaskLegal(Mask, VT);
+  if (!LegalMask) {
+    std::swap(N0, N1);
+    ShuffleVectorSDNode::commuteMask(Mask);
+    LegalMask = isShuffleMaskLegal(Mask, VT);
+  }
+
+  if (!LegalMask)
+    return SDValue();
+
+  return DAG.getVectorShuffle(VT, DL, N0, N1, Mask);
+}
+
 const Constant *TargetLowering::getTargetConstantFromLoad(LoadSDNode*) const {
   return nullptr;
 }
index 275859a6b912f302f6c282d1ff57dcb360603fb0..a8c8f2123391ba392a8149e6db8fc2e7a4220ea5 100644 (file)
@@ -7323,9 +7323,6 @@ SDValue ARMTargetLowering::ReconstructShuffle(SDValue Op,
       LaneMask[j] = ExtractBase + j;
   }
 
-  // Final check before we try to produce nonsense...
-  if (!isShuffleMaskLegal(Mask, ShuffleVT))
-    return SDValue();
 
   // We can't handle more than two sources. This should have already
   // been checked before this point.
@@ -7335,8 +7332,10 @@ SDValue ARMTargetLowering::ReconstructShuffle(SDValue Op,
   for (unsigned i = 0; i < Sources.size(); ++i)
     ShuffleOps[i] = Sources[i].ShuffleVec;
 
-  SDValue Shuffle = DAG.getVectorShuffle(ShuffleVT, dl, ShuffleOps[0],
-                                         ShuffleOps[1], Mask);
+  SDValue Shuffle = buildLegalVectorShuffle(ShuffleVT, dl, ShuffleOps[0],
+                                            ShuffleOps[1], Mask, DAG);
+  if (!Shuffle)
+    return SDValue();
   return DAG.getNode(ISD::BITCAST, dl, VT, Shuffle);
 }