]> granicus.if.org Git - llvm/commitdiff
[X86] Regcall - Adding support for mask types
authorOren Ben Simhon <oren.ben.simhon@intel.com>
Sun, 11 Dec 2016 14:10:52 +0000 (14:10 +0000)
committerOren Ben Simhon <oren.ben.simhon@intel.com>
Sun, 11 Dec 2016 14:10:52 +0000 (14:10 +0000)
Regcall calling convention passes mask types arguments in x86 GPR registers.
The review includes the changes required in order to support v32i1, v16i1 and v8i1.

Differential Revision: https://reviews.llvm.org/D27148

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@289383 91177308-0d34-0410-b5e6-96231b3b80d8

lib/Target/X86/X86CallingConv.h
lib/Target/X86/X86CallingConv.td
lib/Target/X86/X86ISelLowering.cpp
test/CodeGen/X86/avx512-regcall-Mask.ll

index 41fbd2e4474cfc0d8eb82f873d57d822461eac27..2e93ec9c78cab94cda53b37e749075ff89cd0663 100644 (file)
@@ -41,7 +41,6 @@ inline bool CC_X86_32_VectorCallIndirect(unsigned &ValNo, MVT &ValVT,
   return false; // Continue the search, but now for i32.
 }
 
-
 inline bool CC_X86_AnyReg_Error(unsigned &, MVT &, MVT &,
                                 CCValAssign::LocInfo &, ISD::ArgFlagsTy &,
                                 CCState &) {
@@ -51,13 +50,6 @@ inline bool CC_X86_AnyReg_Error(unsigned &, MVT &, MVT &,
   return false;
 }
 
-inline bool CC_X86_RegCall_Error(unsigned &, MVT &, MVT &,
-                                 CCValAssign::LocInfo &, ISD::ArgFlagsTy &,
-                                 CCState &) {
-  report_fatal_error("LLVM x86 RegCall calling convention implementation" \
-    " doesn't support long double and mask types yet.");
-}
-
 inline bool CC_X86_32_MCUInReg(unsigned &ValNo, MVT &ValVT,
                                          MVT &LocVT,
                                          CCValAssign::LocInfo &LocInfo,
index eab11196c826b7e3a8ed6e25860e6239d096a769..a0c822ff0ab4b3e3f2559d01f2bcb68910e8d239 100644 (file)
@@ -76,6 +76,9 @@ def CC_#NAME : CallingConv<[
     // Promote i1/i8/i16 arguments to i32.
     CCIfType<[i1, i8, i16], CCPromoteToType<i32>>,
 
+    // Promote v8i1/v16i1/v32i1 arguments to i32.
+    CCIfType<[v8i1, v16i1, v32i1], CCPromoteToType<i32>>,
+
     // bool, char, int, enum, long, pointer --> GPR
     CCIfType<[i32], CCAssignToReg<RC.GPR_32>>,
 
@@ -89,9 +92,6 @@ def CC_#NAME : CallingConv<[
     CCIfSubtarget<"is32Bit()", CCIfType<[i64], 
       CCCustom<"CC_X86_32_RegCall_Assign2Regs">>>,
 
-    // TODO: Handle the case of mask types (v*i1)
-    CCIfType<[v8i1, v16i1, v32i1], CCCustom<"CC_X86_RegCall_Error">>,
-
     // float, double, float128 --> XMM
     // In the case of SSE disabled --> save to stack
     CCIfType<[f32, f64, f128], 
@@ -146,8 +146,14 @@ def CC_#NAME : CallingConv<[
 ]>;
 
 def RetCC_#NAME : CallingConv<[
-    // Promote i1 arguments to i8.
-    CCIfType<[i1], CCPromoteToType<i8>>,
+    // Promote i1, v8i1 arguments to i8.
+    CCIfType<[i1, v8i1], CCPromoteToType<i8>>,
+
+    // Promote v16i1 arguments to i16.
+    CCIfType<[v16i1], CCPromoteToType<i16>>,
+
+    // Promote v32i1 arguments to i32.
+    CCIfType<[v32i1], CCPromoteToType<i32>>,
 
     // bool, char, int, enum, long, pointer --> GPR
     CCIfType<[i8], CCAssignToReg<RC.GPR_8>>,
@@ -164,9 +170,6 @@ def RetCC_#NAME : CallingConv<[
     CCIfSubtarget<"is32Bit()", CCIfType<[i64], 
       CCCustom<"CC_X86_32_RegCall_Assign2Regs">>>,
 
-    // TODO: Handle the case of mask types (v*i1)
-    CCIfType<[v8i1, v16i1, v32i1], CCCustom<"CC_X86_RegCall_Error">>,
-
     // long double --> FP
     CCIfType<[f80], CCAssignToReg<RC.FP_RET>>,
 
index 3e3da8b76d071f4f0abda6affa8c0e31208e3a3c..94d418c827543f338d36a3be153860fb070112e0 100644 (file)
@@ -2100,14 +2100,26 @@ const MCPhysReg *X86TargetLowering::getScratchRegisters(CallingConv::ID) const {
 }
 
 /// Lowers masks values (v*i1) to the local register values
+/// \returns DAG node after lowering to register type
 static SDValue lowerMasksToReg(const SDValue &ValArg, const EVT &ValLoc,
                                const SDLoc &Dl, SelectionDAG &DAG) {
   EVT ValVT = ValArg.getValueType();
 
-  if (ValVT == MVT::v64i1 && ValLoc == MVT::i64) {
+  if ((ValVT == MVT::v8i1 && (ValLoc == MVT::i8 || ValLoc == MVT::i32)) ||
+      (ValVT == MVT::v16i1 && (ValLoc == MVT::i16 || ValLoc == MVT::i32))) {
+    // Two stage lowering might be required
+    // bitcast:   v8i1 -> i8 / v16i1 -> i16
+    // anyextend: i8   -> i32 / i16   -> i32
+    EVT TempValLoc = ValVT == MVT::v8i1 ? MVT::i8 : MVT::i16;
+    SDValue ValToCopy = DAG.getBitcast(TempValLoc, ValArg);
+    if (ValLoc == MVT::i32)
+      ValToCopy = DAG.getNode(ISD::ANY_EXTEND, Dl, ValLoc, ValToCopy);
+    return ValToCopy;
+  } else if ((ValVT == MVT::v32i1 && ValLoc == MVT::i32) ||
+             (ValVT == MVT::v64i1 && ValLoc == MVT::i64)) {
     // One stage lowering is required
-    // bitcast:   v64i1 -> i64
-    return DAG.getBitcast(MVT::i64, ValArg);
+    // bitcast:   v32i1 -> i32 / v64i1 -> i64
+    return DAG.getBitcast(ValLoc, ValArg);
   } else
     return DAG.getNode(ISD::SIGN_EXTEND, Dl, ValLoc, ValArg);
 }
@@ -2379,14 +2391,14 @@ EVT X86TargetLowering::getTypeForExtReturn(LLVMContext &Context, EVT VT,
 }
 
 /// Reads two 32 bit registers and creates a 64 bit mask value.
-/// @param VA The current 32 bit value that need to be assigned.
-/// @param NextVA The next 32 bit value that need to be assigned.
-/// @param Root The parent DAG node.
-/// @param [in,out] InFlag Represents SDvalue in the parent DAG node for
+/// \param VA The current 32 bit value that need to be assigned.
+/// \param NextVA The next 32 bit value that need to be assigned.
+/// \param Root The parent DAG node.
+/// \param [in,out] InFlag Represents SDvalue in the parent DAG node for
 ///                        glue purposes. In the case the DAG is already using
 ///                        physical register instead of virtual, we should glue
 ///                        our new SDValue to InFlag SDvalue.
-/// @return a new SDvalue of size 64bit.
+/// \return a new SDvalue of size 64bit.
 static SDValue getv64i1Argument(CCValAssign &VA, CCValAssign &NextVA,
                                 SDValue &Root, SelectionDAG &DAG,
                                 const SDLoc &Dl, const X86Subtarget &Subtarget,
@@ -2436,23 +2448,38 @@ static SDValue getv64i1Argument(CCValAssign &VA, CCValAssign &NextVA,
   return DAG.getNode(ISD::CONCAT_VECTORS, Dl, MVT::v64i1, Lo, Hi);
 }
 
+/// The function will lower a register of various sizes (8/16/32/64)
+/// to a mask value of the expected size (v8i1/v16i1/v32i1/v64i1)
+/// \returns a DAG node contains the operand after lowering to mask type.
 static SDValue lowerRegToMasks(const SDValue &ValArg, const EVT &ValVT,
                                const EVT &ValLoc, const SDLoc &Dl,
                                SelectionDAG &DAG) {
-  assert((ValLoc == MVT::i64 || ValLoc == MVT::i32) &&
-         "Expecting register location of size 32/64 bit");
+  SDValue ValReturned = ValArg;
 
-  // Currently not referenced - will be used in other mask lowering
-  (void)Dl;
+  if (ValVT == MVT::v64i1) {
+    // In 32 bit machine, this case is handled by getv64i1Argument
+    assert(ValLoc == MVT::i64 && "Expecting only i64 locations");
+    // In 64 bit machine, There is no need to truncate the value only bitcast
+  } else {
+    MVT maskLen;
+    switch (ValVT.getSimpleVT().SimpleTy) {
+    case MVT::v8i1:
+      maskLen = MVT::i8;
+      break;
+    case MVT::v16i1:
+      maskLen = MVT::i16;
+      break;
+    case MVT::v32i1:
+      maskLen = MVT::i32;
+      break;
+    default:
+      llvm_unreachable("Expecting a vector of i1 types");
+    }
 
-  // In the case of v64i1 no special handling is required due to two reasons:
-  // In 32 bit machine, this case is handled by getv64i1Argument
-  // In 64 bit machine, There is no need to truncate the value only bitcast
-  if (ValVT == MVT::v64i1 && ValLoc == MVT::i32) {
-    llvm_unreachable("Expecting only i64 locations");
+    ValReturned = DAG.getNode(ISD::TRUNCATE, Dl, maskLen, ValReturned);
   }
 
-  return DAG.getBitcast(ValVT, ValArg);
+  return DAG.getBitcast(ValVT, ValReturned);
 }
 
 /// Lower the result values of a call into the
@@ -2513,8 +2540,9 @@ SDValue X86TargetLowering::LowerCallResult(
 
     if (VA.isExtInLoc() && (VA.getValVT().getScalarType() == MVT::i1)) {
       if (VA.getValVT().isVector() &&
-          (VA.getLocVT() == MVT::i32 || VA.getLocVT() == MVT::i64)) {
-        // promoting a mask type (v*i1) into a register of type i64/i32
+          ((VA.getLocVT() == MVT::i64) || (VA.getLocVT() == MVT::i32) ||
+           (VA.getLocVT() == MVT::i16) || (VA.getLocVT() == MVT::i8))) {
+        // promoting a mask type (v*i1) into a register of type i64/i32/i16/i8
         Val = lowerRegToMasks(Val, VA.getValVT(), VA.getLocVT(), dl, DAG);
       } else
         Val = DAG.getNode(ISD::TRUNCATE, dl, VA.getValVT(), Val);
@@ -2867,8 +2895,9 @@ SDValue X86TargetLowering::LowerFormalArguments(
           ArgValue = DAG.getNode(X86ISD::MOVDQ2Q, dl, VA.getValVT(), ArgValue);
         else if (VA.getValVT().isVector() &&
                  VA.getValVT().getScalarType() == MVT::i1 &&
-                 ((RegVT == MVT::i32) || (RegVT == MVT::i64))) {
-          // Promoting a mask type (v*i1) into a register of type i64/i32
+                 ((VA.getLocVT() == MVT::i64) || (VA.getLocVT() == MVT::i32) ||
+                  (VA.getLocVT() == MVT::i16) || (VA.getLocVT() == MVT::i8))) {
+          // Promoting a mask type (v*i1) into a register of type i64/i32/i16/i8
           ArgValue = lowerRegToMasks(ArgValue, VA.getValVT(), RegVT, dl, DAG);
         } else
           ArgValue = DAG.getNode(ISD::TRUNCATE, dl, VA.getValVT(), ArgValue);
index 6de9118f7bedf22e84e154c11dd45398369cfd27..325097ee9510581862fdf8276b41a17791c252ef 100644 (file)
@@ -1,6 +1,6 @@
-; RUN: llc < %s -mtriple=i386-pc-win32       -mattr=+avx512bw  | FileCheck --check-prefix=X32 %s
-; RUN: llc < %s -mtriple=x86_64-win32        -mattr=+avx512bw  | FileCheck --check-prefix=WIN64 %s
-; RUN: llc < %s -mtriple=x86_64-linux-gnu    -mattr=+avx512bw  | FileCheck --check-prefix=LINUXOSX64 %s
+; RUN: llc < %s -mtriple=i386-pc-win32       -mattr=+avx512bw  | FileCheck --check-prefix=CHECK --check-prefix=X32 %s
+; RUN: llc < %s -mtriple=x86_64-win32        -mattr=+avx512bw  | FileCheck --check-prefix=CHECK --check-prefix=CHECK64 --check-prefix=WIN64 %s
+; RUN: llc < %s -mtriple=x86_64-linux-gnu    -mattr=+avx512bw  | FileCheck --check-prefix=CHECK --check-prefix=CHECK64 --check-prefix=LINUXOSX64 %s
 
 ; X32-LABEL:  test_argv64i1:
 ; X32:        kmovd   %edx, %k0
@@ -155,7 +155,7 @@ define x86_regcallcc i64 @test_argv64i1(<64 x i1> %x0, <64 x i1> %x1, <64 x i1>
 ; LINUXOSX64:       call{{.*}}   test_argv64i1
 
 ; Test regcall when passing arguments of v64i1 type
-define x86_regcallcc i64 @caller_argv64i1() #0 {
+define i64 @caller_argv64i1() #0 {
 entry:
   %v0 = bitcast i64 4294967298 to <64 x i1>
   %call = call x86_regcallcc i64 @test_argv64i1(<64 x i1> %v0, <64 x i1> %v0, <64 x i1> %v0,
@@ -171,9 +171,9 @@ entry:
 ; X32:       mov{{.*}}    $1, %ecx
 ; X32:       ret{{.*}}
 
-; WIN64-LABEL: test_retv64i1:
-; WIN64:       mov{{.*}} $4294967298, %rax
-; WIN64:       ret{{.*}}
+; CHECK64-LABEL: test_retv64i1:
+; CHECK64:       mov{{.*}} $4294967298, %rax
+; CHECK64:       ret{{.*}}
 
 ; Test regcall when returning v64i1 type
 define x86_regcallcc <64 x i1> @test_retv64i1()  {
@@ -187,9 +187,164 @@ define x86_regcallcc <64 x i1> @test_retv64i1()  {
 ; X32:       kmov{{.*}}   %ecx, %k1
 ; X32:       kunpckdq     %k0, %k1, %k0
 
+; CHECK64-LABEL: caller_retv64i1:
+; CHECK64:       call{{.*}}   {{_*}}test_retv64i1
+; CHECK64:       kmovq %rax, %k0
+; CHECK64:       ret{{.*}}
+
 ; Test regcall when processing result of v64i1 type
-define x86_regcallcc <64 x i1> @caller_retv64i1() #0 {
+define <64 x i1> @caller_retv64i1() #0 {
 entry:
   %call = call x86_regcallcc <64 x i1> @test_retv64i1()
   ret <64 x i1> %call
 }
+
+; CHECK-LABEL:  test_argv32i1:
+; CHECK:        kmovd    %edx, %k{{[0-9]+}}
+; CHECK:        kmovd    %ecx, %k{{[0-9]+}}
+; CHECK:        kmovd    %eax, %k{{[0-9]+}}
+; CHECK:        ret{{l|q}}
+
+; Test regcall when receiving arguments of v32i1 type
+declare i32 @test_argv32i1helper(<32 x i1> %x0, <32 x i1> %x1, <32 x i1> %x2)
+define x86_regcallcc i32 @test_argv32i1(<32 x i1> %x0, <32 x i1> %x1, <32 x i1> %x2)  {
+entry:
+  %res = call i32 @test_argv32i1helper(<32 x i1> %x0, <32 x i1> %x1, <32 x i1> %x2)
+  ret i32 %res
+}
+
+; CHECK-LABEL:  caller_argv32i1:
+; CHECK:        mov{{.*}}    $1, %eax
+; CHECK:        mov{{.*}}    $1, %ecx
+; CHECK:        mov{{.*}}    $1, %edx
+; CHECK:        call{{.*}}   {{_*}}test_argv32i1
+
+; Test regcall when passing arguments of v32i1 type
+define i32 @caller_argv32i1() #0 {
+entry:
+  %v0 = bitcast i32 1 to <32 x i1>
+  %call = call x86_regcallcc i32 @test_argv32i1(<32 x i1> %v0, <32 x i1> %v0, <32 x i1> %v0)
+  ret i32 %call
+}
+
+; CHECK-LABEL: test_retv32i1:
+; CHECK:       movl    $1, %eax
+; CHECK:       ret{{l|q}}
+
+; Test regcall when returning v32i1 type
+define x86_regcallcc <32 x i1> @test_retv32i1()  {
+  %a = bitcast i32 1 to <32 x i1>
+  ret <32 x i1> %a
+}
+
+; CHECK-LABEL: caller_retv32i1:
+; CHECK:       call{{.*}}   {{_*}}test_retv32i1
+; CHECK:       incl %eax
+
+; Test regcall when processing result of v32i1 type
+define i32 @caller_retv32i1() #0 {
+entry:
+  %call = call x86_regcallcc <32 x i1> @test_retv32i1()
+  %c = bitcast <32 x i1> %call to i32
+  %add = add i32 %c, 1
+  ret i32 %add
+}
+
+; CHECK-LABEL:  test_argv16i1:
+; CHECK:        kmovw    %edx, %k{{[0-9]+}}
+; CHECK:        kmovw    %ecx, %k{{[0-9]+}}
+; CHECK:        kmovw    %eax, %k{{[0-9]+}}
+; CHECK:        ret{{l|q}}
+
+; Test regcall when receiving arguments of v16i1 type
+declare i16 @test_argv16i1helper(<16 x i1> %x0, <16 x i1> %x1, <16 x i1> %x2)
+define x86_regcallcc i16 @test_argv16i1(<16 x i1> %x0, <16 x i1> %x1, <16 x i1> %x2)  {
+  %res = call i16 @test_argv16i1helper(<16 x i1> %x0, <16 x i1> %x1, <16 x i1> %x2)
+  ret i16 %res
+}
+
+; CHECK-LABEL:  caller_argv16i1:
+; CHECK:        movl    $1, %eax
+; CHECK:        movl    $1, %ecx
+; CHECK:        movl    $1, %edx
+; CHECK:        call{{l|q}}   {{_*}}test_argv16i1
+
+; Test regcall when passing arguments of v16i1 type
+define i16 @caller_argv16i1() #0 {
+entry:
+  %v0 = bitcast i16 1 to <16 x i1>
+  %call = call x86_regcallcc i16 @test_argv16i1(<16 x i1> %v0, <16 x i1> %v0, <16 x i1> %v0)
+  ret i16 %call
+}
+
+; CHECK-LABEL: test_retv16i1:
+; CHECK:       movw    $1, %ax
+; CHECK:       ret{{l|q}}
+
+; Test regcall when returning v16i1 type
+define x86_regcallcc <16 x i1> @test_retv16i1()  {
+  %a = bitcast i16 1 to <16 x i1>
+  ret <16 x i1> %a
+}
+
+; CHECK-LABEL: caller_retv16i1:
+; CHECK:       call{{l|q}}   {{_*}}test_retv16i1
+; CHECK:       incl   %eax
+
+; Test regcall when processing result of v16i1 type
+define i16 @caller_retv16i1() #0 {
+entry:
+  %call = call x86_regcallcc <16 x i1> @test_retv16i1()
+  %c = bitcast <16 x i1> %call to i16
+  %add = add i16 %c, 1
+  ret i16 %add
+}
+
+; CHECK-LABEL:  test_argv8i1:
+; CHECK:        kmovw    %edx, %k{{[0-9]+}}
+; CHECK:        kmovw    %ecx, %k{{[0-9]+}}
+; CHECK:        kmovw    %eax, %k{{[0-9]+}}
+; CHECK:        ret{{l|q}}
+
+; Test regcall when receiving arguments of v8i1 type
+declare i8 @test_argv8i1helper(<8 x i1> %x0, <8 x i1> %x1, <8 x i1> %x2)
+define x86_regcallcc i8 @test_argv8i1(<8 x i1> %x0, <8 x i1> %x1, <8 x i1> %x2)  {
+  %res = call i8 @test_argv8i1helper(<8 x i1> %x0, <8 x i1> %x1, <8 x i1> %x2)
+  ret i8 %res
+}
+
+; CHECK-LABEL:  caller_argv8i1:
+; CHECK:        movl    $1, %eax
+; CHECK:        movl    $1, %ecx
+; CHECK:        movl    $1, %edx
+; CHECK:        call{{l|q}}   {{_*}}test_argv8i1
+
+; Test regcall when passing arguments of v8i1 type
+define i8 @caller_argv8i1() #0 {
+entry:
+  %v0 = bitcast i8 1 to <8 x i1>
+  %call = call x86_regcallcc i8 @test_argv8i1(<8 x i1> %v0, <8 x i1> %v0, <8 x i1> %v0)
+  ret i8 %call
+}
+
+; CHECK-LABEL: test_retv8i1:
+; CHECK:       movb    $1, %al
+; CHECK:       ret{{q|l}}
+
+; Test regcall when returning v8i1 type
+define x86_regcallcc <8 x i1> @test_retv8i1()  {
+  %a = bitcast i8 1 to <8 x i1>
+  ret <8 x i1> %a
+}
+
+; CHECK-LABEL: caller_retv8i1:
+; CHECK:       call{{l|q}}   {{_*}}test_retv8i1
+; CHECK:       kmovw %eax, %k{{[0-9]+}}
+; CHECK:       ret{{l|q}}
+
+; Test regcall when processing result of v8i1 type
+define <8 x i1> @caller_retv8i1() #0 {
+entry:
+  %call = call x86_regcallcc <8 x i1> @test_retv8i1()
+  ret <8 x i1> %call
+}