From c2a0046960128f0f09a43b95546bb4f2e623619b Mon Sep 17 00:00:00 2001 From: Simon Pilgrim <llvm-dev@redking.me.uk> Date: Thu, 27 Jun 2019 13:48:43 +0000 Subject: [PATCH] [TargetLowering] SimplifyDemandedBits - use DemandedElts to better identify partial splat shift amounts git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@364541 91177308-0d34-0410-b5e6-96231b3b80d8 --- lib/CodeGen/SelectionDAG/TargetLowering.cpp | 32 ++++++++++++++------- test/CodeGen/X86/combine-sdiv.ll | 6 ++-- test/CodeGen/X86/known-signbits-vector.ll | 5 +--- 3 files changed, 24 insertions(+), 19 deletions(-) diff --git a/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/lib/CodeGen/SelectionDAG/TargetLowering.cpp index 74ab96afe5a..50cd8cded62 100644 --- a/lib/CodeGen/SelectionDAG/TargetLowering.cpp +++ b/lib/CodeGen/SelectionDAG/TargetLowering.cpp @@ -1040,20 +1040,23 @@ bool TargetLowering::SimplifyDemandedBits( SDValue Op0 = Op.getOperand(0); SDValue Op1 = Op.getOperand(1); - if (ConstantSDNode *SA = isConstOrConstSplat(Op1)) { + if (ConstantSDNode *SA = isConstOrConstSplat(Op1, DemandedElts)) { // If the shift count is an invalid immediate, don't do anything. if (SA->getAPIntValue().uge(BitWidth)) break; unsigned ShAmt = SA->getZExtValue(); + if (ShAmt == 0) + return TLO.CombineTo(Op, Op0); // If this is ((X >>u C1) << ShAmt), see if we can simplify this into a // single shift. We can do this if the bottom bits (which are shifted // out) are never demanded. + // TODO - support non-uniform vector amounts. if (Op0.getOpcode() == ISD::SRL) { - if (ShAmt && - (DemandedBits & APInt::getLowBitsSet(BitWidth, ShAmt)) == 0) { - if (ConstantSDNode *SA2 = isConstOrConstSplat(Op0.getOperand(1))) { + if ((DemandedBits & APInt::getLowBitsSet(BitWidth, ShAmt)) == 0) { + if (ConstantSDNode *SA2 = + isConstOrConstSplat(Op0.getOperand(1), DemandedElts)) { if (SA2->getAPIntValue().ult(BitWidth)) { unsigned C1 = SA2->getZExtValue(); unsigned Opc = ISD::SHL; @@ -1134,13 +1137,16 @@ bool TargetLowering::SimplifyDemandedBits( SDValue Op0 = Op.getOperand(0); SDValue Op1 = Op.getOperand(1); - if (ConstantSDNode *SA = isConstOrConstSplat(Op1)) { + if (ConstantSDNode *SA = isConstOrConstSplat(Op1, DemandedElts)) { // If the shift count is an invalid immediate, don't do anything. if (SA->getAPIntValue().uge(BitWidth)) break; - EVT ShiftVT = Op1.getValueType(); unsigned ShAmt = SA->getZExtValue(); + if (ShAmt == 0) + return TLO.CombineTo(Op, Op0); + + EVT ShiftVT = Op1.getValueType(); APInt InDemandedMask = (DemandedBits << ShAmt); // If the shift is exact, then it does demand the low bits (and knows that @@ -1151,10 +1157,11 @@ bool TargetLowering::SimplifyDemandedBits( // If this is ((X << C1) >>u ShAmt), see if we can simplify this into a // single shift. We can do this if the top bits (which are shifted out) // are never demanded. + // TODO - support non-uniform vector amounts. if (Op0.getOpcode() == ISD::SHL) { - if (ConstantSDNode *SA2 = isConstOrConstSplat(Op0.getOperand(1))) { - if (ShAmt && - (DemandedBits & APInt::getHighBitsSet(BitWidth, ShAmt)) == 0) { + if (ConstantSDNode *SA2 = + isConstOrConstSplat(Op0.getOperand(1), DemandedElts)) { + if ((DemandedBits & APInt::getHighBitsSet(BitWidth, ShAmt)) == 0) { if (SA2->getAPIntValue().ult(BitWidth)) { unsigned C1 = SA2->getZExtValue(); unsigned Opc = ISD::SRL; @@ -1195,12 +1202,15 @@ bool TargetLowering::SimplifyDemandedBits( if (DemandedBits.isOneValue()) return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::SRL, dl, VT, Op0, Op1)); - if (ConstantSDNode *SA = isConstOrConstSplat(Op1)) { + if (ConstantSDNode *SA = isConstOrConstSplat(Op1, DemandedElts)) { // If the shift count is an invalid immediate, don't do anything. if (SA->getAPIntValue().uge(BitWidth)) break; unsigned ShAmt = SA->getZExtValue(); + if (ShAmt == 0) + return TLO.CombineTo(Op, Op0); + APInt InDemandedMask = (DemandedBits << ShAmt); // If the shift is exact, then it does demand the low bits (and knows that @@ -1251,7 +1261,7 @@ bool TargetLowering::SimplifyDemandedBits( SDValue Op2 = Op.getOperand(2); bool IsFSHL = (Op.getOpcode() == ISD::FSHL); - if (ConstantSDNode *SA = isConstOrConstSplat(Op2)) { + if (ConstantSDNode *SA = isConstOrConstSplat(Op2, DemandedElts)) { unsigned Amt = SA->getAPIntValue().urem(BitWidth); // For fshl, 0-shift returns the 1st arg. diff --git a/test/CodeGen/X86/combine-sdiv.ll b/test/CodeGen/X86/combine-sdiv.ll index ad7a28ed4ad..3d78569256a 100644 --- a/test/CodeGen/X86/combine-sdiv.ll +++ b/test/CodeGen/X86/combine-sdiv.ll @@ -2393,8 +2393,7 @@ define <4 x i32> @non_splat_minus_one_divisor_2(<4 x i32> %A) { ; ; AVX2ORLATER-LABEL: non_splat_minus_one_divisor_2: ; AVX2ORLATER: # %bb.0: -; AVX2ORLATER-NEXT: vpsrad $31, %xmm0, %xmm1 -; AVX2ORLATER-NEXT: vpsrlvd {{.*}}(%rip), %xmm1, %xmm1 +; AVX2ORLATER-NEXT: vpsrlvd {{.*}}(%rip), %xmm0, %xmm1 ; AVX2ORLATER-NEXT: vpaddd %xmm1, %xmm0, %xmm1 ; AVX2ORLATER-NEXT: vpsravd {{.*}}(%rip), %xmm1, %xmm1 ; AVX2ORLATER-NEXT: vpblendd {{.*#+}} xmm0 = xmm0[0,1],xmm1[2,3] @@ -2405,8 +2404,7 @@ define <4 x i32> @non_splat_minus_one_divisor_2(<4 x i32> %A) { ; ; XOP-LABEL: non_splat_minus_one_divisor_2: ; XOP: # %bb.0: -; XOP-NEXT: vpsrad $31, %xmm0, %xmm1 -; XOP-NEXT: vpshld {{.*}}(%rip), %xmm1, %xmm1 +; XOP-NEXT: vpshld {{.*}}(%rip), %xmm0, %xmm1 ; XOP-NEXT: vpaddd %xmm1, %xmm0, %xmm1 ; XOP-NEXT: vpshad {{.*}}(%rip), %xmm1, %xmm1 ; XOP-NEXT: vpblendw {{.*#+}} xmm0 = xmm0[0,1,2,3],xmm1[4,5,6,7] diff --git a/test/CodeGen/X86/known-signbits-vector.ll b/test/CodeGen/X86/known-signbits-vector.ll index eacdd24404f..43cf147c53d 100644 --- a/test/CodeGen/X86/known-signbits-vector.ll +++ b/test/CodeGen/X86/known-signbits-vector.ll @@ -81,10 +81,7 @@ define float @signbits_ashr_extract_sitofp_1(<2 x i64> %a0) nounwind { ; X32-LABEL: signbits_ashr_extract_sitofp_1: ; X32: # %bb.0: ; X32-NEXT: pushl %eax -; X32-NEXT: vpsrlq $32, %xmm0, %xmm0 -; X32-NEXT: vmovdqa {{.*#+}} xmm1 = [2147483648,0,1,0] -; X32-NEXT: vpxor %xmm1, %xmm0, %xmm0 -; X32-NEXT: vpsubq %xmm1, %xmm0, %xmm0 +; X32-NEXT: vpermilps {{.*#+}} xmm0 = xmm0[1,1,2,3] ; X32-NEXT: vcvtdq2ps %xmm0, %xmm0 ; X32-NEXT: vmovss %xmm0, (%esp) ; X32-NEXT: flds (%esp) -- 2.40.0