From: Simon Pilgrim Date: Wed, 2 Jan 2019 17:05:37 +0000 (+0000) Subject: [X86] Support SHLD/SHRD masked shift-counts (PR34641) X-Git-Url: https://granicus.if.org/sourcecode?a=commitdiff_plain;h=440f5b291e26c6d64b229356680e8e00921d8873;p=llvm [X86] Support SHLD/SHRD masked shift-counts (PR34641) Peek through shift modulo masks while matching double shift patterns. I was hoping to delay this until I could remove the X86 code with generic funnel shift matching (PR40081) but this will do for now. Differential Revision: https://reviews.llvm.org/D56199 git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@350222 91177308-0d34-0410-b5e6-96231b3b80d8 --- diff --git a/lib/Target/X86/X86ISelLowering.cpp b/lib/Target/X86/X86ISelLowering.cpp index 6515ea4c982..187b23179ec 100644 --- a/lib/Target/X86/X86ISelLowering.cpp +++ b/lib/Target/X86/X86ISelLowering.cpp @@ -36514,6 +36514,7 @@ static SDValue combineOr(SDNode *N, SelectionDAG &DAG, // fold (or (x << c) | (y >> (64 - c))) ==> (shld64 x, y, c) bool OptForSize = DAG.getMachineFunction().getFunction().optForSize(); + unsigned Bits = VT.getScalarSizeInBits(); // SHLD/SHRD instructions have lower register pressure, but on some // platforms they have higher latency than the equivalent @@ -36536,6 +36537,23 @@ static SDValue combineOr(SDNode *N, SelectionDAG &DAG, SDValue ShAmt1 = N1.getOperand(1); if (ShAmt1.getValueType() != MVT::i8) return SDValue(); + + // Peek through any modulo shift masks. + SDValue ShMsk0; + if (ShAmt0.getOpcode() == ISD::AND && + isa(ShAmt0.getOperand(1)) && + ShAmt0.getConstantOperandVal(1) == (Bits - 1)) { + ShMsk0 = ShAmt0; + ShAmt0 = ShAmt0.getOperand(0); + } + SDValue ShMsk1; + if (ShAmt1.getOpcode() == ISD::AND && + isa(ShAmt1.getOperand(1)) && + ShAmt1.getConstantOperandVal(1) == (Bits - 1)) { + ShMsk1 = ShAmt1; + ShAmt1 = ShAmt1.getOperand(0); + } + if (ShAmt0.getOpcode() == ISD::TRUNCATE) ShAmt0 = ShAmt0.getOperand(0); if (ShAmt1.getOpcode() == ISD::TRUNCATE) @@ -36550,24 +36568,26 @@ static SDValue combineOr(SDNode *N, SelectionDAG &DAG, Opc = X86ISD::SHRD; std::swap(Op0, Op1); std::swap(ShAmt0, ShAmt1); + std::swap(ShMsk0, ShMsk1); } // OR( SHL( X, C ), SRL( Y, 32 - C ) ) -> SHLD( X, Y, C ) // OR( SRL( X, C ), SHL( Y, 32 - C ) ) -> SHRD( X, Y, C ) // OR( SHL( X, C ), SRL( SRL( Y, 1 ), XOR( C, 31 ) ) ) -> SHLD( X, Y, C ) // OR( SRL( X, C ), SHL( SHL( Y, 1 ), XOR( C, 31 ) ) ) -> SHRD( X, Y, C ) - unsigned Bits = VT.getScalarSizeInBits(); + // OR( SHL( X, AND( C, 31 ) ), SRL( Y, AND( 0 - C, 31 ) ) ) -> SHLD( X, Y, C ) + // OR( SRL( X, AND( C, 31 ) ), SHL( Y, AND( 0 - C, 31 ) ) ) -> SHRD( X, Y, C ) if (ShAmt1.getOpcode() == ISD::SUB) { SDValue Sum = ShAmt1.getOperand(0); if (auto *SumC = dyn_cast(Sum)) { SDValue ShAmt1Op1 = ShAmt1.getOperand(1); if (ShAmt1Op1.getOpcode() == ISD::TRUNCATE) ShAmt1Op1 = ShAmt1Op1.getOperand(0); - if (SumC->getSExtValue() == Bits && ShAmt1Op1 == ShAmt0) - return DAG.getNode(Opc, DL, VT, - Op0, Op1, - DAG.getNode(ISD::TRUNCATE, DL, - MVT::i8, ShAmt0)); + if ((SumC->getAPIntValue() == Bits || + (SumC->getAPIntValue() == 0 && ShMsk1)) && + ShAmt1Op1 == ShAmt0) + return DAG.getNode(Opc, DL, VT, Op0, Op1, + DAG.getNode(ISD::TRUNCATE, DL, MVT::i8, ShAmt0)); } } else if (auto *ShAmt1C = dyn_cast(ShAmt1)) { auto *ShAmt0C = dyn_cast(ShAmt0); @@ -36583,7 +36603,8 @@ static SDValue combineOr(SDNode *N, SelectionDAG &DAG, SDValue ShAmt1Op0 = ShAmt1.getOperand(0); if (ShAmt1Op0.getOpcode() == ISD::TRUNCATE) ShAmt1Op0 = ShAmt1Op0.getOperand(0); - if (MaskC->getSExtValue() == (Bits - 1) && ShAmt1Op0 == ShAmt0) { + if (MaskC->getSExtValue() == (Bits - 1) && + (ShAmt1Op0 == ShAmt0 || ShAmt1Op0 == ShMsk0)) { if (Op1.getOpcode() == InnerShift && isa(Op1.getOperand(1)) && Op1.getConstantOperandVal(1) == 1) { @@ -36594,7 +36615,7 @@ static SDValue combineOr(SDNode *N, SelectionDAG &DAG, if (InnerShift == ISD::SHL && Op1.getOpcode() == ISD::ADD && Op1.getOperand(0) == Op1.getOperand(1)) { return DAG.getNode(Opc, DL, VT, Op0, Op1.getOperand(0), - DAG.getNode(ISD::TRUNCATE, DL, MVT::i8, ShAmt0)); + DAG.getNode(ISD::TRUNCATE, DL, MVT::i8, ShAmt0)); } } } diff --git a/test/CodeGen/X86/shift-double.ll b/test/CodeGen/X86/shift-double.ll index 0b22140e3fe..fd555c4aaac 100644 --- a/test/CodeGen/X86/shift-double.ll +++ b/test/CodeGen/X86/shift-double.ll @@ -460,24 +460,18 @@ define i32 @test17(i32 %hi, i32 %lo, i32 %bits) nounwind { define i32 @shld_safe_i32(i32, i32, i32) { ; X86-LABEL: shld_safe_i32: ; X86: # %bb.0: -; X86-NEXT: movl {{[0-9]+}}(%esp), %eax ; X86-NEXT: movb {{[0-9]+}}(%esp), %cl ; X86-NEXT: movl {{[0-9]+}}(%esp), %edx -; X86-NEXT: shll %cl, %edx -; X86-NEXT: negb %cl -; X86-NEXT: shrl %cl, %eax -; X86-NEXT: orl %edx, %eax +; X86-NEXT: movl {{[0-9]+}}(%esp), %eax +; X86-NEXT: shldl %cl, %edx, %eax ; X86-NEXT: retl ; ; X64-LABEL: shld_safe_i32: ; X64: # %bb.0: ; X64-NEXT: movl %edx, %ecx -; X64-NEXT: movl %esi, %eax -; X64-NEXT: shll %cl, %edi -; X64-NEXT: negb %cl +; X64-NEXT: movl %edi, %eax ; X64-NEXT: # kill: def $cl killed $cl killed $ecx -; X64-NEXT: shrl %cl, %eax -; X64-NEXT: orl %edi, %eax +; X64-NEXT: shldl %cl, %esi, %eax ; X64-NEXT: retq %4 = and i32 %2, 31 %5 = shl i32 %0, %4 @@ -491,24 +485,18 @@ define i32 @shld_safe_i32(i32, i32, i32) { define i32 @shrd_safe_i32(i32, i32, i32) { ; X86-LABEL: shrd_safe_i32: ; X86: # %bb.0: -; X86-NEXT: movl {{[0-9]+}}(%esp), %eax ; X86-NEXT: movb {{[0-9]+}}(%esp), %cl ; X86-NEXT: movl {{[0-9]+}}(%esp), %edx -; X86-NEXT: shrl %cl, %edx -; X86-NEXT: negb %cl -; X86-NEXT: shll %cl, %eax -; X86-NEXT: orl %edx, %eax +; X86-NEXT: movl {{[0-9]+}}(%esp), %eax +; X86-NEXT: shrdl %cl, %edx, %eax ; X86-NEXT: retl ; ; X64-LABEL: shrd_safe_i32: ; X64: # %bb.0: ; X64-NEXT: movl %edx, %ecx -; X64-NEXT: movl %esi, %eax -; X64-NEXT: shrl %cl, %edi -; X64-NEXT: negb %cl +; X64-NEXT: movl %edi, %eax ; X64-NEXT: # kill: def $cl killed $cl killed $ecx -; X64-NEXT: shll %cl, %eax -; X64-NEXT: orl %edi, %eax +; X64-NEXT: shrdl %cl, %esi, %eax ; X64-NEXT: retq %4 = and i32 %2, 31 %5 = lshr i32 %0, %4