]> granicus.if.org Git - llvm/commitdiff
[X86][SSE] Generalize INSERTPS/SHUFPS/SHUFPD combines across domains.
authorSimon Pilgrim <llvm-dev@redking.me.uk>
Sun, 19 Feb 2017 15:15:40 +0000 (15:15 +0000)
committerSimon Pilgrim <llvm-dev@redking.me.uk>
Sun, 19 Feb 2017 15:15:40 +0000 (15:15 +0000)
Relax the INSERTPS/SHUFPS/SHUFPD combines to support integer inputs if permitted.

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@295606 91177308-0d34-0410-b5e6-96231b3b80d8

lib/Target/X86/X86ISelLowering.cpp

index 1ff304457163c40fdaab43547664c6a557b90540..88651a2d22790679186748192e11c220dbbd51c0 100644 (file)
@@ -12235,7 +12235,7 @@ static bool matchVectorShuffleWithSHUFPD(MVT VT, SDValue &V1, SDValue &V2,
                                          unsigned &ShuffleImm,
                                          ArrayRef<int> Mask) {
   int NumElts = VT.getVectorNumElements();
-  assert(VT.getScalarType() == MVT::f64 &&
+  assert(VT.getScalarSizeInBits() == 64 &&
          (NumElts == 2 || NumElts == 4 || NumElts == 8) &&
          "Unexpected data type for VSHUFPD");
 
@@ -12271,6 +12271,9 @@ static bool matchVectorShuffleWithSHUFPD(MVT VT, SDValue &V1, SDValue &V2,
 static SDValue lowerVectorShuffleWithSHUFPD(const SDLoc &DL, MVT VT,
                                             ArrayRef<int> Mask, SDValue V1,
                                             SDValue V2, SelectionDAG &DAG) {
+  assert((VT == MVT::v2f64 || VT == MVT::v4f64 || VT == MVT::v8f64)&&
+         "Unexpected data type for VSHUFPD");
+
   unsigned Immediate = 0;
   if (!matchVectorShuffleWithSHUFPD(VT, V1, V2, Immediate, Mask))
     return SDValue();
@@ -26691,13 +26694,15 @@ static bool matchBinaryVectorShuffle(MVT MaskVT, ArrayRef<int> Mask,
 }
 
 static bool matchBinaryPermuteVectorShuffle(MVT MaskVT, ArrayRef<int> Mask,
-                                            bool AllowIntDomain, SDValue &V1,
-                                            SDValue &V2, SDLoc &DL,
+                                            bool AllowFloatDomain,
+                                            bool AllowIntDomain,
+                                            SDValue &V1, SDValue &V2, SDLoc &DL,
                                             SelectionDAG &DAG,
                                             const X86Subtarget &Subtarget,
                                             unsigned &Shuffle, MVT &ShuffleVT,
                                             unsigned &PermuteImm) {
   unsigned NumMaskElts = Mask.size();
+  unsigned EltSizeInBits = MaskVT.getScalarSizeInBits();
 
   // Attempt to match against PALIGNR byte rotate.
   if (AllowIntDomain && ((MaskVT.is128BitVector() && Subtarget.hasSSSE3()) ||
@@ -26777,7 +26782,8 @@ static bool matchBinaryPermuteVectorShuffle(MVT MaskVT, ArrayRef<int> Mask,
   }
 
   // Attempt to combine to INSERTPS.
-  if (Subtarget.hasSSE41() && MaskVT == MVT::v4f32) {
+  if (AllowFloatDomain && EltSizeInBits == 32 && Subtarget.hasSSE41() &&
+      MaskVT.is128BitVector()) {
     SmallBitVector Zeroable(4, false);
     for (unsigned i = 0; i != NumMaskElts; ++i)
       if (Mask[i] < 0)
@@ -26792,20 +26798,22 @@ static bool matchBinaryPermuteVectorShuffle(MVT MaskVT, ArrayRef<int> Mask,
   }
 
   // Attempt to combine to SHUFPD.
-  if ((MaskVT == MVT::v2f64 && Subtarget.hasSSE2()) ||
-      (MaskVT == MVT::v4f64 && Subtarget.hasAVX()) ||
-      (MaskVT == MVT::v8f64 && Subtarget.hasAVX512())) {
+  if (AllowFloatDomain && EltSizeInBits == 64 &&
+      ((MaskVT.is128BitVector() && Subtarget.hasSSE2()) ||
+       (MaskVT.is256BitVector() && Subtarget.hasAVX()) ||
+       (MaskVT.is512BitVector() && Subtarget.hasAVX512()))) {
     if (matchVectorShuffleWithSHUFPD(MaskVT, V1, V2, PermuteImm, Mask)) {
       Shuffle = X86ISD::SHUFP;
-      ShuffleVT = MaskVT;
+      ShuffleVT = MVT::getVectorVT(MVT::f64, MaskVT.getSizeInBits() / 64);
       return true;
     }
   }
 
   // Attempt to combine to SHUFPS.
-  if ((MaskVT == MVT::v4f32 && Subtarget.hasSSE1()) ||
-      (MaskVT == MVT::v8f32 && Subtarget.hasAVX()) ||
-      (MaskVT == MVT::v16f32 && Subtarget.hasAVX512())) {
+  if (AllowFloatDomain && EltSizeInBits == 32 &&
+      ((MaskVT.is128BitVector() && Subtarget.hasSSE1()) ||
+       (MaskVT.is256BitVector() && Subtarget.hasAVX()) ||
+       (MaskVT.is512BitVector() && Subtarget.hasAVX512()))) {
     SmallVector<int, 4> RepeatedMask;
     if (isRepeatedTargetShuffleMask(128, MaskVT, Mask, RepeatedMask)) {
       // Match each half of the repeated mask, to determine if its just
@@ -26841,7 +26849,7 @@ static bool matchBinaryPermuteVectorShuffle(MVT MaskVT, ArrayRef<int> Mask,
         V1 = Lo;
         V2 = Hi;
         Shuffle = X86ISD::SHUFP;
-        ShuffleVT = MaskVT;
+        ShuffleVT = MVT::getVectorVT(MVT::f32, MaskVT.getSizeInBits() / 32);
         PermuteImm = getV4X86ShuffleImm(ShufMask);
         return true;
       }
@@ -27035,8 +27043,9 @@ static bool combineX86ShuffleChain(ArrayRef<SDValue> Inputs, SDValue Root,
     return true;
   }
 
-  if (matchBinaryPermuteVectorShuffle(MaskVT, Mask, AllowIntDomain, V1, V2, DL,
-                                      DAG, Subtarget, Shuffle, ShuffleVT,
+  if (matchBinaryPermuteVectorShuffle(MaskVT, Mask, AllowFloatDomain,
+                                      AllowIntDomain, V1, V2, DL, DAG,
+                                      Subtarget, Shuffle, ShuffleVT,
                                       PermuteImm)) {
     if (Depth == 1 && Root.getOpcode() == Shuffle)
       return false; // Nothing to do!