From: Simon Pilgrim Date: Thu, 20 Jun 2019 14:42:27 +0000 (+0000) Subject: [DAGCombiner] Support (shl (zext (srl x, C)), C) -> (zext (shl (srl x, C), C)) non... X-Git-Url: https://granicus.if.org/sourcecode?a=commitdiff_plain;h=702c4f09d38b53b2ea22e959e09ec2306ca54ff3;p=llvm [DAGCombiner] Support (shl (zext (srl x, C)), C) -> (zext (shl (srl x, C), C)) non-uniform folds. Use matchBinaryPredicate instead of isConstOrConstSplat to let us handle non-uniform shift cases. git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@363929 91177308-0d34-0410-b5e6-96231b3b80d8 --- diff --git a/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index ce568e331cf..4b50ea4785c 100644 --- a/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -7259,25 +7259,27 @@ SDValue DAGCombiner::visitSHL(SDNode *N) { // fold (shl (zext (srl x, C)), C) -> (zext (shl (srl x, C), C)) // Only fold this if the inner zext has no other uses to avoid increasing // the total number of instructions. - // TODO - support non-uniform vector shift amounts. - if (N1C && N0.getOpcode() == ISD::ZERO_EXTEND && N0.hasOneUse() && + if (N0.getOpcode() == ISD::ZERO_EXTEND && N0.hasOneUse() && N0.getOperand(0).getOpcode() == ISD::SRL) { SDValue N0Op0 = N0.getOperand(0); - if (ConstantSDNode *N0Op0C1 = isConstOrConstSplat(N0Op0.getOperand(1))) { - if (N0Op0C1->getAPIntValue().ult(VT.getScalarSizeInBits())) { - uint64_t c1 = N0Op0C1->getZExtValue(); - uint64_t c2 = N1C->getZExtValue(); - if (c1 == c2) { - SDValue NewOp0 = N0.getOperand(0); - EVT CountVT = NewOp0.getOperand(1).getValueType(); - SDLoc DL(N); - SDValue NewSHL = DAG.getNode(ISD::SHL, DL, NewOp0.getValueType(), - NewOp0, - DAG.getConstant(c2, DL, CountVT)); - AddToWorklist(NewSHL.getNode()); - return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N0), VT, NewSHL); - } - } + SDValue InnerShiftAmt = N0Op0.getOperand(1); + EVT InnerShiftAmtVT = N0Op0.getOperand(1).getValueType(); + + auto MatchEqual = [VT](ConstantSDNode *LHS, ConstantSDNode *RHS) { + APInt c1 = LHS->getAPIntValue(); + APInt c2 = RHS->getAPIntValue(); + zeroExtendToMatch(c1, c2); + return c1.ult(VT.getScalarSizeInBits()) && (c1 == c2); + }; + if (ISD::matchBinaryPredicate(InnerShiftAmt, N1, MatchEqual, + /*AllowUndefs*/ false, + /*AllowTypeMismatch*/ true)) { + SDLoc DL(N); + EVT InnerShiftAmtVT = N0Op0.getOperand(1).getValueType(); + SDValue NewSHL = DAG.getZExtOrTrunc(N1, DL, InnerShiftAmtVT); + NewSHL = DAG.getNode(ISD::SHL, DL, N0Op0.getValueType(), N0Op0, NewSHL); + AddToWorklist(NewSHL.getNode()); + return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N0), VT, NewSHL); } } diff --git a/test/CodeGen/X86/combine-shl.ll b/test/CodeGen/X86/combine-shl.ll index 44eea8c9946..55ad952c359 100644 --- a/test/CodeGen/X86/combine-shl.ll +++ b/test/CodeGen/X86/combine-shl.ll @@ -363,44 +363,30 @@ define <8 x i32> @combine_vec_shl_zext_lshr0(<8 x i16> %x) { define <8 x i32> @combine_vec_shl_zext_lshr1(<8 x i16> %x) { ; SSE2-LABEL: combine_vec_shl_zext_lshr1: ; SSE2: # %bb.0: -; SSE2-NEXT: pmulhuw {{.*}}(%rip), %xmm0 -; SSE2-NEXT: pxor %xmm1, %xmm1 -; SSE2-NEXT: movdqa %xmm0, %xmm2 -; SSE2-NEXT: punpckhwd {{.*#+}} xmm2 = xmm2[4],xmm1[4],xmm2[5],xmm1[5],xmm2[6],xmm1[6],xmm2[7],xmm1[7] -; SSE2-NEXT: punpcklwd {{.*#+}} xmm0 = xmm0[0],xmm1[0],xmm0[1],xmm1[1],xmm0[2],xmm1[2],xmm0[3],xmm1[3] -; SSE2-NEXT: movdqa {{.*#+}} xmm1 = [2,4,8,16] -; SSE2-NEXT: pshufd {{.*#+}} xmm3 = xmm0[1,1,3,3] -; SSE2-NEXT: pmuludq %xmm1, %xmm0 -; SSE2-NEXT: pshufd {{.*#+}} xmm0 = xmm0[0,2,2,3] -; SSE2-NEXT: pshufd {{.*#+}} xmm1 = xmm1[1,1,3,3] -; SSE2-NEXT: pmuludq %xmm3, %xmm1 -; SSE2-NEXT: pshufd {{.*#+}} xmm1 = xmm1[0,2,2,3] -; SSE2-NEXT: punpckldq {{.*#+}} xmm0 = xmm0[0],xmm1[0],xmm0[1],xmm1[1] -; SSE2-NEXT: movdqa {{.*#+}} xmm3 = [32,64,128,256] -; SSE2-NEXT: pshufd {{.*#+}} xmm4 = xmm2[1,1,3,3] -; SSE2-NEXT: pmuludq %xmm3, %xmm2 -; SSE2-NEXT: pshufd {{.*#+}} xmm1 = xmm2[0,2,2,3] -; SSE2-NEXT: pshufd {{.*#+}} xmm2 = xmm3[1,1,3,3] -; SSE2-NEXT: pmuludq %xmm4, %xmm2 -; SSE2-NEXT: pshufd {{.*#+}} xmm2 = xmm2[0,2,2,3] -; SSE2-NEXT: punpckldq {{.*#+}} xmm1 = xmm1[0],xmm2[0],xmm1[1],xmm2[1] +; SSE2-NEXT: movdqa %xmm0, %xmm1 +; SSE2-NEXT: pmulhuw {{.*}}(%rip), %xmm1 +; SSE2-NEXT: pxor %xmm2, %xmm2 +; SSE2-NEXT: pmullw {{.*}}(%rip), %xmm1 +; SSE2-NEXT: movdqa %xmm1, %xmm0 +; SSE2-NEXT: punpcklwd {{.*#+}} xmm0 = xmm0[0],xmm2[0],xmm0[1],xmm2[1],xmm0[2],xmm2[2],xmm0[3],xmm2[3] +; SSE2-NEXT: punpckhwd {{.*#+}} xmm1 = xmm1[4],xmm2[4],xmm1[5],xmm2[5],xmm1[6],xmm2[6],xmm1[7],xmm2[7] ; SSE2-NEXT: retq ; ; SSE41-LABEL: combine_vec_shl_zext_lshr1: ; SSE41: # %bb.0: ; SSE41-NEXT: pmulhuw {{.*}}(%rip), %xmm0 -; SSE41-NEXT: pshufd {{.*#+}} xmm1 = xmm0[2,3,0,1] -; SSE41-NEXT: pmovzxwd {{.*#+}} xmm1 = xmm1[0],zero,xmm1[1],zero,xmm1[2],zero,xmm1[3],zero -; SSE41-NEXT: pmovzxwd {{.*#+}} xmm0 = xmm0[0],zero,xmm0[1],zero,xmm0[2],zero,xmm0[3],zero -; SSE41-NEXT: pmulld {{.*}}(%rip), %xmm0 -; SSE41-NEXT: pmulld {{.*}}(%rip), %xmm1 +; SSE41-NEXT: pmullw {{.*}}(%rip), %xmm0 +; SSE41-NEXT: pmovzxwd {{.*#+}} xmm2 = xmm0[0],zero,xmm0[1],zero,xmm0[2],zero,xmm0[3],zero +; SSE41-NEXT: pshufd {{.*#+}} xmm0 = xmm0[2,3,0,1] +; SSE41-NEXT: pmovzxwd {{.*#+}} xmm1 = xmm0[0],zero,xmm0[1],zero,xmm0[2],zero,xmm0[3],zero +; SSE41-NEXT: movdqa %xmm2, %xmm0 ; SSE41-NEXT: retq ; ; AVX-LABEL: combine_vec_shl_zext_lshr1: ; AVX: # %bb.0: ; AVX-NEXT: vpmulhuw {{.*}}(%rip), %xmm0, %xmm0 +; AVX-NEXT: vpmullw {{.*}}(%rip), %xmm0, %xmm0 ; AVX-NEXT: vpmovzxwd {{.*#+}} ymm0 = xmm0[0],zero,xmm0[1],zero,xmm0[2],zero,xmm0[3],zero,xmm0[4],zero,xmm0[5],zero,xmm0[6],zero,xmm0[7],zero -; AVX-NEXT: vpsllvd {{.*}}(%rip), %ymm0, %ymm0 ; AVX-NEXT: retq %1 = lshr <8 x i16> %x, %2 = zext <8 x i16> %1 to <8 x i32>