]> granicus.if.org Git - llvm/commitdiff
[DAG] isBitwiseNot / isConstOrConstSplat - add support for build vector undefs +...
authorSimon Pilgrim <llvm-dev@redking.me.uk>
Sun, 2 Jun 2019 11:56:39 +0000 (11:56 +0000)
committerSimon Pilgrim <llvm-dev@redking.me.uk>
Sun, 2 Jun 2019 11:56:39 +0000 (11:56 +0000)
Add (opt-in) support for implicit truncation to isConstOrConstSplat, which allows us to match truncated 'all ones' cases in isBitwiseNot.

PR41020 compares against using ISD::isBuildVectorAllOnes() instead, but that predicate silently accepts any UNDEF elements in the build vector which might not be what we want in isBitwiseNot - so I've added an opt-in 'AllowUndefs' flag that is set to false by default but will allow us to enable it on individual cases where its safe.

Differential Revision: https://reviews.llvm.org/D62783

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@362323 91177308-0d34-0410-b5e6-96231b3b80d8

include/llvm/CodeGen/SelectionDAGNodes.h
lib/CodeGen/SelectionDAG/SelectionDAG.cpp
test/CodeGen/AArch64/sat-add.ll

index 4e4c0e57d63259590f78a04b69542156518e31e3..370c3a438d12b017a06367cbdd3e0dde05e63769 100644 (file)
@@ -1649,15 +1649,17 @@ SDValue peekThroughExtractSubvectors(SDValue V);
 
 /// Returns true if \p V is a bitwise not operation. Assumes that an all ones
 /// constant is canonicalized to be operand 1.
-bool isBitwiseNot(SDValue V);
+bool isBitwiseNot(SDValue V, bool AllowUndefs = false);
 
 /// Returns the SDNode if it is a constant splat BuildVector or constant int.
-ConstantSDNode *isConstOrConstSplat(SDValue N, bool AllowUndefs = false);
+ConstantSDNode *isConstOrConstSplat(SDValue N, bool AllowUndefs = false,
+                                    bool AllowTruncation = false);
 
 /// Returns the SDNode if it is a demanded constant splat BuildVector or
 /// constant int.
 ConstantSDNode *isConstOrConstSplat(SDValue N, const APInt &DemandedElts,
-                                    bool AllowUndefs = false);
+                                    bool AllowUndefs = false,
+                                    bool AllowTruncation = false);
 
 /// Returns the SDNode if it is a constant splat BuildVector or constant float.
 ConstantFPSDNode *isConstOrConstSplatFP(SDValue N, bool AllowUndefs = false);
index 7cb7e17d55a2b42b503d3f93b880f6786bc37219..d6d8cf54cb01d61a8928379d3a8769312f5f4836 100644 (file)
@@ -8640,14 +8640,18 @@ SDValue llvm::peekThroughExtractSubvectors(SDValue V) {
   return V;
 }
 
-bool llvm::isBitwiseNot(SDValue V) {
+bool llvm::isBitwiseNot(SDValue V, bool AllowUndefs) {
   if (V.getOpcode() != ISD::XOR)
     return false;
-  ConstantSDNode *C = isConstOrConstSplat(peekThroughBitcasts(V.getOperand(1)));
-  return C && C->isAllOnesValue();
+  V = peekThroughBitcasts(V.getOperand(1));
+  unsigned NumBits = V.getScalarValueSizeInBits();
+  ConstantSDNode *C =
+      isConstOrConstSplat(V, AllowUndefs, /*AllowTruncation*/ true);
+  return C && (C->getAPIntValue().countTrailingOnes() >= NumBits);
 }
 
-ConstantSDNode *llvm::isConstOrConstSplat(SDValue N, bool AllowUndefs) {
+ConstantSDNode *llvm::isConstOrConstSplat(SDValue N, bool AllowUndefs,
+                                          bool AllowTruncation) {
   if (ConstantSDNode *CN = dyn_cast<ConstantSDNode>(N))
     return CN;
 
@@ -8655,17 +8659,23 @@ ConstantSDNode *llvm::isConstOrConstSplat(SDValue N, bool AllowUndefs) {
     BitVector UndefElements;
     ConstantSDNode *CN = BV->getConstantSplatNode(&UndefElements);
 
-    // BuildVectors can truncate their operands. Ignore that case here.
-    if (CN && (UndefElements.none() || AllowUndefs) &&
-        CN->getValueType(0) == N.getValueType().getScalarType())
-      return CN;
+    // BuildVectors can truncate their operands. Ignore that case here unless
+    // AllowTruncation is set.
+    if (CN && (UndefElements.none() || AllowUndefs)) {
+      EVT CVT = CN->getValueType(0);
+      EVT NSVT = N.getValueType().getScalarType();
+      assert(CVT.bitsGE(NSVT) && "Illegal build vector element extension");
+      if (AllowTruncation || (CVT == NSVT))
+        return CN;
+    }
   }
 
   return nullptr;
 }
 
 ConstantSDNode *llvm::isConstOrConstSplat(SDValue N, const APInt &DemandedElts,
-                                          bool AllowUndefs) {
+                                          bool AllowUndefs,
+                                          bool AllowTruncation) {
   if (ConstantSDNode *CN = dyn_cast<ConstantSDNode>(N))
     return CN;
 
@@ -8673,10 +8683,15 @@ ConstantSDNode *llvm::isConstOrConstSplat(SDValue N, const APInt &DemandedElts,
     BitVector UndefElements;
     ConstantSDNode *CN = BV->getConstantSplatNode(DemandedElts, &UndefElements);
 
-    // BuildVectors can truncate their operands. Ignore that case here.
-    if (CN && (UndefElements.none() || AllowUndefs) &&
-        CN->getValueType(0) == N.getValueType().getScalarType())
-      return CN;
+    // BuildVectors can truncate their operands. Ignore that case here unless
+    // AllowTruncation is set.
+    if (CN && (UndefElements.none() || AllowUndefs)) {
+      EVT CVT = CN->getValueType(0);
+      EVT NSVT = N.getValueType().getScalarType();
+      assert(CVT.bitsGE(NSVT) && "Illegal build vector element extension");
+      if (AllowTruncation || (CVT == NSVT))
+        return CN;
+    }
   }
 
   return nullptr;
index 36e63f3594b4a0c3f50887d4a405af0d0ae0cf4f..8e54d916627758261a7660bf28459998ec2e0e67 100644 (file)
@@ -364,8 +364,7 @@ define <16 x i8> @unsigned_sat_constant_v16i8_using_cmp_sum(<16 x i8> %x) {
 ; CHECK-NEXT:    movi v1.16b, #42
 ; CHECK-NEXT:    add v1.16b, v0.16b, v1.16b
 ; CHECK-NEXT:    cmhi v0.16b, v0.16b, v1.16b
-; CHECK-NEXT:    bic v1.16b, v1.16b, v0.16b
-; CHECK-NEXT:    orr v0.16b, v0.16b, v1.16b
+; CHECK-NEXT:    orr v0.16b, v1.16b, v0.16b
 ; CHECK-NEXT:    ret
   %a = add <16 x i8> %x, <i8 42, i8 42, i8 42, i8 42, i8 42, i8 42, i8 42, i8 42, i8 42, i8 42, i8 42, i8 42, i8 42, i8 42, i8 42, i8 42>
   %c = icmp ugt <16 x i8> %x, %a
@@ -380,8 +379,7 @@ define <16 x i8> @unsigned_sat_constant_v16i8_using_cmp_notval(<16 x i8> %x) {
 ; CHECK-NEXT:    movi v2.16b, #213
 ; CHECK-NEXT:    add v1.16b, v0.16b, v1.16b
 ; CHECK-NEXT:    cmhi v0.16b, v0.16b, v2.16b
-; CHECK-NEXT:    bic v1.16b, v1.16b, v0.16b
-; CHECK-NEXT:    orr v0.16b, v0.16b, v1.16b
+; CHECK-NEXT:    orr v0.16b, v1.16b, v0.16b
 ; CHECK-NEXT:    ret
   %a = add <16 x i8> %x, <i8 42, i8 42, i8 42, i8 42, i8 42, i8 42, i8 42, i8 42, i8 42, i8 42, i8 42, i8 42, i8 42, i8 42, i8 42, i8 42>
   %c = icmp ugt <16 x i8> %x, <i8 -43, i8 -43, i8 -43, i8 -43, i8 -43, i8 -43, i8 -43, i8 -43, i8 -43, i8 -43, i8 -43, i8 -43, i8 -43, i8 -43, i8 -43, i8 -43>
@@ -409,8 +407,7 @@ define <8 x i16> @unsigned_sat_constant_v8i16_using_cmp_sum(<8 x i16> %x) {
 ; CHECK-NEXT:    movi v1.8h, #42
 ; CHECK-NEXT:    add v1.8h, v0.8h, v1.8h
 ; CHECK-NEXT:    cmhi v0.8h, v0.8h, v1.8h
-; CHECK-NEXT:    bic v1.16b, v1.16b, v0.16b
-; CHECK-NEXT:    orr v0.16b, v0.16b, v1.16b
+; CHECK-NEXT:    orr v0.16b, v1.16b, v0.16b
 ; CHECK-NEXT:    ret
   %a = add <8 x i16> %x, <i16 42, i16 42, i16 42, i16 42, i16 42, i16 42, i16 42, i16 42>
   %c = icmp ugt <8 x i16> %x, %a
@@ -425,8 +422,7 @@ define <8 x i16> @unsigned_sat_constant_v8i16_using_cmp_notval(<8 x i16> %x) {
 ; CHECK-NEXT:    mvni v2.8h, #42
 ; CHECK-NEXT:    add v1.8h, v0.8h, v1.8h
 ; CHECK-NEXT:    cmhi v0.8h, v0.8h, v2.8h
-; CHECK-NEXT:    bic v1.16b, v1.16b, v0.16b
-; CHECK-NEXT:    orr v0.16b, v0.16b, v1.16b
+; CHECK-NEXT:    orr v0.16b, v1.16b, v0.16b
 ; CHECK-NEXT:    ret
   %a = add <8 x i16> %x, <i16 42, i16 42, i16 42, i16 42, i16 42, i16 42, i16 42, i16 42>
   %c = icmp ugt <8 x i16> %x, <i16 -43, i16 -43, i16 -43, i16 -43, i16 -43, i16 -43, i16 -43, i16 -43>
@@ -545,8 +541,7 @@ define <16 x i8> @unsigned_sat_variable_v16i8_using_cmp_sum(<16 x i8> %x, <16 x
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    add v1.16b, v0.16b, v1.16b
 ; CHECK-NEXT:    cmhi v0.16b, v0.16b, v1.16b
-; CHECK-NEXT:    bic v1.16b, v1.16b, v0.16b
-; CHECK-NEXT:    orr v0.16b, v0.16b, v1.16b
+; CHECK-NEXT:    orr v0.16b, v1.16b, v0.16b
 ; CHECK-NEXT:    ret
   %a = add <16 x i8> %x, %y
   %c = icmp ugt <16 x i8> %x, %a
@@ -560,8 +555,7 @@ define <16 x i8> @unsigned_sat_variable_v16i8_using_cmp_notval(<16 x i8> %x, <16
 ; CHECK-NEXT:    mvn v2.16b, v1.16b
 ; CHECK-NEXT:    add v1.16b, v0.16b, v1.16b
 ; CHECK-NEXT:    cmhi v0.16b, v0.16b, v2.16b
-; CHECK-NEXT:    bic v1.16b, v1.16b, v0.16b
-; CHECK-NEXT:    orr v0.16b, v0.16b, v1.16b
+; CHECK-NEXT:    orr v0.16b, v1.16b, v0.16b
 ; CHECK-NEXT:    ret
   %noty = xor <16 x i8> %y, <i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1>
   %a = add <16 x i8> %x, %y
@@ -589,8 +583,7 @@ define <8 x i16> @unsigned_sat_variable_v8i16_using_cmp_sum(<8 x i16> %x, <8 x i
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    add v1.8h, v0.8h, v1.8h
 ; CHECK-NEXT:    cmhi v0.8h, v0.8h, v1.8h
-; CHECK-NEXT:    bic v1.16b, v1.16b, v0.16b
-; CHECK-NEXT:    orr v0.16b, v0.16b, v1.16b
+; CHECK-NEXT:    orr v0.16b, v1.16b, v0.16b
 ; CHECK-NEXT:    ret
   %a = add <8 x i16> %x, %y
   %c = icmp ugt <8 x i16> %x, %a
@@ -604,8 +597,7 @@ define <8 x i16> @unsigned_sat_variable_v8i16_using_cmp_notval(<8 x i16> %x, <8
 ; CHECK-NEXT:    mvn v2.16b, v1.16b
 ; CHECK-NEXT:    add v1.8h, v0.8h, v1.8h
 ; CHECK-NEXT:    cmhi v0.8h, v0.8h, v2.8h
-; CHECK-NEXT:    bic v1.16b, v1.16b, v0.16b
-; CHECK-NEXT:    orr v0.16b, v0.16b, v1.16b
+; CHECK-NEXT:    orr v0.16b, v1.16b, v0.16b
 ; CHECK-NEXT:    ret
   %noty = xor <8 x i16> %y, <i16 -1, i16 -1, i16 -1, i16 -1, i16 -1, i16 -1, i16 -1, i16 -1>
   %a = add <8 x i16> %x, %y