From 2368d9f73de09cd5ff2ac44b70cbd9128893c036 Mon Sep 17 00:00:00 2001 From: Craig Topper Date: Sat, 26 Aug 2017 22:24:57 +0000 Subject: [PATCH] [AVX512] Add patterns to match masked extract_subvector with bitcasts between the vselect and the extract_subvector. Remove the late DAG combine. We used to do a late DAG combine to move the bitcasts out of the way, but I'm starting to think that it's better to canonicalize extract_subvector's type to match the type of its input. I've seen some cases where we've formed two different extract_subvector from the same node where one had a bitcast and the other didn't. Add some more test cases to ensure we've also got most of the zero masking covered too. git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@311837 91177308-0d34-0410-b5e6-96231b3b80d8 --- lib/Target/X86/X86ISelLowering.cpp | 21 ---- lib/Target/X86/X86InstrAVX512.td | 106 ++++++++++++++++++++ test/CodeGen/X86/vector-shuffle-masked.ll | 114 ++++++++++++++++++++++ 3 files changed, 220 insertions(+), 21 deletions(-) diff --git a/lib/Target/X86/X86ISelLowering.cpp b/lib/Target/X86/X86ISelLowering.cpp index 1897f6a0c9c..cb78d235fbe 100644 --- a/lib/Target/X86/X86ISelLowering.cpp +++ b/lib/Target/X86/X86ISelLowering.cpp @@ -30276,27 +30276,6 @@ static bool combineBitcastForMaskedOp(SDValue OrigOp, SelectionDAG &DAG, DAG.getIntPtrConstant(Imm, DL))); return true; } - case ISD::EXTRACT_SUBVECTOR: { - unsigned EltSize = EltVT.getSizeInBits(); - if (EltSize != 32 && EltSize != 64) - return false; - MVT OpEltVT = Op.getSimpleValueType().getVectorElementType(); - // Only change element size, not type. - if (EltVT.isInteger() != OpEltVT.isInteger()) - return false; - uint64_t Imm = cast(Op.getOperand(1))->getZExtValue(); - Imm = (Imm * OpEltVT.getSizeInBits()) / EltSize; - // Op0 needs to be bitcasted to a larger vector with the same element type. - SDValue Op0 = Op.getOperand(0); - MVT Op0VT = MVT::getVectorVT(EltVT, - Op0.getSimpleValueType().getSizeInBits() / EltSize); - Op0 = DAG.getBitcast(Op0VT, Op0); - DCI.AddToWorklist(Op0.getNode()); - DCI.CombineTo(OrigOp.getNode(), - DAG.getNode(Opcode, DL, VT, Op0, - DAG.getIntPtrConstant(Imm, DL))); - return true; - } case X86ISD::SUBV_BROADCAST: { unsigned EltSize = EltVT.getSizeInBits(); if (EltSize != 32 && EltSize != 64) diff --git a/lib/Target/X86/X86InstrAVX512.td b/lib/Target/X86/X86InstrAVX512.td index f73716e556f..f9a1c6ddb6c 100644 --- a/lib/Target/X86/X86InstrAVX512.td +++ b/lib/Target/X86/X86InstrAVX512.td @@ -887,6 +887,112 @@ def : Pat<(v64i8 (insert_subvector undef, (v32i8 VR256X:$src), (iPTR 0))), (INSERT_SUBREG (v64i8 (IMPLICIT_DEF)), VR256X:$src, sub_ymm)>; } +// Additional patterns for handling a bitcast between the vselect and the +// extract_subvector. +multiclass vextract_for_mask_cast p> { +let Predicates = p in { + def : Pat<(Cast.VT (vselect Cast.KRCWM:$mask, + (bitconvert + (To.VT (vextract_extract:$ext + (From.VT From.RC:$src), (iPTR imm)))), + To.RC:$src0)), + (Cast.VT (!cast(InstrStr#"rrk") + Cast.RC:$src0, Cast.KRCWM:$mask, From.RC:$src, + (EXTRACT_get_vextract_imm To.RC:$ext)))>; + + def : Pat<(Cast.VT (vselect Cast.KRCWM:$mask, + (bitconvert + (To.VT (vextract_extract:$ext + (From.VT From.RC:$src), (iPTR imm)))), + Cast.ImmAllZerosV)), + (Cast.VT (!cast(InstrStr#"rrkz") + Cast.KRCWM:$mask, From.RC:$src, + (EXTRACT_get_vextract_imm To.RC:$ext)))>; +} +} + +defm : vextract_for_mask_cast<"VEXTRACTF32x4Z256", v4f64x_info, v2f64x_info, + v4f32x_info, vextract128_extract, + EXTRACT_get_vextract128_imm, [HasVLX]>; +defm : vextract_for_mask_cast<"VEXTRACTF64x2Z256", v8f32x_info, v4f32x_info, + v2f64x_info, vextract128_extract, + EXTRACT_get_vextract128_imm, [HasDQI, HasVLX]>; + +defm : vextract_for_mask_cast<"VEXTRACTI32x4Z256", v4i64x_info, v2i64x_info, + v4i32x_info, vextract128_extract, + EXTRACT_get_vextract128_imm, [HasVLX]>; +defm : vextract_for_mask_cast<"VEXTRACTI32x4Z256", v16i16x_info, v8i16x_info, + v4i32x_info, vextract128_extract, + EXTRACT_get_vextract128_imm, [HasVLX]>; +defm : vextract_for_mask_cast<"VEXTRACTI32x4Z256", v32i8x_info, v16i8x_info, + v4i32x_info, vextract128_extract, + EXTRACT_get_vextract128_imm, [HasVLX]>; +defm : vextract_for_mask_cast<"VEXTRACTI64x2Z256", v8i32x_info, v4i32x_info, + v2i64x_info, vextract128_extract, + EXTRACT_get_vextract128_imm, [HasDQI, HasVLX]>; +defm : vextract_for_mask_cast<"VEXTRACTI64x2Z256", v16i16x_info, v8i16x_info, + v2i64x_info, vextract128_extract, + EXTRACT_get_vextract128_imm, [HasDQI, HasVLX]>; +defm : vextract_for_mask_cast<"VEXTRACTI64x2Z256", v32i8x_info, v16i8x_info, + v2i64x_info, vextract128_extract, + EXTRACT_get_vextract128_imm, [HasDQI, HasVLX]>; + +defm : vextract_for_mask_cast<"VEXTRACTF32x4Z", v8f64_info, v2f64x_info, + v4f32x_info, vextract128_extract, + EXTRACT_get_vextract128_imm, [HasAVX512]>; +defm : vextract_for_mask_cast<"VEXTRACTF64x2Z", v16f32_info, v4f32x_info, + v2f64x_info, vextract128_extract, + EXTRACT_get_vextract128_imm, [HasDQI]>; + +defm : vextract_for_mask_cast<"VEXTRACTI32x4Z", v8i64_info, v2i64x_info, + v4i32x_info, vextract128_extract, + EXTRACT_get_vextract128_imm, [HasAVX512]>; +defm : vextract_for_mask_cast<"VEXTRACTI32x4Z", v32i16_info, v8i16x_info, + v4i32x_info, vextract128_extract, + EXTRACT_get_vextract128_imm, [HasAVX512]>; +defm : vextract_for_mask_cast<"VEXTRACTI32x4Z", v64i8_info, v16i8x_info, + v4i32x_info, vextract128_extract, + EXTRACT_get_vextract128_imm, [HasAVX512]>; +defm : vextract_for_mask_cast<"VEXTRACTI64x2Z", v16i32_info, v4i32x_info, + v2i64x_info, vextract128_extract, + EXTRACT_get_vextract128_imm, [HasDQI]>; +defm : vextract_for_mask_cast<"VEXTRACTI64x2Z", v32i16_info, v8i16x_info, + v2i64x_info, vextract128_extract, + EXTRACT_get_vextract128_imm, [HasDQI]>; +defm : vextract_for_mask_cast<"VEXTRACTI64x2Z", v64i8_info, v16i8x_info, + v2i64x_info, vextract128_extract, + EXTRACT_get_vextract128_imm, [HasDQI]>; + +defm : vextract_for_mask_cast<"VEXTRACTF32x8Z", v8f64_info, v4f64x_info, + v8f32x_info, vextract256_extract, + EXTRACT_get_vextract256_imm, [HasDQI]>; +defm : vextract_for_mask_cast<"VEXTRACTF64x4Z", v16f32_info, v8f32x_info, + v4f64x_info, vextract256_extract, + EXTRACT_get_vextract256_imm, [HasAVX512]>; + +defm : vextract_for_mask_cast<"VEXTRACTI32x8Z", v8i64_info, v4i64x_info, + v8i32x_info, vextract256_extract, + EXTRACT_get_vextract256_imm, [HasDQI]>; +defm : vextract_for_mask_cast<"VEXTRACTI32x8Z", v32i16_info, v16i16x_info, + v8i32x_info, vextract256_extract, + EXTRACT_get_vextract256_imm, [HasDQI]>; +defm : vextract_for_mask_cast<"VEXTRACTI32x8Z", v64i8_info, v32i8x_info, + v8i32x_info, vextract256_extract, + EXTRACT_get_vextract256_imm, [HasDQI]>; +defm : vextract_for_mask_cast<"VEXTRACTI64x4Z", v16i32_info, v8i32x_info, + v4i64x_info, vextract256_extract, + EXTRACT_get_vextract256_imm, [HasAVX512]>; +defm : vextract_for_mask_cast<"VEXTRACTI64x4Z", v32i16_info, v16i16x_info, + v4i64x_info, vextract256_extract, + EXTRACT_get_vextract256_imm, [HasAVX512]>; +defm : vextract_for_mask_cast<"VEXTRACTI64x4Z", v64i8_info, v32i8x_info, + v4i64x_info, vextract256_extract, + EXTRACT_get_vextract256_imm, [HasAVX512]>; + // vextractps - extract 32 bits from XMM def VEXTRACTPSZrr : AVX512AIi8<0x17, MRMDestReg, (outs GR32:$dst), (ins VR128X:$src1, u8imm:$src2), diff --git a/test/CodeGen/X86/vector-shuffle-masked.ll b/test/CodeGen/X86/vector-shuffle-masked.ll index cd543cd13f3..3b82d191a0c 100644 --- a/test/CodeGen/X86/vector-shuffle-masked.ll +++ b/test/CodeGen/X86/vector-shuffle-masked.ll @@ -1031,6 +1031,19 @@ define <8 x i32> @mask_cast_extract_v8i64_v8i32_1(<8 x i64> %a, <8 x i32> %passt ret <8 x i32> %res } +define <8 x i32> @mask_cast_extract_v8i64_v8i32_1_z(<8 x i64> %a, i8 %mask) { +; CHECK-LABEL: mask_cast_extract_v8i64_v8i32_1_z: +; CHECK: # BB#0: +; CHECK-NEXT: kmovw %edi, %k1 +; CHECK-NEXT: vextracti32x8 $1, %zmm0, %ymm0 {%k1} {z} +; CHECK-NEXT: retq + %shuffle = shufflevector <8 x i64> %a, <8 x i64> undef, <4 x i32> + %shuffle.cast = bitcast <4 x i64> %shuffle to <8 x i32> + %mask.cast = bitcast i8 %mask to <8 x i1> + %res = select <8 x i1> %mask.cast, <8 x i32> %shuffle.cast, <8 x i32> zeroinitializer + ret <8 x i32> %res +} + define <8 x float> @mask_cast_extract_v8f64_v8f32_1(<8 x double> %a, <8 x float> %passthru, i8 %mask) { ; CHECK-LABEL: mask_cast_extract_v8f64_v8f32_1: ; CHECK: # BB#0: @@ -1045,6 +1058,19 @@ define <8 x float> @mask_cast_extract_v8f64_v8f32_1(<8 x double> %a, <8 x float> ret <8 x float> %res } +define <8 x float> @mask_cast_extract_v8f64_v8f32_1_z(<8 x double> %a, i8 %mask) { +; CHECK-LABEL: mask_cast_extract_v8f64_v8f32_1_z: +; CHECK: # BB#0: +; CHECK-NEXT: kmovw %edi, %k1 +; CHECK-NEXT: vextractf32x8 $1, %zmm0, %ymm0 {%k1} {z} +; CHECK-NEXT: retq + %shuffle = shufflevector <8 x double> %a, <8 x double> undef, <4 x i32> + %shuffle.cast = bitcast <4 x double> %shuffle to <8 x float> + %mask.cast = bitcast i8 %mask to <8 x i1> + %res = select <8 x i1> %mask.cast, <8 x float> %shuffle.cast, <8 x float> zeroinitializer + ret <8 x float> %res +} + define <4 x i32> @mask_cast_extract_v8i64_v4i32_1(<8 x i64> %a, <4 x i32> %passthru, i8 %mask) { ; CHECK-LABEL: mask_cast_extract_v8i64_v4i32_1: ; CHECK: # BB#0: @@ -1061,6 +1087,21 @@ define <4 x i32> @mask_cast_extract_v8i64_v4i32_1(<8 x i64> %a, <4 x i32> %passt ret <4 x i32> %res } +define <4 x i32> @mask_cast_extract_v8i64_v4i32_1_z(<8 x i64> %a, i8 %mask) { +; CHECK-LABEL: mask_cast_extract_v8i64_v4i32_1_z: +; CHECK: # BB#0: +; CHECK-NEXT: kmovw %edi, %k1 +; CHECK-NEXT: vextracti32x4 $1, %zmm0, %xmm0 {%k1} {z} +; CHECK-NEXT: vzeroupper +; CHECK-NEXT: retq + %shuffle = shufflevector <8 x i64> %a, <8 x i64> undef, <2 x i32> + %shuffle.cast = bitcast <2 x i64> %shuffle to <4 x i32> + %mask.cast = bitcast i8 %mask to <8 x i1> + %mask.extract = shufflevector <8 x i1> %mask.cast, <8 x i1> undef, <4 x i32> + %res = select <4 x i1> %mask.extract, <4 x i32> %shuffle.cast, <4 x i32> zeroinitializer + ret <4 x i32> %res +} + define <4 x float> @mask_cast_extract_v8f64_v4f32_1(<8 x double> %a, <4 x float> %passthru, i8 %mask) { ; CHECK-LABEL: mask_cast_extract_v8f64_v4f32_1: ; CHECK: # BB#0: @@ -1077,6 +1118,21 @@ define <4 x float> @mask_cast_extract_v8f64_v4f32_1(<8 x double> %a, <4 x float> ret <4 x float> %res } +define <4 x float> @mask_cast_extract_v8f64_v4f32_1_z(<8 x double> %a, i8 %mask) { +; CHECK-LABEL: mask_cast_extract_v8f64_v4f32_1_z: +; CHECK: # BB#0: +; CHECK-NEXT: kmovw %edi, %k1 +; CHECK-NEXT: vextractf32x4 $1, %zmm0, %xmm0 {%k1} {z} +; CHECK-NEXT: vzeroupper +; CHECK-NEXT: retq + %shuffle = shufflevector <8 x double> %a, <8 x double> undef, <2 x i32> + %shuffle.cast = bitcast <2 x double> %shuffle to <4 x float> + %mask.cast = bitcast i8 %mask to <8 x i1> + %mask.extract = shufflevector <8 x i1> %mask.cast, <8 x i1> undef, <4 x i32> + %res = select <4 x i1> %mask.extract, <4 x float> %shuffle.cast, <4 x float> zeroinitializer + ret <4 x float> %res +} + define <4 x i64> @mask_cast_extract_v16i32_v4i64_1(<16 x i32> %a, <4 x i64> %passthru, i8 %mask) { ; CHECK-LABEL: mask_cast_extract_v16i32_v4i64_1: ; CHECK: # BB#0: @@ -1092,6 +1148,20 @@ define <4 x i64> @mask_cast_extract_v16i32_v4i64_1(<16 x i32> %a, <4 x i64> %pas ret <4 x i64> %res } +define <4 x i64> @mask_cast_extract_v16i32_v4i64_1_z(<16 x i32> %a, i8 %mask) { +; CHECK-LABEL: mask_cast_extract_v16i32_v4i64_1_z: +; CHECK: # BB#0: +; CHECK-NEXT: kmovw %edi, %k1 +; CHECK-NEXT: vextracti64x4 $1, %zmm0, %ymm0 {%k1} {z} +; CHECK-NEXT: retq + %shuffle = shufflevector <16 x i32> %a, <16 x i32> undef, <8 x i32> + %shuffle.cast = bitcast <8 x i32> %shuffle to <4 x i64> + %mask.cast = bitcast i8 %mask to <8 x i1> + %mask.extract = shufflevector <8 x i1> %mask.cast, <8 x i1> undef, <4 x i32> + %res = select <4 x i1> %mask.extract, <4 x i64> %shuffle.cast, <4 x i64> zeroinitializer + ret <4 x i64> %res +} + define <4 x double> @mask_cast_extract_v16f32_v4f64_1(<16 x float> %a, <4 x double> %passthru, i8 %mask) { ; CHECK-LABEL: mask_cast_extract_v16f32_v4f64_1: ; CHECK: # BB#0: @@ -1107,6 +1177,20 @@ define <4 x double> @mask_cast_extract_v16f32_v4f64_1(<16 x float> %a, <4 x doub ret <4 x double> %res } +define <4 x double> @mask_cast_extract_v16f32_v4f64_1_z(<16 x float> %a, i8 %mask) { +; CHECK-LABEL: mask_cast_extract_v16f32_v4f64_1_z: +; CHECK: # BB#0: +; CHECK-NEXT: kmovw %edi, %k1 +; CHECK-NEXT: vextractf64x4 $1, %zmm0, %ymm0 {%k1} {z} +; CHECK-NEXT: retq + %shuffle = shufflevector <16 x float> %a, <16 x float> undef, <8 x i32> + %shuffle.cast = bitcast <8 x float> %shuffle to <4 x double> + %mask.cast = bitcast i8 %mask to <8 x i1> + %mask.extract = shufflevector <8 x i1> %mask.cast, <8 x i1> undef, <4 x i32> + %res = select <4 x i1> %mask.extract, <4 x double> %shuffle.cast, <4 x double> zeroinitializer + ret <4 x double> %res +} + define <2 x i64> @mask_cast_extract_v16i32_v2i64_1(<16 x i32> %a, <2 x i64> %passthru, i8 %mask) { ; CHECK-LABEL: mask_cast_extract_v16i32_v2i64_1: ; CHECK: # BB#0: @@ -1123,6 +1207,21 @@ define <2 x i64> @mask_cast_extract_v16i32_v2i64_1(<16 x i32> %a, <2 x i64> %pas ret <2 x i64> %res } +define <2 x i64> @mask_cast_extract_v16i32_v2i64_1_z(<16 x i32> %a, i8 %mask) { +; CHECK-LABEL: mask_cast_extract_v16i32_v2i64_1_z: +; CHECK: # BB#0: +; CHECK-NEXT: kmovw %edi, %k1 +; CHECK-NEXT: vextracti64x2 $1, %zmm0, %xmm0 {%k1} {z} +; CHECK-NEXT: vzeroupper +; CHECK-NEXT: retq + %shuffle = shufflevector <16 x i32> %a, <16 x i32> undef, <4 x i32> + %shuffle.cast = bitcast <4 x i32> %shuffle to <2 x i64> + %mask.cast = bitcast i8 %mask to <8 x i1> + %mask.extract = shufflevector <8 x i1> %mask.cast, <8 x i1> undef, <2 x i32> + %res = select <2 x i1> %mask.extract, <2 x i64> %shuffle.cast, <2 x i64> zeroinitializer + ret <2 x i64> %res +} + define <2 x double> @mask_cast_extract_v16f32_v2f64_1(<16 x float> %a, <2 x double> %passthru, i8 %mask) { ; CHECK-LABEL: mask_cast_extract_v16f32_v2f64_1: ; CHECK: # BB#0: @@ -1139,6 +1238,21 @@ define <2 x double> @mask_cast_extract_v16f32_v2f64_1(<16 x float> %a, <2 x doub ret <2 x double> %res } +define <2 x double> @mask_cast_extract_v16f32_v2f64_1_z(<16 x float> %a, i8 %mask) { +; CHECK-LABEL: mask_cast_extract_v16f32_v2f64_1_z: +; CHECK: # BB#0: +; CHECK-NEXT: kmovw %edi, %k1 +; CHECK-NEXT: vextractf64x2 $1, %zmm0, %xmm0 {%k1} {z} +; CHECK-NEXT: vzeroupper +; CHECK-NEXT: retq + %shuffle = shufflevector <16 x float> %a, <16 x float> undef, <4 x i32> + %shuffle.cast = bitcast <4 x float> %shuffle to <2 x double> + %mask.cast = bitcast i8 %mask to <8 x i1> + %mask.extract = shufflevector <8 x i1> %mask.cast, <8 x i1> undef, <2 x i32> + %res = select <2 x i1> %mask.extract, <2 x double> %shuffle.cast, <2 x double> zeroinitializer + ret <2 x double> %res +} + define <2 x double> @broadcast_v4f32_0101_from_v2f32_mask(double* %x, <2 x double> %passthru, i8 %mask) { ; CHECK-LABEL: broadcast_v4f32_0101_from_v2f32_mask: ; CHECK: # BB#0: -- 2.50.1