From: Craig Topper Date: Sun, 19 Feb 2017 19:36:58 +0000 (+0000) Subject: [AVX-512] Add patterns to recognize masked vpternlog when the passthrough operand... X-Git-Url: https://granicus.if.org/sourcecode?a=commitdiff_plain;h=07a02364017a8de22885fa56f6d662b32f879d52;p=llvm [AVX-512] Add patterns to recognize masked vpternlog when the passthrough operand is not operand 0. This uses a SDNodeXForm to swizzle the appropriate immediate bits to allow this to be matched. git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@295612 91177308-0d34-0410-b5e6-96231b3b80d8 --- diff --git a/lib/Target/X86/X86InstrAVX512.td b/lib/Target/X86/X86InstrAVX512.td index e7c9e87f956..dea8b2b49b5 100644 --- a/lib/Target/X86/X86InstrAVX512.td +++ b/lib/Target/X86/X86InstrAVX512.td @@ -8891,6 +8891,31 @@ multiclass avx512_psadbw_packed_all opc, SDNode OpNode, defm VPSADBW : avx512_psadbw_packed_all<0xf6, X86psadbw, "vpsadbw", HasBWI>, EVEX_4V; +// Transforms to swizzle an immediate to enable better matching when +// memory operand isn't in the right place. +def VPTERNLOG321_imm8 : SDNodeXFormgetZExtValue(); + // Swap bits 1/4 and 3/6. + uint8_t NewImm = Imm & 0xa5; + if (Imm & 0x02) NewImm |= 0x10; + if (Imm & 0x10) NewImm |= 0x02; + if (Imm & 0x08) NewImm |= 0x40; + if (Imm & 0x40) NewImm |= 0x08; + return getI8Imm(NewImm, SDLoc(N)); +}]>; +def VPTERNLOG213_imm8 : SDNodeXFormgetZExtValue(); + // Swap bits 2/4 and 3/5. + uint8_t NewImm = Imm & 0xc3; + if (Imm & 0x02) NewImm |= 0x10; + if (Imm & 0x10) NewImm |= 0x02; + if (Imm & 0x08) NewImm |= 0x20; + if (Imm & 0x20) NewImm |= 0x08; + return getI8Imm(NewImm, SDLoc(N)); +}]>; + multiclass avx512_ternlog opc, string OpcodeStr, SDNode OpNode, X86VectorVTInfo _>{ let Constraints = "$src1 = $dst", ExeDomain = _.ExeDomain in { @@ -8919,6 +8944,20 @@ multiclass avx512_ternlog opc, string OpcodeStr, SDNode OpNode, (i8 imm:$src4)), 1, 0>, EVEX_B, AVX512AIi8Base, EVEX_4V, EVEX_CD8<_.EltSize, CD8VF>; }// Constraints = "$src1 = $dst" + + // Additional patterns for matching passthru operand in other positions. + let AddedComplexity = 20 in { + def : Pat<(_.VT (vselect _.KRCWM:$mask, + (OpNode _.RC:$src3, _.RC:$src2, _.RC:$src1, (i8 imm:$src4)), + _.RC:$src1)), + (!cast(NAME#_.ZSuffix#rrik) _.RC:$src1, _.KRCWM:$mask, + _.RC:$src2, _.RC:$src3, (VPTERNLOG321_imm8 imm:$src4))>; + def : Pat<(_.VT (vselect _.KRCWM:$mask, + (OpNode _.RC:$src2, _.RC:$src1, _.RC:$src3, (i8 imm:$src4)), + _.RC:$src1)), + (!cast(NAME#_.ZSuffix#rrik) _.RC:$src1, _.KRCWM:$mask, + _.RC:$src2, _.RC:$src3, (VPTERNLOG213_imm8 imm:$src4))>; + } } multiclass avx512_common_ternlog{ diff --git a/test/CodeGen/X86/avx512-vpternlog-commute.ll b/test/CodeGen/X86/avx512-vpternlog-commute.ll index 648de3582a3..1f203936bcc 100644 --- a/test/CodeGen/X86/avx512-vpternlog-commute.ll +++ b/test/CodeGen/X86/avx512-vpternlog-commute.ll @@ -189,9 +189,9 @@ define <16 x i32> @vpternlog_v16i32_210_mask(<16 x i32> %x0, <16 x i32> %x1, <16 define <16 x i32> @vpternlog_v16i32_012_mask1(<16 x i32> %x0, <16 x i32> %x1, <16 x i32> %x2, i16 %mask) { ; CHECK-LABEL: vpternlog_v16i32_012_mask1: ; CHECK: ## BB#0: -; CHECK-NEXT: vpternlogd $33, %zmm2, %zmm1, %zmm0 ; CHECK-NEXT: kmovw %edi, %k1 -; CHECK-NEXT: vpblendmd %zmm0, %zmm1, %zmm0 {%k1} +; CHECK-NEXT: vpternlogd $9, %zmm2, %zmm0, %zmm1 {%k1} +; CHECK-NEXT: vmovdqa64 %zmm1, %zmm0 ; CHECK-NEXT: retq %res = call <16 x i32> @llvm.x86.avx512.mask.pternlog.d.512(<16 x i32> %x0, <16 x i32> %x1, <16 x i32> %x2, i32 33, i16 -1) %mask.cast = bitcast i16 %mask to <16 x i1> @@ -202,9 +202,9 @@ define <16 x i32> @vpternlog_v16i32_012_mask1(<16 x i32> %x0, <16 x i32> %x1, <1 define <16 x i32> @vpternlog_v16i32_012_mask2(<16 x i32> %x0, <16 x i32> %x1, <16 x i32> %x2, i16 %mask) { ; CHECK-LABEL: vpternlog_v16i32_012_mask2: ; CHECK: ## BB#0: -; CHECK-NEXT: vpternlogd $33, %zmm2, %zmm1, %zmm0 ; CHECK-NEXT: kmovw %edi, %k1 -; CHECK-NEXT: vpblendmd %zmm0, %zmm2, %zmm0 {%k1} +; CHECK-NEXT: vpternlogd $33, %zmm0, %zmm1, %zmm2 {%k1} +; CHECK-NEXT: vmovdqa64 %zmm2, %zmm0 ; CHECK-NEXT: retq %res = call <16 x i32> @llvm.x86.avx512.mask.pternlog.d.512(<16 x i32> %x0, <16 x i32> %x1, <16 x i32> %x2, i32 33, i16 -1) %mask.cast = bitcast i16 %mask to <16 x i1>