]> granicus.if.org Git - llvm/commitdiff
[X86] Make IFMA instructions during isel so we can fold broadcast loads.
authorCraig Topper <craig.topper@intel.com>
Sun, 24 Sep 2017 19:30:55 +0000 (19:30 +0000)
committerCraig Topper <craig.topper@intel.com>
Sun, 24 Sep 2017 19:30:55 +0000 (19:30 +0000)
This required changing the ISD opcode for these instructions to have the commutable operands first and the addend last. This way tablegen can autogenerate the additional patterns for us.

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@314083 91177308-0d34-0410-b5e6-96231b3b80d8

lib/Target/X86/X86ISelLowering.cpp
lib/Target/X86/X86ISelLowering.h
lib/Target/X86/X86InstrAVX512.td
lib/Target/X86/X86InstrFragmentsSIMD.td
lib/Target/X86/X86IntrinsicsInfo.h
test/CodeGen/X86/avx512ifma-intrinsics.ll

index c77d6f29810d8d261880b8d8b8ccf46001ff496a..0d2b84b134759e1d4ec202f1ec32987da22887f9 100644 (file)
@@ -19580,6 +19580,26 @@ static SDValue LowerINTRINSIC_WO_CHAIN(SDValue Op, const X86Subtarget &Subtarget
                                               Src3, Rnd),
                                   Mask, PassThru, Subtarget, DAG);
     }
+    case IFMA_OP_MASKZ:
+    case IFMA_OP_MASK: {
+      SDValue Src1 = Op.getOperand(1);
+      SDValue Src2 = Op.getOperand(2);
+      SDValue Src3 = Op.getOperand(3);
+      SDValue Mask = Op.getOperand(4);
+      MVT VT = Op.getSimpleValueType();
+      SDValue PassThru = Src1;
+
+      // set PassThru element
+      if (IntrData->Type == IFMA_OP_MASKZ)
+        PassThru = getZeroVector(VT, Subtarget, DAG, dl);
+
+      // Node we need to swizzle the operands to pass the multiply operands
+      // first.
+      return getVectorMaskingNode(DAG.getNode(IntrData->Opc0,
+                                              dl, Op.getValueType(),
+                                              Src2, Src3, Src1),
+                                  Mask, PassThru, Subtarget, DAG);
+    }
     case TERLOG_OP_MASK:
     case TERLOG_OP_MASKZ: {
       SDValue Src1 = Op.getOperand(1);
index a1ae1613c7379f64cfacc0346a97adcf4092c38b..a0fc4c8a29c66dd32f38b1ccd42d50304c185c45 100644 (file)
@@ -467,6 +467,10 @@ namespace llvm {
 
       // Multiply and Add Packed Integers.
       VPMADDUBSW, VPMADDWD,
+
+      // AVX512IFMA multiply and add.
+      // NOTE: These are different than the instruction and perform
+      // op0 x op1 + op2.
       VPMADD52L, VPMADD52H,
 
       // FMA nodes.
index a8b7c80cdab95e1a5154a1c973518661c86eb8b4..7ae4352683b76dd4e167a10b60563fcc406a4321 100644 (file)
@@ -6480,30 +6480,30 @@ defm VFNMSUB : avx512_fma3s<0xAF, 0xBF, 0x9F, "vfnmsub", X86Fnmsub,
 let Constraints = "$src1 = $dst" in {
 multiclass avx512_pmadd52_rm<bits<8> opc, string OpcodeStr, SDNode OpNode,
                                                             X86VectorVTInfo _> {
+  // NOTE: The SDNode have the multiply operands first with the add last.
+  // This enables commuted load patterns to be autogenerated by tablegen.
   let ExeDomain = _.ExeDomain in {
   defm r: AVX512_maskable_3src<opc, MRMSrcReg, _, (outs _.RC:$dst),
           (ins _.RC:$src2, _.RC:$src3),
           OpcodeStr, "$src3, $src2", "$src2, $src3",
-          (_.VT (OpNode _.RC:$src1, _.RC:$src2, _.RC:$src3)), 1, 1>,
+          (_.VT (OpNode _.RC:$src2, _.RC:$src3, _.RC:$src1)), 1, 1>,
          AVX512FMA3Base;
 
   defm m: AVX512_maskable_3src<opc, MRMSrcMem, _, (outs _.RC:$dst),
           (ins _.RC:$src2, _.MemOp:$src3),
           OpcodeStr, "$src3, $src2", "$src2, $src3",
-          (_.VT (OpNode _.RC:$src1, _.RC:$src2, (_.LdFrag addr:$src3)))>,
+          (_.VT (OpNode _.RC:$src2, (_.LdFrag addr:$src3), _.RC:$src1))>,
           AVX512FMA3Base;
 
   defm mb: AVX512_maskable_3src<opc, MRMSrcMem, _, (outs _.RC:$dst),
             (ins _.RC:$src2, _.ScalarMemOp:$src3),
             OpcodeStr,   !strconcat("${src3}", _.BroadcastStr,", $src2"),
             !strconcat("$src2, ${src3}", _.BroadcastStr ),
-            (OpNode _.RC:$src1,
-             _.RC:$src2,(_.VT (X86VBroadcast (_.ScalarLdFrag addr:$src3))))>,
+            (OpNode _.RC:$src2,
+                    (_.VT (X86VBroadcast (_.ScalarLdFrag addr:$src3))),
+                    _.RC:$src1)>,
             AVX512FMA3Base, EVEX_B;
   }
-
-  // TODO: Should be able to match a memory op in operand 2.
-  // TODO: These instructions should be marked Commutable on operand 2 and 3.
 }
 } // Constraints = "$src1 = $dst"
 
index 38a2e4cacbc4a97bcbb0e8b32a311b9e62770799..c01c8af6d4ac5207c6dcb7f6b31c911fdcecf4e7 100644 (file)
@@ -498,8 +498,8 @@ def X86FnmsubRnds3  : SDNode<"X86ISD::FNMSUBS3_RND",    SDTFmaRound, [SDNPCommut
 
 def SDTIFma : SDTypeProfile<1, 3, [SDTCisInt<0>, SDTCisSameAs<0,1>,
                            SDTCisSameAs<1,2>, SDTCisSameAs<1,3>]>;
-def x86vpmadd52l     : SDNode<"X86ISD::VPMADD52L",     SDTIFma>;
-def x86vpmadd52h     : SDNode<"X86ISD::VPMADD52H",     SDTIFma>;
+def x86vpmadd52l     : SDNode<"X86ISD::VPMADD52L",     SDTIFma, [SDNPCommutative]>;
+def x86vpmadd52h     : SDNode<"X86ISD::VPMADD52H",     SDTIFma, [SDNPCommutative]>;
 
 def X86rsqrt28   : SDNode<"X86ISD::RSQRT28",  SDTFPUnaryOpRound>;
 def X86rcp28     : SDNode<"X86ISD::RCP28",    SDTFPUnaryOpRound>;
index d9d0b06c96077a428d6e42397d39d7d08a652f02..40a04d6f2f6a952c7bc4569812ac9e3a4cfb1662 100644 (file)
@@ -30,6 +30,7 @@ enum IntrinsicType : uint16_t {
   INTR_TYPE_3OP_MASK, INTR_TYPE_3OP_MASK_RM, INTR_TYPE_3OP_IMM8_MASK,
   FMA_OP_MASK, FMA_OP_MASKZ, FMA_OP_MASK3,
   FMA_OP_SCALAR_MASK, FMA_OP_SCALAR_MASKZ, FMA_OP_SCALAR_MASK3,
+  IFMA_OP_MASK, IFMA_OP_MASKZ,
   VPERM_2OP_MASK, VPERM_3OP_MASK, VPERM_3OP_MASKZ, INTR_TYPE_SCALAR_MASK,
   INTR_TYPE_SCALAR_MASK_RM, INTR_TYPE_3OP_SCALAR_MASK_RM,
   COMPRESS_EXPAND_IN_REG, COMPRESS_TO_MEM, BRCST32x2_TO_VEC,
@@ -1208,17 +1209,17 @@ static const IntrinsicData  IntrinsicsWithoutChain[] = {
                     X86ISD::VPERMV3, 0),
   X86_INTRINSIC_DATA(avx512_mask_vpermt2var_qi_512, VPERM_3OP_MASK,
                     X86ISD::VPERMV3, 0),
-  X86_INTRINSIC_DATA(avx512_mask_vpmadd52h_uq_128 , FMA_OP_MASK,
+  X86_INTRINSIC_DATA(avx512_mask_vpmadd52h_uq_128 , IFMA_OP_MASK,
                      X86ISD::VPMADD52H, 0),
-  X86_INTRINSIC_DATA(avx512_mask_vpmadd52h_uq_256 , FMA_OP_MASK,
+  X86_INTRINSIC_DATA(avx512_mask_vpmadd52h_uq_256 , IFMA_OP_MASK,
                      X86ISD::VPMADD52H, 0),
-  X86_INTRINSIC_DATA(avx512_mask_vpmadd52h_uq_512 , FMA_OP_MASK,
+  X86_INTRINSIC_DATA(avx512_mask_vpmadd52h_uq_512 , IFMA_OP_MASK,
                      X86ISD::VPMADD52H, 0),
-  X86_INTRINSIC_DATA(avx512_mask_vpmadd52l_uq_128 , FMA_OP_MASK,
+  X86_INTRINSIC_DATA(avx512_mask_vpmadd52l_uq_128 , IFMA_OP_MASK,
                      X86ISD::VPMADD52L, 0),
-  X86_INTRINSIC_DATA(avx512_mask_vpmadd52l_uq_256 , FMA_OP_MASK,
+  X86_INTRINSIC_DATA(avx512_mask_vpmadd52l_uq_256 , IFMA_OP_MASK,
                      X86ISD::VPMADD52L, 0),
-  X86_INTRINSIC_DATA(avx512_mask_vpmadd52l_uq_512 , FMA_OP_MASK,
+  X86_INTRINSIC_DATA(avx512_mask_vpmadd52l_uq_512 , IFMA_OP_MASK,
                      X86ISD::VPMADD52L, 0),
   X86_INTRINSIC_DATA(avx512_mask3_vfmadd_pd_128, FMA_OP_MASK3, ISD::FMA, 0),
   X86_INTRINSIC_DATA(avx512_mask3_vfmadd_pd_256, FMA_OP_MASK3, ISD::FMA, 0),
@@ -1354,17 +1355,17 @@ static const IntrinsicData  IntrinsicsWithoutChain[] = {
                      X86ISD::VPERMV3, 0),
   X86_INTRINSIC_DATA(avx512_maskz_vpermt2var_qi_512, VPERM_3OP_MASKZ,
                      X86ISD::VPERMV3, 0),
-  X86_INTRINSIC_DATA(avx512_maskz_vpmadd52h_uq_128, FMA_OP_MASKZ,
+  X86_INTRINSIC_DATA(avx512_maskz_vpmadd52h_uq_128, IFMA_OP_MASKZ,
                      X86ISD::VPMADD52H, 0),
-  X86_INTRINSIC_DATA(avx512_maskz_vpmadd52h_uq_256, FMA_OP_MASKZ,
+  X86_INTRINSIC_DATA(avx512_maskz_vpmadd52h_uq_256, IFMA_OP_MASKZ,
                      X86ISD::VPMADD52H, 0),
-  X86_INTRINSIC_DATA(avx512_maskz_vpmadd52h_uq_512, FMA_OP_MASKZ,
+  X86_INTRINSIC_DATA(avx512_maskz_vpmadd52h_uq_512, IFMA_OP_MASKZ,
                      X86ISD::VPMADD52H, 0),
-  X86_INTRINSIC_DATA(avx512_maskz_vpmadd52l_uq_128, FMA_OP_MASKZ,
+  X86_INTRINSIC_DATA(avx512_maskz_vpmadd52l_uq_128, IFMA_OP_MASKZ,
                      X86ISD::VPMADD52L, 0),
-  X86_INTRINSIC_DATA(avx512_maskz_vpmadd52l_uq_256, FMA_OP_MASKZ,
+  X86_INTRINSIC_DATA(avx512_maskz_vpmadd52l_uq_256, IFMA_OP_MASKZ,
                      X86ISD::VPMADD52L, 0),
-  X86_INTRINSIC_DATA(avx512_maskz_vpmadd52l_uq_512, FMA_OP_MASKZ,
+  X86_INTRINSIC_DATA(avx512_maskz_vpmadd52l_uq_512, IFMA_OP_MASKZ,
                      X86ISD::VPMADD52L, 0),
   X86_INTRINSIC_DATA(avx512_packssdw_512, INTR_TYPE_2OP, X86ISD::PACKSS, 0),
   X86_INTRINSIC_DATA(avx512_packsswb_512, INTR_TYPE_2OP, X86ISD::PACKSS, 0),
index 5fd7f0f2f70be4813039b57883600e9d5a014567..8a0f8d9df6216f2f40d75e9dfe29b0165811a2b2 100644 (file)
@@ -151,8 +151,7 @@ define <8 x i64>@test_int_x86_avx512_vpmadd52h_uq_512_load_commute(<8 x i64> %x0
 define <8 x i64>@test_int_x86_avx512_vpmadd52h_uq_512_load_commute_bcast(<8 x i64> %x0, i64* %x1ptr, <8 x i64> %x2) {
 ; CHECK-LABEL: test_int_x86_avx512_vpmadd52h_uq_512_load_commute_bcast:
 ; CHECK:       ## BB#0:
-; CHECK-NEXT:    vpbroadcastq (%rdi), %zmm2
-; CHECK-NEXT:    vpmadd52huq %zmm1, %zmm2, %zmm0
+; CHECK-NEXT:    vpmadd52huq (%rdi){1to8}, %zmm1, %zmm0
 ; CHECK-NEXT:    retq
 
   %x1load = load i64, i64* %x1ptr
@@ -204,8 +203,7 @@ define <8 x i64>@test_int_x86_avx512_mask_vpmadd52h_uq_512_load_commute_bcast(<8
 ; CHECK-LABEL: test_int_x86_avx512_mask_vpmadd52h_uq_512_load_commute_bcast:
 ; CHECK:       ## BB#0:
 ; CHECK-NEXT:    kmovw %esi, %k1
-; CHECK-NEXT:    vpbroadcastq (%rdi), %zmm2
-; CHECK-NEXT:    vpmadd52huq %zmm1, %zmm2, %zmm0 {%k1}
+; CHECK-NEXT:    vpmadd52huq (%rdi){1to8}, %zmm1, %zmm0 {%k1}
 ; CHECK-NEXT:    retq
 
   %x1load = load i64, i64* %x1ptr
@@ -257,8 +255,7 @@ define <8 x i64>@test_int_x86_avx512_maskz_vpmadd52h_uq_512_load_commute_bcast(<
 ; CHECK-LABEL: test_int_x86_avx512_maskz_vpmadd52h_uq_512_load_commute_bcast:
 ; CHECK:       ## BB#0:
 ; CHECK-NEXT:    kmovw %esi, %k1
-; CHECK-NEXT:    vpbroadcastq (%rdi), %zmm2
-; CHECK-NEXT:    vpmadd52huq %zmm1, %zmm2, %zmm0 {%k1} {z}
+; CHECK-NEXT:    vpmadd52huq (%rdi){1to8}, %zmm1, %zmm0 {%k1} {z}
 ; CHECK-NEXT:    retq
 
   %x1load = load i64, i64* %x1ptr