From 3d16775dd15f1ebfb98cf620c04b99e7c89ac1dd Mon Sep 17 00:00:00 2001 From: Craig Topper Date: Sat, 24 Aug 2019 23:14:57 +0000 Subject: [PATCH] [X86] Add isel patterns to match vpdpwssd avx512vnni instruction from add+pmaddwd nodes. git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@369859 91177308-0d34-0410-b5e6-96231b3b80d8 --- lib/Target/X86/X86InstrAVX512.td | 29 +++++ test/CodeGen/X86/avx512vnni.ll | 198 +++++++++++++++++++++++++++++++ 2 files changed, 227 insertions(+) create mode 100644 test/CodeGen/X86/avx512vnni.ll diff --git a/lib/Target/X86/X86InstrAVX512.td b/lib/Target/X86/X86InstrAVX512.td index 3bdc610ecf6..d6615535fbb 100644 --- a/lib/Target/X86/X86InstrAVX512.td +++ b/lib/Target/X86/X86InstrAVX512.td @@ -11934,6 +11934,35 @@ defm VPDPBUSDS : VNNI_common<0x51, "vpdpbusds", X86Vpdpbusds, SchedWriteVecIMul defm VPDPWSSD : VNNI_common<0x52, "vpdpwssd", X86Vpdpwssd, SchedWriteVecIMul, 1>; defm VPDPWSSDS : VNNI_common<0x53, "vpdpwssds", X86Vpdpwssds, SchedWriteVecIMul, 1>; +def X86vpmaddwd_su : PatFrag<(ops node:$lhs, node:$rhs), + (X86vpmaddwd node:$lhs, node:$rhs), [{ + return N->hasOneUse(); +}]>; + +// Patterns to match VPDPWSSD from existing instructions/intrinsics. +let Predicates = [HasVNNI] in { + def : Pat<(v16i32 (add VR512:$src1, + (X86vpmaddwd_su VR512:$src2, VR512:$src3))), + (VPDPWSSDZr VR512:$src1, VR512:$src2, VR512:$src3)>; + def : Pat<(v16i32 (add VR512:$src1, + (X86vpmaddwd_su VR512:$src2, (load addr:$src3)))), + (VPDPWSSDZm VR512:$src1, VR512:$src2, addr:$src3)>; +} +let Predicates = [HasVNNI,HasVLX] in { + def : Pat<(v8i32 (add VR256X:$src1, + (X86vpmaddwd_su VR256X:$src2, VR256X:$src3))), + (VPDPWSSDZ256r VR256X:$src1, VR256X:$src2, VR256X:$src3)>; + def : Pat<(v8i32 (add VR256X:$src1, + (X86vpmaddwd_su VR256X:$src2, (load addr:$src3)))), + (VPDPWSSDZ256m VR256X:$src1, VR256X:$src2, addr:$src3)>; + def : Pat<(v4i32 (add VR128X:$src1, + (X86vpmaddwd_su VR128X:$src2, VR128X:$src3))), + (VPDPWSSDZ128r VR128X:$src1, VR128X:$src2, VR128X:$src3)>; + def : Pat<(v4i32 (add VR128X:$src1, + (X86vpmaddwd_su VR128X:$src2, (load addr:$src3)))), + (VPDPWSSDZ128m VR128X:$src1, VR128X:$src2, addr:$src3)>; +} + //===----------------------------------------------------------------------===// // Bit Algorithms //===----------------------------------------------------------------------===// diff --git a/test/CodeGen/X86/avx512vnni.ll b/test/CodeGen/X86/avx512vnni.ll new file mode 100644 index 00000000000..2464a3e93ac --- /dev/null +++ b/test/CodeGen/X86/avx512vnni.ll @@ -0,0 +1,198 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc < %s -disable-peephole -mtriple=x86_64-unknown-unknown -mattr=+avx512vnni,+avx512vl,+avx512bw | FileCheck %s --check-prefixes=CHECK + +define <4 x i32> @test_pmaddwd_v8i16_add_v4i32(<4 x i32> %a0, <8 x i16> %a1, <8 x i16> %a2) { +; CHECK-LABEL: test_pmaddwd_v8i16_add_v4i32: +; CHECK: # %bb.0: +; CHECK-NEXT: vpdpwssd %xmm2, %xmm1, %xmm0 +; CHECK-NEXT: retq + %1 = call <4 x i32> @llvm.x86.sse2.pmadd.wd(<8 x i16> %a1, <8 x i16> %a2) + %2 = add <4 x i32> %1, %a0 + ret <4 x i32> %2 +} + +define <4 x i32> @test_pmaddwd_v8i16_add_v4i32_commute(<4 x i32> %a0, <8 x i16> %a1, <8 x i16> %a2) { +; CHECK-LABEL: test_pmaddwd_v8i16_add_v4i32_commute: +; CHECK: # %bb.0: +; CHECK-NEXT: vpdpwssd %xmm2, %xmm1, %xmm0 +; CHECK-NEXT: retq + %1 = call <4 x i32> @llvm.x86.sse2.pmadd.wd(<8 x i16> %a1, <8 x i16> %a2) + %2 = add <4 x i32> %a0, %1 + ret <4 x i32> %2 +} + +define <4 x i32> @test_pmaddwd_v8i16_add_v4i32_load1(<4 x i32> %a0, <8 x i16>* %p1, <8 x i16> %a2) { +; CHECK-LABEL: test_pmaddwd_v8i16_add_v4i32_load1: +; CHECK: # %bb.0: +; CHECK-NEXT: vpdpwssd (%rdi), %xmm1, %xmm0 +; CHECK-NEXT: retq + %a1 = load <8 x i16>, <8 x i16>* %p1 + %1 = call <4 x i32> @llvm.x86.sse2.pmadd.wd(<8 x i16> %a1, <8 x i16> %a2) + %2 = add <4 x i32> %1, %a0 + ret <4 x i32> %2 +} + +define <4 x i32> @test_pmaddwd_v8i16_add_v4i32_load2(<4 x i32> %a0, <8 x i16> %a1, <8 x i16>* %p2) { +; CHECK-LABEL: test_pmaddwd_v8i16_add_v4i32_load2: +; CHECK: # %bb.0: +; CHECK-NEXT: vpdpwssd (%rdi), %xmm1, %xmm0 +; CHECK-NEXT: retq + %a2 = load <8 x i16>, <8 x i16>* %p2 + %1 = call <4 x i32> @llvm.x86.sse2.pmadd.wd(<8 x i16> %a1, <8 x i16> %a2) + %2 = add <4 x i32> %1, %a0 + ret <4 x i32> %2 +} + +define <4 x i32> @test_pmaddwd_v8i16_add_v4i32_commute_load1(<4 x i32> %a0, <8 x i16>* %p1, <8 x i16> %a2) { +; CHECK-LABEL: test_pmaddwd_v8i16_add_v4i32_commute_load1: +; CHECK: # %bb.0: +; CHECK-NEXT: vpdpwssd (%rdi), %xmm1, %xmm0 +; CHECK-NEXT: retq + %a1 = load <8 x i16>, <8 x i16>* %p1 + %1 = call <4 x i32> @llvm.x86.sse2.pmadd.wd(<8 x i16> %a1, <8 x i16> %a2) + %2 = add <4 x i32> %a0, %1 + ret <4 x i32> %2 +} + +define <4 x i32> @test_pmaddwd_v8i16_add_v4i32_commute_load2(<4 x i32> %a0, <8 x i16> %a1, <8 x i16>* %p2) { +; CHECK-LABEL: test_pmaddwd_v8i16_add_v4i32_commute_load2: +; CHECK: # %bb.0: +; CHECK-NEXT: vpdpwssd (%rdi), %xmm1, %xmm0 +; CHECK-NEXT: retq + %a2 = load <8 x i16>, <8 x i16>* %p2 + %1 = call <4 x i32> @llvm.x86.sse2.pmadd.wd(<8 x i16> %a1, <8 x i16> %a2) + %2 = add <4 x i32> %a0, %1 + ret <4 x i32> %2 +} + +define <8 x i32> @test_pmaddwd_v16i16_add_v8i32(<8 x i32> %a0, <16 x i16> %a1, <16 x i16> %a2) { +; CHECK-LABEL: test_pmaddwd_v16i16_add_v8i32: +; CHECK: # %bb.0: +; CHECK-NEXT: vpdpwssd %ymm2, %ymm1, %ymm0 +; CHECK-NEXT: retq + %1 = call <8 x i32> @llvm.x86.avx2.pmadd.wd(<16 x i16> %a1, <16 x i16> %a2) + %2 = add <8 x i32> %1, %a0 + ret <8 x i32> %2 +} + +define <8 x i32> @test_pmaddwd_v16i16_add_v8i32_commute(<8 x i32> %a0, <16 x i16> %a1, <16 x i16> %a2) { +; CHECK-LABEL: test_pmaddwd_v16i16_add_v8i32_commute: +; CHECK: # %bb.0: +; CHECK-NEXT: vpdpwssd %ymm2, %ymm1, %ymm0 +; CHECK-NEXT: retq + %1 = call <8 x i32> @llvm.x86.avx2.pmadd.wd(<16 x i16> %a1, <16 x i16> %a2) + %2 = add <8 x i32> %a0, %1 + ret <8 x i32> %2 +} + +define <8 x i32> @test_pmaddwd_v16i16_add_v8i32_load1(<8 x i32> %a0, <16 x i16>* %p1, <16 x i16> %a2) { +; CHECK-LABEL: test_pmaddwd_v16i16_add_v8i32_load1: +; CHECK: # %bb.0: +; CHECK-NEXT: vpdpwssd (%rdi), %ymm1, %ymm0 +; CHECK-NEXT: retq + %a1 = load <16 x i16>, <16 x i16>* %p1 + %1 = call <8 x i32> @llvm.x86.avx2.pmadd.wd(<16 x i16> %a1, <16 x i16> %a2) + %2 = add <8 x i32> %1, %a0 + ret <8 x i32> %2 +} + +define <8 x i32> @test_pmaddwd_v16i16_add_v8i32_load2(<8 x i32> %a0, <16 x i16> %a1, <16 x i16>* %p2) { +; CHECK-LABEL: test_pmaddwd_v16i16_add_v8i32_load2: +; CHECK: # %bb.0: +; CHECK-NEXT: vpdpwssd (%rdi), %ymm1, %ymm0 +; CHECK-NEXT: retq + %a2 = load <16 x i16>, <16 x i16>* %p2 + %1 = call <8 x i32> @llvm.x86.avx2.pmadd.wd(<16 x i16> %a1, <16 x i16> %a2) + %2 = add <8 x i32> %1, %a0 + ret <8 x i32> %2 +} + +define <8 x i32> @test_pmaddwd_v16i16_add_v8i32_commute_load1(<8 x i32> %a0, <16 x i16>* %p1, <16 x i16> %a2) { +; CHECK-LABEL: test_pmaddwd_v16i16_add_v8i32_commute_load1: +; CHECK: # %bb.0: +; CHECK-NEXT: vpdpwssd (%rdi), %ymm1, %ymm0 +; CHECK-NEXT: retq + %a1 = load <16 x i16>, <16 x i16>* %p1 + %1 = call <8 x i32> @llvm.x86.avx2.pmadd.wd(<16 x i16> %a1, <16 x i16> %a2) + %2 = add <8 x i32> %a0, %1 + ret <8 x i32> %2 +} + +define <8 x i32> @test_pmaddwd_v16i16_add_v8i32_commute_load2(<8 x i32> %a0, <16 x i16> %a1, <16 x i16>* %p2) { +; CHECK-LABEL: test_pmaddwd_v16i16_add_v8i32_commute_load2: +; CHECK: # %bb.0: +; CHECK-NEXT: vpdpwssd (%rdi), %ymm1, %ymm0 +; CHECK-NEXT: retq + %a2 = load <16 x i16>, <16 x i16>* %p2 + %1 = call <8 x i32> @llvm.x86.avx2.pmadd.wd(<16 x i16> %a1, <16 x i16> %a2) + %2 = add <8 x i32> %a0, %1 + ret <8 x i32> %2 +} + +define <16 x i32> @test_pmaddwd_v32i16_add_v16i32(<16 x i32> %a0, <32 x i16> %a1, <32 x i16> %a2) { +; CHECK-LABEL: test_pmaddwd_v32i16_add_v16i32: +; CHECK: # %bb.0: +; CHECK-NEXT: vpdpwssd %zmm2, %zmm1, %zmm0 +; CHECK-NEXT: retq + %1 = call <16 x i32> @llvm.x86.avx512.pmaddw.d.512(<32 x i16> %a1, <32 x i16> %a2) + %2 = add <16 x i32> %1, %a0 + ret <16 x i32> %2 +} + +define <16 x i32> @test_pmaddwd_v32i16_add_v16i32_commute(<16 x i32> %a0, <32 x i16> %a1, <32 x i16> %a2) { +; CHECK-LABEL: test_pmaddwd_v32i16_add_v16i32_commute: +; CHECK: # %bb.0: +; CHECK-NEXT: vpdpwssd %zmm2, %zmm1, %zmm0 +; CHECK-NEXT: retq + %1 = call <16 x i32> @llvm.x86.avx512.pmaddw.d.512(<32 x i16> %a1, <32 x i16> %a2) + %2 = add <16 x i32> %a0, %1 + ret <16 x i32> %2 +} + +define <16 x i32> @test_pmaddwd_v32i16_add_v16i32_load1(<16 x i32> %a0, <32 x i16>* %p1, <32 x i16> %a2) { +; CHECK-LABEL: test_pmaddwd_v32i16_add_v16i32_load1: +; CHECK: # %bb.0: +; CHECK-NEXT: vpdpwssd (%rdi), %zmm1, %zmm0 +; CHECK-NEXT: retq + %a1 = load <32 x i16>, <32 x i16>* %p1 + %1 = call <16 x i32> @llvm.x86.avx512.pmaddw.d.512(<32 x i16> %a1, <32 x i16> %a2) + %2 = add <16 x i32> %1, %a0 + ret <16 x i32> %2 +} + +define <16 x i32> @test_pmaddwd_v32i16_add_v16i32_load2(<16 x i32> %a0, <32 x i16> %a1, <32 x i16>* %p2) { +; CHECK-LABEL: test_pmaddwd_v32i16_add_v16i32_load2: +; CHECK: # %bb.0: +; CHECK-NEXT: vpdpwssd (%rdi), %zmm1, %zmm0 +; CHECK-NEXT: retq + %a2 = load <32 x i16>, <32 x i16>* %p2 + %1 = call <16 x i32> @llvm.x86.avx512.pmaddw.d.512(<32 x i16> %a1, <32 x i16> %a2) + %2 = add <16 x i32> %1, %a0 + ret <16 x i32> %2 +} + +define <16 x i32> @test_pmaddwd_v32i16_add_v16i32_commute_load1(<16 x i32> %a0, <32 x i16>* %p1, <32 x i16> %a2) { +; CHECK-LABEL: test_pmaddwd_v32i16_add_v16i32_commute_load1: +; CHECK: # %bb.0: +; CHECK-NEXT: vpdpwssd (%rdi), %zmm1, %zmm0 +; CHECK-NEXT: retq + %a1 = load <32 x i16>, <32 x i16>* %p1 + %1 = call <16 x i32> @llvm.x86.avx512.pmaddw.d.512(<32 x i16> %a1, <32 x i16> %a2) + %2 = add <16 x i32> %a0, %1 + ret <16 x i32> %2 +} + +define <16 x i32> @test_pmaddwd_v32i16_add_v16i32_commute_load2(<16 x i32> %a0, <32 x i16> %a1, <32 x i16>* %p2) { +; CHECK-LABEL: test_pmaddwd_v32i16_add_v16i32_commute_load2: +; CHECK: # %bb.0: +; CHECK-NEXT: vpdpwssd (%rdi), %zmm1, %zmm0 +; CHECK-NEXT: retq + %a2 = load <32 x i16>, <32 x i16>* %p2 + %1 = call <16 x i32> @llvm.x86.avx512.pmaddw.d.512(<32 x i16> %a1, <32 x i16> %a2) + %2 = add <16 x i32> %a0, %1 + ret <16 x i32> %2 +} + +declare <4 x i32> @llvm.x86.sse2.pmadd.wd(<8 x i16>, <8 x i16>) +declare <8 x i32> @llvm.x86.avx2.pmadd.wd(<16 x i16>, <16 x i16>) +declare <16 x i32> @llvm.x86.avx512.pmaddw.d.512(<32 x i16>, <32 x i16>) -- 2.40.0