From 272cd527cc2764f159445bb87d2a50dda68d7742 Mon Sep 17 00:00:00 2001 From: Simon Pilgrim Date: Thu, 20 Jul 2017 10:43:05 +0000 Subject: [PATCH] [DAGCombiner] Match ISD::SRA non-uniform constant vectors patterns using predicates. Use predicate matchers introduced in D35492 to match more ISD::SRA constant folds git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@308600 91177308-0d34-0410-b5e6-96231b3b80d8 --- lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 41 +++++++++++++------ test/CodeGen/X86/combine-sra.ll | 52 +++++------------------- 2 files changed, 38 insertions(+), 55 deletions(-) diff --git a/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index b54b1274ebe..99ece601f56 100644 --- a/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -5624,7 +5624,11 @@ SDValue DAGCombiner::visitSRA(SDNode *N) { if (N0C && N1C && !N1C->isOpaque()) return DAG.FoldConstantArithmetic(ISD::SRA, SDLoc(N), VT, N0C, N1C); // fold (sra x, c >= size(x)) -> undef - if (N1C && N1C->getAPIntValue().uge(OpSizeInBits)) + // NOTE: ALL vector elements must be too big to avoid partial UNDEFs. + auto MatchShiftTooBig = [OpSizeInBits](ConstantSDNode *Val) { + return Val->getAPIntValue().uge(OpSizeInBits); + }; + if (matchUnaryPredicate(N1, MatchShiftTooBig)) return DAG.getUNDEF(VT); // fold (sra x, 0) -> x if (N1C && N1C->isNullValue()) @@ -5648,20 +5652,31 @@ SDValue DAGCombiner::visitSRA(SDNode *N) { } // fold (sra (sra x, c1), c2) -> (sra x, (add c1, c2)) - if (N1C && N0.getOpcode() == ISD::SRA) { - if (ConstantSDNode *N0C1 = isConstOrConstSplat(N0.getOperand(1))) { - SDLoc DL(N); - APInt c1 = N0C1->getAPIntValue(); - APInt c2 = N1C->getAPIntValue(); - zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */); + if (N0.getOpcode() == ISD::SRA) { + SDLoc DL(N); + EVT ShiftVT = N1.getValueType(); - APInt Sum = c1 + c2; - if (Sum.uge(OpSizeInBits)) - Sum = APInt(OpSizeInBits, OpSizeInBits - 1); + auto MatchOutOfRange = [OpSizeInBits](ConstantSDNode *LHS, + ConstantSDNode *RHS) { + APInt c1 = LHS->getAPIntValue(); + APInt c2 = RHS->getAPIntValue(); + zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */); + return (c1 + c2).uge(OpSizeInBits); + }; + if (matchBinaryPredicate(N1, N0.getOperand(1), MatchOutOfRange)) + return DAG.getNode(ISD::SRA, DL, VT, N0.getOperand(0), + DAG.getConstant(OpSizeInBits - 1, DL, ShiftVT)); - return DAG.getNode( - ISD::SRA, DL, VT, N0.getOperand(0), - DAG.getConstant(Sum.getZExtValue(), DL, N1.getValueType())); + auto MatchInRange = [OpSizeInBits](ConstantSDNode *LHS, + ConstantSDNode *RHS) { + APInt c1 = LHS->getAPIntValue(); + APInt c2 = RHS->getAPIntValue(); + zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */); + return (c1 + c2).ult(OpSizeInBits); + }; + if (matchBinaryPredicate(N1, N0.getOperand(1), MatchInRange)) { + SDValue Sum = DAG.getNode(ISD::ADD, DL, ShiftVT, N1, N0.getOperand(1)); + return DAG.getNode(ISD::SRA, DL, VT, N0.getOperand(0), Sum); } } diff --git a/test/CodeGen/X86/combine-sra.ll b/test/CodeGen/X86/combine-sra.ll index 49ebce4857e..f9927198978 100644 --- a/test/CodeGen/X86/combine-sra.ll +++ b/test/CodeGen/X86/combine-sra.ll @@ -48,12 +48,10 @@ define <4 x i32> @combine_vec_ashr_outofrange0(<4 x i32> %x) { define <4 x i32> @combine_vec_ashr_outofrange1(<4 x i32> %x) { ; SSE-LABEL: combine_vec_ashr_outofrange1: ; SSE: # BB#0: -; SSE-NEXT: psrad $31, %xmm0 ; SSE-NEXT: retq ; ; AVX-LABEL: combine_vec_ashr_outofrange1: ; AVX: # BB#0: -; AVX-NEXT: vpsravd {{.*}}(%rip), %xmm0, %xmm0 ; AVX-NEXT: retq %1 = ashr <4 x i32> %x, ret <4 x i32> %1 @@ -91,31 +89,21 @@ define <4 x i32> @combine_vec_ashr_ashr0(<4 x i32> %x) { define <4 x i32> @combine_vec_ashr_ashr1(<4 x i32> %x) { ; SSE-LABEL: combine_vec_ashr_ashr1: ; SSE: # BB#0: +; SSE-NEXT: movdqa %xmm0, %xmm1 +; SSE-NEXT: psrad $10, %xmm1 ; SSE-NEXT: movdqa %xmm0, %xmm2 +; SSE-NEXT: psrad $6, %xmm2 +; SSE-NEXT: pblendw {{.*#+}} xmm2 = xmm2[0,1,2,3],xmm1[4,5,6,7] ; SSE-NEXT: movdqa %xmm0, %xmm1 -; SSE-NEXT: psrad $2, %xmm1 -; SSE-NEXT: pblendw {{.*#+}} xmm1 = xmm0[0,1,2,3],xmm1[4,5,6,7] -; SSE-NEXT: psrad $3, %xmm0 -; SSE-NEXT: psrad $1, %xmm2 -; SSE-NEXT: pblendw {{.*#+}} xmm2 = xmm2[0,1,2,3],xmm0[4,5,6,7] -; SSE-NEXT: pblendw {{.*#+}} xmm1 = xmm1[0,1],xmm2[2,3],xmm1[4,5],xmm2[6,7] -; SSE-NEXT: movdqa %xmm1, %xmm0 -; SSE-NEXT: psrad $7, %xmm0 -; SSE-NEXT: movdqa %xmm1, %xmm2 -; SSE-NEXT: psrad $5, %xmm2 -; SSE-NEXT: pblendw {{.*#+}} xmm2 = xmm2[0,1,2,3],xmm0[4,5,6,7] -; SSE-NEXT: movdqa %xmm1, %xmm0 -; SSE-NEXT: psrad $6, %xmm0 -; SSE-NEXT: psrad $4, %xmm1 -; SSE-NEXT: pblendw {{.*#+}} xmm1 = xmm1[0,1,2,3],xmm0[4,5,6,7] -; SSE-NEXT: pblendw {{.*#+}} xmm1 = xmm1[0,1],xmm2[2,3],xmm1[4,5],xmm2[6,7] -; SSE-NEXT: movdqa %xmm1, %xmm0 +; SSE-NEXT: psrad $8, %xmm1 +; SSE-NEXT: psrad $4, %xmm0 +; SSE-NEXT: pblendw {{.*#+}} xmm0 = xmm0[0,1,2,3],xmm1[4,5,6,7] +; SSE-NEXT: pblendw {{.*#+}} xmm0 = xmm0[0,1],xmm2[2,3],xmm0[4,5],xmm2[6,7] ; SSE-NEXT: retq ; ; AVX-LABEL: combine_vec_ashr_ashr1: ; AVX: # BB#0: ; AVX-NEXT: vpsravd {{.*}}(%rip), %xmm0, %xmm0 -; AVX-NEXT: vpsravd {{.*}}(%rip), %xmm0, %xmm0 ; AVX-NEXT: retq %1 = ashr <4 x i32> %x, %2 = ashr <4 x i32> %1, @@ -125,32 +113,12 @@ define <4 x i32> @combine_vec_ashr_ashr1(<4 x i32> %x) { define <4 x i32> @combine_vec_ashr_ashr2(<4 x i32> %x) { ; SSE-LABEL: combine_vec_ashr_ashr2: ; SSE: # BB#0: -; SSE-NEXT: movdqa %xmm0, %xmm1 -; SSE-NEXT: psrad $20, %xmm1 -; SSE-NEXT: movdqa %xmm0, %xmm2 -; SSE-NEXT: psrad $18, %xmm2 -; SSE-NEXT: pblendw {{.*#+}} xmm2 = xmm2[0,1,2,3],xmm1[4,5,6,7] -; SSE-NEXT: movdqa %xmm0, %xmm1 -; SSE-NEXT: psrad $19, %xmm1 -; SSE-NEXT: psrad $17, %xmm0 -; SSE-NEXT: pblendw {{.*#+}} xmm0 = xmm0[0,1,2,3],xmm1[4,5,6,7] -; SSE-NEXT: pblendw {{.*#+}} xmm0 = xmm0[0,1],xmm2[2,3],xmm0[4,5],xmm2[6,7] -; SSE-NEXT: movdqa %xmm0, %xmm1 -; SSE-NEXT: psrad $28, %xmm1 -; SSE-NEXT: movdqa %xmm0, %xmm2 -; SSE-NEXT: psrad $26, %xmm2 -; SSE-NEXT: pblendw {{.*#+}} xmm2 = xmm2[0,1,2,3],xmm1[4,5,6,7] -; SSE-NEXT: movdqa %xmm0, %xmm1 -; SSE-NEXT: psrad $27, %xmm1 -; SSE-NEXT: psrad $25, %xmm0 -; SSE-NEXT: pblendw {{.*#+}} xmm0 = xmm0[0,1,2,3],xmm1[4,5,6,7] -; SSE-NEXT: pblendw {{.*#+}} xmm0 = xmm0[0,1],xmm2[2,3],xmm0[4,5],xmm2[6,7] +; SSE-NEXT: psrad $31, %xmm0 ; SSE-NEXT: retq ; ; AVX-LABEL: combine_vec_ashr_ashr2: ; AVX: # BB#0: -; AVX-NEXT: vpsravd {{.*}}(%rip), %xmm0, %xmm0 -; AVX-NEXT: vpsravd {{.*}}(%rip), %xmm0, %xmm0 +; AVX-NEXT: vpsrad $31, %xmm0, %xmm0 ; AVX-NEXT: retq %1 = ashr <4 x i32> %x, %2 = ashr <4 x i32> %1, -- 2.50.1