]> granicus.if.org Git - llvm/commitdiff
[X86][SSE] Improve lowering of vXi64 multiply with known zero 32-bit halves
authorSimon Pilgrim <llvm-dev@redking.me.uk>
Thu, 17 Nov 2016 12:14:49 +0000 (12:14 +0000)
committerSimon Pilgrim <llvm-dev@redking.me.uk>
Thu, 17 Nov 2016 12:14:49 +0000 (12:14 +0000)
vXi64 multiplication is lowered into 3 calls of vpmuludq with the upper/lower 32-bit halves.

If any of these halves are zero then we can remove individual calls. Although there was isBuildVectorAllZeros code to do this I don't think it ever worked (maybe just for constant folded cases that don't seem to be tested for any longer).

This requires additional X86ISD support for computeKnownBitsForTargetNode, so far I've just added support for X86ISD::VZEXT (VPMOVZX* - helping the AVX2+ cases).

Partial fix for PR30845

Differential Revision: https://reviews.llvm.org/D26590

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@287223 91177308-0d34-0410-b5e6-96231b3b80d8

lib/Target/X86/X86ISelLowering.cpp
test/CodeGen/X86/pmul.ll

index 5d81debb5f6f11318777072185cf978572e43a4c..838e75b4b36edc8084ef7a6927d0ac911be166a3 100644 (file)
@@ -20144,33 +20144,43 @@ static SDValue LowerMUL(SDValue Op, const X86Subtarget &Subtarget,
   //  AloBhi = psllqi(AloBhi, 32);
   //  AhiBlo = psllqi(AhiBlo, 32);
   //  return AloBlo + AloBhi + AhiBlo;
+  APInt LowerBitsMask = APInt::getLowBitsSet(64, 32);
+  bool ALoiIsZero = DAG.MaskedValueIsZero(A, LowerBitsMask);
+  bool BLoiIsZero = DAG.MaskedValueIsZero(B, LowerBitsMask);
 
-  SDValue Ahi = getTargetVShiftByConstNode(X86ISD::VSRLI, dl, VT, A, 32, DAG);
-  SDValue Bhi = getTargetVShiftByConstNode(X86ISD::VSRLI, dl, VT, B, 32, DAG);
+  APInt UpperBitsMask = APInt::getHighBitsSet(64, 32);
+  bool AHiIsZero = DAG.MaskedValueIsZero(A, UpperBitsMask);
+  bool BHiIsZero = DAG.MaskedValueIsZero(B, UpperBitsMask);
 
-  SDValue AhiBlo = Ahi;
-  SDValue AloBhi = Bhi;
   // Bit cast to 32-bit vectors for MULUDQ
   MVT MulVT = (VT == MVT::v2i64) ? MVT::v4i32 :
                                   (VT == MVT::v4i64) ? MVT::v8i32 : MVT::v16i32;
-  A = DAG.getBitcast(MulVT, A);
-  B = DAG.getBitcast(MulVT, B);
-  Ahi = DAG.getBitcast(MulVT, Ahi);
-  Bhi = DAG.getBitcast(MulVT, Bhi);
-
-  SDValue AloBlo = DAG.getNode(X86ISD::PMULUDQ, dl, VT, A, B);
-  // After shifting right const values the result may be all-zero.
-  if (!ISD::isBuildVectorAllZeros(Ahi.getNode())) {
-    AhiBlo = DAG.getNode(X86ISD::PMULUDQ, dl, VT, Ahi, B);
-    AhiBlo = getTargetVShiftByConstNode(X86ISD::VSHLI, dl, VT, AhiBlo, 32, DAG);
-  }
-  if (!ISD::isBuildVectorAllZeros(Bhi.getNode())) {
-    AloBhi = DAG.getNode(X86ISD::PMULUDQ, dl, VT, A, Bhi);
+  SDValue Alo = DAG.getBitcast(MulVT, A);
+  SDValue Blo = DAG.getBitcast(MulVT, B);
+
+  SDValue Res;
+
+  // Only multiply lo/hi halves that aren't known to be zero.
+  if (!ALoiIsZero && !BLoiIsZero)
+    Res = DAG.getNode(X86ISD::PMULUDQ, dl, VT, Alo, Blo);
+
+  if (!ALoiIsZero && !BHiIsZero) {
+    SDValue Bhi = getTargetVShiftByConstNode(X86ISD::VSRLI, dl, VT, B, 32, DAG);
+    Bhi = DAG.getBitcast(MulVT, Bhi);
+    SDValue AloBhi = DAG.getNode(X86ISD::PMULUDQ, dl, VT, Alo, Bhi);
     AloBhi = getTargetVShiftByConstNode(X86ISD::VSHLI, dl, VT, AloBhi, 32, DAG);
+    Res = (Res.getNode() ? DAG.getNode(ISD::ADD, dl, VT, Res, AloBhi) : AloBhi);
   }
 
-  SDValue Res = DAG.getNode(ISD::ADD, dl, VT, AloBlo, AloBhi);
-  return DAG.getNode(ISD::ADD, dl, VT, Res, AhiBlo);
+  if (!AHiIsZero && !BLoiIsZero) {
+    SDValue Ahi = getTargetVShiftByConstNode(X86ISD::VSRLI, dl, VT, A, 32, DAG);
+    Ahi = DAG.getBitcast(MulVT, Ahi);
+    SDValue AhiBlo = DAG.getNode(X86ISD::PMULUDQ, dl, VT, Ahi, Blo);
+    AhiBlo = getTargetVShiftByConstNode(X86ISD::VSHLI, dl, VT, AhiBlo, 32, DAG);
+    Res = (Res.getNode() ? DAG.getNode(ISD::ADD, dl, VT, Res, AhiBlo) : AhiBlo);
+  }
+
+  return (Res.getNode() ? Res : getZeroVector(VT, Subtarget, DAG, dl));
 }
 
 static SDValue LowerMULH(SDValue Op, const X86Subtarget &Subtarget,
@@ -25256,6 +25266,20 @@ void X86TargetLowering::computeKnownBitsForTargetNode(const SDValue Op,
     KnownZero = APInt::getHighBitsSet(BitWidth, BitWidth - NumLoBits);
     break;
   }
+  case X86ISD::VZEXT: {
+    SDValue N0 = Op.getOperand(0);
+    unsigned NumElts = Op.getValueType().getVectorNumElements();
+    unsigned InNumElts = N0.getValueType().getVectorNumElements();
+    unsigned InBitWidth = N0.getValueType().getScalarSizeInBits();
+
+    KnownZero = KnownOne = APInt(InBitWidth, 0);
+    APInt DemandedElts = APInt::getLowBitsSet(InNumElts, NumElts);
+    DAG.computeKnownBits(N0, KnownZero, KnownOne, DemandedElts, Depth + 1);
+    KnownOne = KnownOne.zext(BitWidth);
+    KnownZero = KnownZero.zext(BitWidth);
+    KnownZero |= APInt::getHighBitsSet(BitWidth, BitWidth - InBitWidth);
+    break;
+  }
   }
 }
 
index 2117c9de5ff2e8bb04b2c145907388fdf8a9acdf..b72f6cf6328dba2d02e48ee15c3ce9efcfa15a69 100644 (file)
@@ -1225,15 +1225,7 @@ define <4 x i32> @mul_v4i64_zero_upper(<4 x i32> %val1, <4 x i32> %val2) {
 ; AVX2:       # BB#0: # %entry
 ; AVX2-NEXT:    vpmovzxdq {{.*#+}} ymm0 = xmm0[0],zero,xmm0[1],zero,xmm0[2],zero,xmm0[3],zero
 ; AVX2-NEXT:    vpmovzxdq {{.*#+}} ymm1 = xmm1[0],zero,xmm1[1],zero,xmm1[2],zero,xmm1[3],zero
-; AVX2-NEXT:    vpmuludq %ymm1, %ymm0, %ymm2
-; AVX2-NEXT:    vpsrlq $32, %ymm1, %ymm3
-; AVX2-NEXT:    vpmuludq %ymm3, %ymm0, %ymm3
-; AVX2-NEXT:    vpsllq $32, %ymm3, %ymm3
-; AVX2-NEXT:    vpsrlq $32, %ymm0, %ymm0
 ; AVX2-NEXT:    vpmuludq %ymm1, %ymm0, %ymm0
-; AVX2-NEXT:    vpsllq $32, %ymm0, %ymm0
-; AVX2-NEXT:    vpaddq %ymm0, %ymm3, %ymm0
-; AVX2-NEXT:    vpaddq %ymm0, %ymm2, %ymm0
 ; AVX2-NEXT:    vextracti128 $1, %ymm0, %xmm1
 ; AVX2-NEXT:    vpshufd {{.*#+}} xmm1 = xmm1[0,1,1,3]
 ; AVX2-NEXT:    vpshufd {{.*#+}} xmm0 = xmm0[1,3,2,3]
@@ -1245,15 +1237,7 @@ define <4 x i32> @mul_v4i64_zero_upper(<4 x i32> %val1, <4 x i32> %val2) {
 ; AVX512:       # BB#0: # %entry
 ; AVX512-NEXT:    vpmovzxdq {{.*#+}} ymm0 = xmm0[0],zero,xmm0[1],zero,xmm0[2],zero,xmm0[3],zero
 ; AVX512-NEXT:    vpmovzxdq {{.*#+}} ymm1 = xmm1[0],zero,xmm1[1],zero,xmm1[2],zero,xmm1[3],zero
-; AVX512-NEXT:    vpmuludq %ymm1, %ymm0, %ymm2
-; AVX512-NEXT:    vpsrlq $32, %ymm1, %ymm3
-; AVX512-NEXT:    vpmuludq %ymm3, %ymm0, %ymm3
-; AVX512-NEXT:    vpsllq $32, %ymm3, %ymm3
-; AVX512-NEXT:    vpsrlq $32, %ymm0, %ymm0
 ; AVX512-NEXT:    vpmuludq %ymm1, %ymm0, %ymm0
-; AVX512-NEXT:    vpsllq $32, %ymm0, %ymm0
-; AVX512-NEXT:    vpaddq %ymm0, %ymm3, %ymm0
-; AVX512-NEXT:    vpaddq %ymm0, %ymm2, %ymm0
 ; AVX512-NEXT:    vextracti128 $1, %ymm0, %xmm1
 ; AVX512-NEXT:    vpshufd {{.*#+}} xmm1 = xmm1[0,1,1,3]
 ; AVX512-NEXT:    vpshufd {{.*#+}} xmm0 = xmm0[1,3,2,3]
@@ -1338,13 +1322,9 @@ define <4 x i32> @mul_v4i64_zero_upper_left(<4 x i32> %val1, <4 x i64> %val2) {
 ; AVX2:       # BB#0: # %entry
 ; AVX2-NEXT:    vpmovzxdq {{.*#+}} ymm0 = xmm0[0],zero,xmm0[1],zero,xmm0[2],zero,xmm0[3],zero
 ; AVX2-NEXT:    vpmuludq %ymm1, %ymm0, %ymm2
-; AVX2-NEXT:    vpsrlq $32, %ymm1, %ymm3
-; AVX2-NEXT:    vpmuludq %ymm3, %ymm0, %ymm3
-; AVX2-NEXT:    vpsllq $32, %ymm3, %ymm3
-; AVX2-NEXT:    vpsrlq $32, %ymm0, %ymm0
+; AVX2-NEXT:    vpsrlq $32, %ymm1, %ymm1
 ; AVX2-NEXT:    vpmuludq %ymm1, %ymm0, %ymm0
 ; AVX2-NEXT:    vpsllq $32, %ymm0, %ymm0
-; AVX2-NEXT:    vpaddq %ymm0, %ymm3, %ymm0
 ; AVX2-NEXT:    vpaddq %ymm0, %ymm2, %ymm0
 ; AVX2-NEXT:    vextracti128 $1, %ymm0, %xmm1
 ; AVX2-NEXT:    vpshufd {{.*#+}} xmm1 = xmm1[0,1,1,3]
@@ -1357,13 +1337,9 @@ define <4 x i32> @mul_v4i64_zero_upper_left(<4 x i32> %val1, <4 x i64> %val2) {
 ; AVX512:       # BB#0: # %entry
 ; AVX512-NEXT:    vpmovzxdq {{.*#+}} ymm0 = xmm0[0],zero,xmm0[1],zero,xmm0[2],zero,xmm0[3],zero
 ; AVX512-NEXT:    vpmuludq %ymm1, %ymm0, %ymm2
-; AVX512-NEXT:    vpsrlq $32, %ymm1, %ymm3
-; AVX512-NEXT:    vpmuludq %ymm3, %ymm0, %ymm3
-; AVX512-NEXT:    vpsllq $32, %ymm3, %ymm3
-; AVX512-NEXT:    vpsrlq $32, %ymm0, %ymm0
+; AVX512-NEXT:    vpsrlq $32, %ymm1, %ymm1
 ; AVX512-NEXT:    vpmuludq %ymm1, %ymm0, %ymm0
 ; AVX512-NEXT:    vpsllq $32, %ymm0, %ymm0
-; AVX512-NEXT:    vpaddq %ymm0, %ymm3, %ymm0
 ; AVX512-NEXT:    vpaddq %ymm0, %ymm2, %ymm0
 ; AVX512-NEXT:    vextracti128 $1, %ymm0, %xmm1
 ; AVX512-NEXT:    vpshufd {{.*#+}} xmm1 = xmm1[0,1,1,3]
@@ -1455,13 +1431,9 @@ define <4 x i32> @mul_v4i64_zero_lower(<4 x i32> %val1, <4 x i64> %val2) {
 ; AVX2-NEXT:    vpxor %ymm2, %ymm2, %ymm2
 ; AVX2-NEXT:    vpblendd {{.*#+}} ymm1 = ymm2[0],ymm1[1],ymm2[2],ymm1[3],ymm2[4],ymm1[5],ymm2[6],ymm1[7]
 ; AVX2-NEXT:    vpmuludq %ymm1, %ymm0, %ymm2
-; AVX2-NEXT:    vpsrlq $32, %ymm1, %ymm3
-; AVX2-NEXT:    vpmuludq %ymm3, %ymm0, %ymm3
-; AVX2-NEXT:    vpsllq $32, %ymm3, %ymm3
-; AVX2-NEXT:    vpsrlq $32, %ymm0, %ymm0
+; AVX2-NEXT:    vpsrlq $32, %ymm1, %ymm1
 ; AVX2-NEXT:    vpmuludq %ymm1, %ymm0, %ymm0
 ; AVX2-NEXT:    vpsllq $32, %ymm0, %ymm0
-; AVX2-NEXT:    vpaddq %ymm0, %ymm3, %ymm0
 ; AVX2-NEXT:    vpaddq %ymm0, %ymm2, %ymm0
 ; AVX2-NEXT:    vextracti128 $1, %ymm0, %xmm1
 ; AVX2-NEXT:    vpshufd {{.*#+}} xmm1 = xmm1[0,1,1,3]
@@ -1476,13 +1448,9 @@ define <4 x i32> @mul_v4i64_zero_lower(<4 x i32> %val1, <4 x i64> %val2) {
 ; AVX512-NEXT:    vpxor %ymm2, %ymm2, %ymm2
 ; AVX512-NEXT:    vpblendd {{.*#+}} ymm1 = ymm2[0],ymm1[1],ymm2[2],ymm1[3],ymm2[4],ymm1[5],ymm2[6],ymm1[7]
 ; AVX512-NEXT:    vpmuludq %ymm1, %ymm0, %ymm2
-; AVX512-NEXT:    vpsrlq $32, %ymm1, %ymm3
-; AVX512-NEXT:    vpmuludq %ymm3, %ymm0, %ymm3
-; AVX512-NEXT:    vpsllq $32, %ymm3, %ymm3
-; AVX512-NEXT:    vpsrlq $32, %ymm0, %ymm0
+; AVX512-NEXT:    vpsrlq $32, %ymm1, %ymm1
 ; AVX512-NEXT:    vpmuludq %ymm1, %ymm0, %ymm0
 ; AVX512-NEXT:    vpsllq $32, %ymm0, %ymm0
-; AVX512-NEXT:    vpaddq %ymm0, %ymm3, %ymm0
 ; AVX512-NEXT:    vpaddq %ymm0, %ymm2, %ymm0
 ; AVX512-NEXT:    vextracti128 $1, %ymm0, %xmm1
 ; AVX512-NEXT:    vpshufd {{.*#+}} xmm1 = xmm1[0,1,1,3]