From c69190afafc22ac1c8ced5b42987f308d2e6c9f1 Mon Sep 17 00:00:00 2001 From: Simon Pilgrim Date: Wed, 19 Jun 2019 12:25:29 +0000 Subject: [PATCH] [DAGCombiner] Support (shl (ext (shl x, c1)), c2) -> 0 non-uniform folds. Use matchBinaryPredicate instead of isConstOrConstSplat to let us handle non-uniform shift cases. This requires us to tweak matchBinaryPredicate to allow it to (optionally) handle constants with different type widths. git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@363792 91177308-0d34-0410-b5e6-96231b3b80d8 --- include/llvm/CodeGen/SelectionDAGNodes.h | 3 ++- lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 30 +++++++++++++++++------ lib/CodeGen/SelectionDAG/SelectionDAG.cpp | 8 +++--- test/CodeGen/X86/combine-shl.ll | 29 +++++----------------- 4 files changed, 35 insertions(+), 35 deletions(-) diff --git a/include/llvm/CodeGen/SelectionDAGNodes.h b/include/llvm/CodeGen/SelectionDAGNodes.h index c26b1a86c93..5aab9643e09 100644 --- a/include/llvm/CodeGen/SelectionDAGNodes.h +++ b/include/llvm/CodeGen/SelectionDAGNodes.h @@ -2617,10 +2617,11 @@ namespace ISD { /// Attempt to match a binary predicate against a pair of scalar/splat /// constants or every element of a pair of constant BUILD_VECTORs. /// If AllowUndef is true, then UNDEF elements will pass nullptr to Match. + /// If AllowTypeMismatch is true then RetType + ArgTypes don't need to match. bool matchBinaryPredicate( SDValue LHS, SDValue RHS, std::function Match, - bool AllowUndefs = false); + bool AllowUndefs = false, bool AllowTypeMismatch = false); } // end namespace ISD } // end namespace llvm diff --git a/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index 676cfbc859e..ae555ac0711 100644 --- a/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -7210,19 +7210,35 @@ SDValue DAGCombiner::visitSHL(SDNode *N) { // that are shifted out by the inner shift in the first form. This means // the outer shift size must be >= the number of bits added by the ext. // As a corollary, we don't care what kind of ext it is. - if (N1C && (N0.getOpcode() == ISD::ZERO_EXTEND || - N0.getOpcode() == ISD::ANY_EXTEND || - N0.getOpcode() == ISD::SIGN_EXTEND) && + if ((N0.getOpcode() == ISD::ZERO_EXTEND || + N0.getOpcode() == ISD::ANY_EXTEND || + N0.getOpcode() == ISD::SIGN_EXTEND) && N0.getOperand(0).getOpcode() == ISD::SHL) { SDValue N0Op0 = N0.getOperand(0); - if (ConstantSDNode *N0Op0C1 = isConstOrConstSplat(N0Op0.getOperand(1))) { + SDValue InnerShiftAmt = N0Op0.getOperand(1); + EVT InnerVT = N0Op0.getValueType(); + uint64_t InnerBitwidth = InnerVT.getScalarSizeInBits(); + + auto MatchOutOfRange = [OpSizeInBits, InnerBitwidth](ConstantSDNode *LHS, + ConstantSDNode *RHS) { + APInt c1 = LHS->getAPIntValue(); + APInt c2 = RHS->getAPIntValue(); + zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */); + return c2.uge(OpSizeInBits - InnerBitwidth) && + (c1 + c2).uge(OpSizeInBits); + }; + if (ISD::matchBinaryPredicate(InnerShiftAmt, N1, MatchOutOfRange, + /*AllowUndefs*/ false, + /*AllowTypeMismatch*/ true)) + return DAG.getConstant(0, SDLoc(N), VT); + + ConstantSDNode *N0Op0C1 = isConstOrConstSplat(InnerShiftAmt); + if (N1C && N0Op0C1) { APInt c1 = N0Op0C1->getAPIntValue(); APInt c2 = N1C->getAPIntValue(); zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */); - EVT InnerShiftVT = N0Op0.getValueType(); - uint64_t InnerShiftSize = InnerShiftVT.getScalarSizeInBits(); - if (c2.uge(OpSizeInBits - InnerShiftSize)) { + if (c2.uge(OpSizeInBits - InnerBitwidth)) { SDLoc DL(N0); APInt Sum = c1 + c2; if (Sum.uge(OpSizeInBits)) diff --git a/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/lib/CodeGen/SelectionDAG/SelectionDAG.cpp index fea8d7ad894..a2eca91c67e 100644 --- a/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -294,8 +294,8 @@ bool ISD::matchUnaryPredicate(SDValue Op, bool ISD::matchBinaryPredicate( SDValue LHS, SDValue RHS, std::function Match, - bool AllowUndefs) { - if (LHS.getValueType() != RHS.getValueType()) + bool AllowUndefs, bool AllowTypeMismatch) { + if (!AllowTypeMismatch && LHS.getValueType() != RHS.getValueType()) return false; // TODO: Add support for scalar UNDEF cases? @@ -318,8 +318,8 @@ bool ISD::matchBinaryPredicate( auto *RHSCst = dyn_cast(RHSOp); if ((!LHSCst && !LHSUndef) || (!RHSCst && !RHSUndef)) return false; - if (LHSOp.getValueType() != SVT || - LHSOp.getValueType() != RHSOp.getValueType()) + if (!AllowTypeMismatch && (LHSOp.getValueType() != SVT || + LHSOp.getValueType() != RHSOp.getValueType())) return false; if (!Match(LHSCst, RHSCst)) return false; diff --git a/test/CodeGen/X86/combine-shl.ll b/test/CodeGen/X86/combine-shl.ll index 8d48f180c14..0dd6426543c 100644 --- a/test/CodeGen/X86/combine-shl.ll +++ b/test/CodeGen/X86/combine-shl.ll @@ -264,33 +264,16 @@ define <8 x i32> @combine_vec_shl_ext_shl0(<8 x i16> %x) { ret <8 x i32> %3 } -; TODO - this should fold to ZERO. define <8 x i32> @combine_vec_shl_ext_shl1(<8 x i16> %x) { -; SSE2-LABEL: combine_vec_shl_ext_shl1: -; SSE2: # %bb.0: -; SSE2-NEXT: pmullw {{.*}}(%rip), %xmm0 -; SSE2-NEXT: punpcklwd {{.*#+}} xmm0 = xmm0[0,0,1,1,2,2,3,3] -; SSE2-NEXT: pslld $30, %xmm0 -; SSE2-NEXT: xorpd %xmm1, %xmm1 -; SSE2-NEXT: movsd {{.*#+}} xmm0 = xmm1[0],xmm0[1] -; SSE2-NEXT: movsd {{.*#+}} xmm1 = xmm1[0,1] -; SSE2-NEXT: retq -; -; SSE41-LABEL: combine_vec_shl_ext_shl1: -; SSE41: # %bb.0: -; SSE41-NEXT: pmullw {{.*}}(%rip), %xmm0 -; SSE41-NEXT: pmovsxwd %xmm0, %xmm0 -; SSE41-NEXT: pslld $30, %xmm0 -; SSE41-NEXT: pxor %xmm1, %xmm1 -; SSE41-NEXT: pblendw {{.*#+}} xmm0 = xmm1[0,1,2,3],xmm0[4,5,6,7] -; SSE41-NEXT: pxor %xmm1, %xmm1 -; SSE41-NEXT: retq +; SSE-LABEL: combine_vec_shl_ext_shl1: +; SSE: # %bb.0: +; SSE-NEXT: xorps %xmm0, %xmm0 +; SSE-NEXT: xorps %xmm1, %xmm1 +; SSE-NEXT: retq ; ; AVX-LABEL: combine_vec_shl_ext_shl1: ; AVX: # %bb.0: -; AVX-NEXT: vpmullw {{.*}}(%rip), %xmm0, %xmm0 -; AVX-NEXT: vpmovsxwd %xmm0, %ymm0 -; AVX-NEXT: vpsllvd {{.*}}(%rip), %ymm0, %ymm0 +; AVX-NEXT: vxorps %xmm0, %xmm0, %xmm0 ; AVX-NEXT: retq %1 = shl <8 x i16> %x, %2 = sext <8 x i16> %1 to <8 x i32> -- 2.50.1