]> granicus.if.org Git - llvm/commitdiff
[AVX-512] Support ADD/SUB/MUL of mask vectors
authorCraig Topper <craig.topper@gmail.com>
Thu, 19 Jan 2017 07:12:35 +0000 (07:12 +0000)
committerCraig Topper <craig.topper@gmail.com>
Thu, 19 Jan 2017 07:12:35 +0000 (07:12 +0000)
Summary:
Currently we expand and scalarize these operations, but I think we should be able to implement ADD/SUB with KXOR and MUL with KAND.

We already do this for scalar i1 operations so I just extended it to vectors of i1.

Reviewers: zvi, delena

Reviewed By: delena

Subscribers: guyblank, llvm-commits

Differential Revision: https://reviews.llvm.org/D28888

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@292474 91177308-0d34-0410-b5e6-96231b3b80d8

lib/Target/X86/X86ISelLowering.cpp
test/CodeGen/X86/avx512-mask-op.ll
test/CodeGen/X86/avx512bw-mask-op.ll

index e5c5975b75795b0de3207becc989124b74615d45..513ea3813bd6ee3b4168644aa6723a427c2d7aa4 100644 (file)
@@ -1357,12 +1357,12 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
     setOperationAction(ISD::UMIN,               MVT::v16i32, Legal);
     setOperationAction(ISD::UMIN,               MVT::v8i64, Legal);
 
-    setOperationAction(ISD::ADD,                MVT::v8i1,  Expand);
-    setOperationAction(ISD::ADD,                MVT::v16i1, Expand);
-    setOperationAction(ISD::SUB,                MVT::v8i1,  Expand);
-    setOperationAction(ISD::SUB,                MVT::v16i1, Expand);
-    setOperationAction(ISD::MUL,                MVT::v8i1,  Expand);
-    setOperationAction(ISD::MUL,                MVT::v16i1, Expand);
+    setOperationAction(ISD::ADD,                MVT::v8i1,  Custom);
+    setOperationAction(ISD::ADD,                MVT::v16i1, Custom);
+    setOperationAction(ISD::SUB,                MVT::v8i1,  Custom);
+    setOperationAction(ISD::SUB,                MVT::v16i1, Custom);
+    setOperationAction(ISD::MUL,                MVT::v8i1,  Custom);
+    setOperationAction(ISD::MUL,                MVT::v16i1, Custom);
 
     setOperationAction(ISD::MUL,                MVT::v16i32, Legal);
 
@@ -1460,12 +1460,12 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
     addRegisterClass(MVT::v32i1,  &X86::VK32RegClass);
     addRegisterClass(MVT::v64i1,  &X86::VK64RegClass);
 
-    setOperationAction(ISD::ADD,                MVT::v32i1, Expand);
-    setOperationAction(ISD::ADD,                MVT::v64i1, Expand);
-    setOperationAction(ISD::SUB,                MVT::v32i1, Expand);
-    setOperationAction(ISD::SUB,                MVT::v64i1, Expand);
-    setOperationAction(ISD::MUL,                MVT::v32i1, Expand);
-    setOperationAction(ISD::MUL,                MVT::v64i1, Expand);
+    setOperationAction(ISD::ADD,                MVT::v32i1, Custom);
+    setOperationAction(ISD::ADD,                MVT::v64i1, Custom);
+    setOperationAction(ISD::SUB,                MVT::v32i1, Custom);
+    setOperationAction(ISD::SUB,                MVT::v64i1, Custom);
+    setOperationAction(ISD::MUL,                MVT::v32i1, Custom);
+    setOperationAction(ISD::MUL,                MVT::v64i1, Custom);
 
     setOperationAction(ISD::SETCC,              MVT::v32i1, Custom);
     setOperationAction(ISD::SETCC,              MVT::v64i1, Custom);
@@ -1574,9 +1574,9 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
     addRegisterClass(MVT::v2i1,   &X86::VK2RegClass);
 
     for (auto VT : { MVT::v2i1, MVT::v4i1 }) {
-      setOperationAction(ISD::ADD,                VT, Expand);
-      setOperationAction(ISD::SUB,                VT, Expand);
-      setOperationAction(ISD::MUL,                VT, Expand);
+      setOperationAction(ISD::ADD,                VT, Custom);
+      setOperationAction(ISD::SUB,                VT, Custom);
+      setOperationAction(ISD::MUL,                VT, Custom);
       setOperationAction(ISD::VSELECT,            VT, Expand);
 
       setOperationAction(ISD::TRUNCATE,           VT, Custom);
@@ -20847,8 +20847,9 @@ static SDValue Lower512IntArith(SDValue Op, SelectionDAG &DAG) {
 }
 
 static SDValue LowerADD_SUB(SDValue Op, SelectionDAG &DAG) {
-  if (Op.getValueType() == MVT::i1)
-    return DAG.getNode(ISD::XOR, SDLoc(Op), Op.getValueType(),
+  MVT VT = Op.getSimpleValueType();
+  if (VT.getScalarType() == MVT::i1)
+    return DAG.getNode(ISD::XOR, SDLoc(Op), VT,
                        Op.getOperand(0), Op.getOperand(1));
   assert(Op.getSimpleValueType().is256BitVector() &&
          Op.getSimpleValueType().isInteger() &&
@@ -20868,7 +20869,7 @@ static SDValue LowerMUL(SDValue Op, const X86Subtarget &Subtarget,
   SDLoc dl(Op);
   MVT VT = Op.getSimpleValueType();
 
-  if (VT == MVT::i1)
+  if (VT.getScalarType() == MVT::i1)
     return DAG.getNode(ISD::AND, dl, VT, Op.getOperand(0), Op.getOperand(1));
 
   // Decompose 256-bit ops into smaller 128-bit ops.
index 89bd1980e5231dfff057f5c40f3f4a9b33809caa..41fb19f38e09a727386747bba5956cd2aa3e3bdb 100644 (file)
@@ -1986,3 +1986,117 @@ define i32 @test_bitcast_v16i1_zext(<16 x i32> %a) {
    %val1 = add i32 %val, %val
    ret i32 %val1
 }
+
+define i16 @test_v16i1_add(i16 %x, i16 %y) {
+; CHECK-LABEL: test_v16i1_add:
+; CHECK:       ## BB#0:
+; CHECK-NEXT:    kmovw %edi, %k0
+; CHECK-NEXT:    kmovw %esi, %k1
+; CHECK-NEXT:    kxorw %k1, %k0, %k0
+; CHECK-NEXT:    kmovw %k0, %eax
+; CHECK-NEXT:    retq
+  %m0 = bitcast i16 %x to <16 x i1>
+  %m1 = bitcast i16 %y to <16 x i1>
+  %m2 = add <16 x i1> %m0,  %m1
+  %ret = bitcast <16 x i1> %m2 to i16
+  ret i16 %ret
+}
+
+define i16 @test_v16i1_sub(i16 %x, i16 %y) {
+; CHECK-LABEL: test_v16i1_sub:
+; CHECK:       ## BB#0:
+; CHECK-NEXT:    kmovw %edi, %k0
+; CHECK-NEXT:    kmovw %esi, %k1
+; CHECK-NEXT:    kxorw %k1, %k0, %k0
+; CHECK-NEXT:    kmovw %k0, %eax
+; CHECK-NEXT:    retq
+  %m0 = bitcast i16 %x to <16 x i1>
+  %m1 = bitcast i16 %y to <16 x i1>
+  %m2 = sub <16 x i1> %m0,  %m1
+  %ret = bitcast <16 x i1> %m2 to i16
+  ret i16 %ret
+}
+
+define i16 @test_v16i1_mul(i16 %x, i16 %y) {
+; CHECK-LABEL: test_v16i1_mul:
+; CHECK:       ## BB#0:
+; CHECK-NEXT:    kmovw %edi, %k0
+; CHECK-NEXT:    kmovw %esi, %k1
+; CHECK-NEXT:    kandw %k1, %k0, %k0
+; CHECK-NEXT:    kmovw %k0, %eax
+; CHECK-NEXT:    retq
+  %m0 = bitcast i16 %x to <16 x i1>
+  %m1 = bitcast i16 %y to <16 x i1>
+  %m2 = mul <16 x i1> %m0,  %m1
+  %ret = bitcast <16 x i1> %m2 to i16
+  ret i16 %ret
+}
+
+define i8 @test_v8i1_add(i8 %x, i8 %y) {
+; KNL-LABEL: test_v8i1_add:
+; KNL:       ## BB#0:
+; KNL-NEXT:    kmovw %edi, %k0
+; KNL-NEXT:    kmovw %esi, %k1
+; KNL-NEXT:    kxorw %k1, %k0, %k0
+; KNL-NEXT:    kmovw %k0, %eax
+; KNL-NEXT:    retq
+;
+; SKX-LABEL: test_v8i1_add:
+; SKX:       ## BB#0:
+; SKX-NEXT:    kmovb %edi, %k0
+; SKX-NEXT:    kmovb %esi, %k1
+; SKX-NEXT:    kxorb %k1, %k0, %k0
+; SKX-NEXT:    kmovb %k0, %eax
+; SKX-NEXT:    retq
+  %m0 = bitcast i8 %x to <8 x i1>
+  %m1 = bitcast i8 %y to <8 x i1>
+  %m2 = add <8 x i1> %m0,  %m1
+  %ret = bitcast <8 x i1> %m2 to i8
+  ret i8 %ret
+}
+
+define i8 @test_v8i1_sub(i8 %x, i8 %y) {
+; KNL-LABEL: test_v8i1_sub:
+; KNL:       ## BB#0:
+; KNL-NEXT:    kmovw %edi, %k0
+; KNL-NEXT:    kmovw %esi, %k1
+; KNL-NEXT:    kxorw %k1, %k0, %k0
+; KNL-NEXT:    kmovw %k0, %eax
+; KNL-NEXT:    retq
+;
+; SKX-LABEL: test_v8i1_sub:
+; SKX:       ## BB#0:
+; SKX-NEXT:    kmovb %edi, %k0
+; SKX-NEXT:    kmovb %esi, %k1
+; SKX-NEXT:    kxorb %k1, %k0, %k0
+; SKX-NEXT:    kmovb %k0, %eax
+; SKX-NEXT:    retq
+  %m0 = bitcast i8 %x to <8 x i1>
+  %m1 = bitcast i8 %y to <8 x i1>
+  %m2 = sub <8 x i1> %m0,  %m1
+  %ret = bitcast <8 x i1> %m2 to i8
+  ret i8 %ret
+}
+
+define i8 @test_v8i1_mul(i8 %x, i8 %y) {
+; KNL-LABEL: test_v8i1_mul:
+; KNL:       ## BB#0:
+; KNL-NEXT:    kmovw %edi, %k0
+; KNL-NEXT:    kmovw %esi, %k1
+; KNL-NEXT:    kandw %k1, %k0, %k0
+; KNL-NEXT:    kmovw %k0, %eax
+; KNL-NEXT:    retq
+;
+; SKX-LABEL: test_v8i1_mul:
+; SKX:       ## BB#0:
+; SKX-NEXT:    kmovb %edi, %k0
+; SKX-NEXT:    kmovb %esi, %k1
+; SKX-NEXT:    kandb %k1, %k0, %k0
+; SKX-NEXT:    kmovb %k0, %eax
+; SKX-NEXT:    retq
+  %m0 = bitcast i8 %x to <8 x i1>
+  %m1 = bitcast i8 %y to <8 x i1>
+  %m2 = mul <8 x i1> %m0,  %m1
+  %ret = bitcast <8 x i1> %m2 to i8
+  ret i8 %ret
+}
index 619c42494e2d057c00b532f189b69abca5dbcb59..e000ef4068f64febf668e301722beac45af16522 100644 (file)
@@ -150,3 +150,93 @@ define i64 @mand64_mem(<64 x i1>* %x, <64 x i1>* %y) {
   %ret = bitcast <64 x i1> %me to i64
   ret i64 %ret
 }
+
+define i32 @test_v32i1_add(i32 %x, i32 %y) {
+; CHECK-LABEL: test_v32i1_add:
+; CHECK:       ## BB#0:
+; CHECK-NEXT:    kmovd %edi, %k0
+; CHECK-NEXT:    kmovd %esi, %k1
+; CHECK-NEXT:    kxord %k1, %k0, %k0
+; CHECK-NEXT:    kmovd %k0, %eax
+; CHECK-NEXT:    retq
+  %m0 = bitcast i32 %x to <32 x i1>
+  %m1 = bitcast i32 %y to <32 x i1>
+  %m2 = add <32 x i1> %m0,  %m1
+  %ret = bitcast <32 x i1> %m2 to i32
+  ret i32 %ret
+}
+
+define i32 @test_v32i1_sub(i32 %x, i32 %y) {
+; CHECK-LABEL: test_v32i1_sub:
+; CHECK:       ## BB#0:
+; CHECK-NEXT:    kmovd %edi, %k0
+; CHECK-NEXT:    kmovd %esi, %k1
+; CHECK-NEXT:    kxord %k1, %k0, %k0
+; CHECK-NEXT:    kmovd %k0, %eax
+; CHECK-NEXT:    retq
+  %m0 = bitcast i32 %x to <32 x i1>
+  %m1 = bitcast i32 %y to <32 x i1>
+  %m2 = sub <32 x i1> %m0,  %m1
+  %ret = bitcast <32 x i1> %m2 to i32
+  ret i32 %ret
+}
+
+define i32 @test_v32i1_mul(i32 %x, i32 %y) {
+; CHECK-LABEL: test_v32i1_mul:
+; CHECK:       ## BB#0:
+; CHECK-NEXT:    kmovd %edi, %k0
+; CHECK-NEXT:    kmovd %esi, %k1
+; CHECK-NEXT:    kandd %k1, %k0, %k0
+; CHECK-NEXT:    kmovd %k0, %eax
+; CHECK-NEXT:    retq
+  %m0 = bitcast i32 %x to <32 x i1>
+  %m1 = bitcast i32 %y to <32 x i1>
+  %m2 = mul <32 x i1> %m0,  %m1
+  %ret = bitcast <32 x i1> %m2 to i32
+  ret i32 %ret
+}
+
+define i64 @test_v64i1_add(i64 %x, i64 %y) {
+; CHECK-LABEL: test_v64i1_add:
+; CHECK:       ## BB#0:
+; CHECK-NEXT:    kmovq %rdi, %k0
+; CHECK-NEXT:    kmovq %rsi, %k1
+; CHECK-NEXT:    kxorq %k1, %k0, %k0
+; CHECK-NEXT:    kmovq %k0, %rax
+; CHECK-NEXT:    retq
+  %m0 = bitcast i64 %x to <64 x i1>
+  %m1 = bitcast i64 %y to <64 x i1>
+  %m2 = add <64 x i1> %m0,  %m1
+  %ret = bitcast <64 x i1> %m2 to i64
+  ret i64 %ret
+}
+
+define i64 @test_v64i1_sub(i64 %x, i64 %y) {
+; CHECK-LABEL: test_v64i1_sub:
+; CHECK:       ## BB#0:
+; CHECK-NEXT:    kmovq %rdi, %k0
+; CHECK-NEXT:    kmovq %rsi, %k1
+; CHECK-NEXT:    kxorq %k1, %k0, %k0
+; CHECK-NEXT:    kmovq %k0, %rax
+; CHECK-NEXT:    retq
+  %m0 = bitcast i64 %x to <64 x i1>
+  %m1 = bitcast i64 %y to <64 x i1>
+  %m2 = sub <64 x i1> %m0,  %m1
+  %ret = bitcast <64 x i1> %m2 to i64
+  ret i64 %ret
+}
+
+define i64 @test_v64i1_mul(i64 %x, i64 %y) {
+; CHECK-LABEL: test_v64i1_mul:
+; CHECK:       ## BB#0:
+; CHECK-NEXT:    kmovq %rdi, %k0
+; CHECK-NEXT:    kmovq %rsi, %k1
+; CHECK-NEXT:    kandq %k1, %k0, %k0
+; CHECK-NEXT:    kmovq %k0, %rax
+; CHECK-NEXT:    retq
+  %m0 = bitcast i64 %x to <64 x i1>
+  %m1 = bitcast i64 %y to <64 x i1>
+  %m2 = mul <64 x i1> %m0,  %m1
+  %ret = bitcast <64 x i1> %m2 to i64
+  ret i64 %ret
+}