From: Simon Pilgrim Date: Sun, 2 Jun 2019 11:56:39 +0000 (+0000) Subject: [DAG] isBitwiseNot / isConstOrConstSplat - add support for build vector undefs +... X-Git-Url: https://granicus.if.org/sourcecode?a=commitdiff_plain;h=40263ea2de82abd70c41f5821448797f498d4d75;p=llvm [DAG] isBitwiseNot / isConstOrConstSplat - add support for build vector undefs + truncation (PR41020) Add (opt-in) support for implicit truncation to isConstOrConstSplat, which allows us to match truncated 'all ones' cases in isBitwiseNot. PR41020 compares against using ISD::isBuildVectorAllOnes() instead, but that predicate silently accepts any UNDEF elements in the build vector which might not be what we want in isBitwiseNot - so I've added an opt-in 'AllowUndefs' flag that is set to false by default but will allow us to enable it on individual cases where its safe. Differential Revision: https://reviews.llvm.org/D62783 git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@362323 91177308-0d34-0410-b5e6-96231b3b80d8 --- diff --git a/include/llvm/CodeGen/SelectionDAGNodes.h b/include/llvm/CodeGen/SelectionDAGNodes.h index 4e4c0e57d63..370c3a438d1 100644 --- a/include/llvm/CodeGen/SelectionDAGNodes.h +++ b/include/llvm/CodeGen/SelectionDAGNodes.h @@ -1649,15 +1649,17 @@ SDValue peekThroughExtractSubvectors(SDValue V); /// Returns true if \p V is a bitwise not operation. Assumes that an all ones /// constant is canonicalized to be operand 1. -bool isBitwiseNot(SDValue V); +bool isBitwiseNot(SDValue V, bool AllowUndefs = false); /// Returns the SDNode if it is a constant splat BuildVector or constant int. -ConstantSDNode *isConstOrConstSplat(SDValue N, bool AllowUndefs = false); +ConstantSDNode *isConstOrConstSplat(SDValue N, bool AllowUndefs = false, + bool AllowTruncation = false); /// Returns the SDNode if it is a demanded constant splat BuildVector or /// constant int. ConstantSDNode *isConstOrConstSplat(SDValue N, const APInt &DemandedElts, - bool AllowUndefs = false); + bool AllowUndefs = false, + bool AllowTruncation = false); /// Returns the SDNode if it is a constant splat BuildVector or constant float. ConstantFPSDNode *isConstOrConstSplatFP(SDValue N, bool AllowUndefs = false); diff --git a/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/lib/CodeGen/SelectionDAG/SelectionDAG.cpp index 7cb7e17d55a..d6d8cf54cb0 100644 --- a/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -8640,14 +8640,18 @@ SDValue llvm::peekThroughExtractSubvectors(SDValue V) { return V; } -bool llvm::isBitwiseNot(SDValue V) { +bool llvm::isBitwiseNot(SDValue V, bool AllowUndefs) { if (V.getOpcode() != ISD::XOR) return false; - ConstantSDNode *C = isConstOrConstSplat(peekThroughBitcasts(V.getOperand(1))); - return C && C->isAllOnesValue(); + V = peekThroughBitcasts(V.getOperand(1)); + unsigned NumBits = V.getScalarValueSizeInBits(); + ConstantSDNode *C = + isConstOrConstSplat(V, AllowUndefs, /*AllowTruncation*/ true); + return C && (C->getAPIntValue().countTrailingOnes() >= NumBits); } -ConstantSDNode *llvm::isConstOrConstSplat(SDValue N, bool AllowUndefs) { +ConstantSDNode *llvm::isConstOrConstSplat(SDValue N, bool AllowUndefs, + bool AllowTruncation) { if (ConstantSDNode *CN = dyn_cast(N)) return CN; @@ -8655,17 +8659,23 @@ ConstantSDNode *llvm::isConstOrConstSplat(SDValue N, bool AllowUndefs) { BitVector UndefElements; ConstantSDNode *CN = BV->getConstantSplatNode(&UndefElements); - // BuildVectors can truncate their operands. Ignore that case here. - if (CN && (UndefElements.none() || AllowUndefs) && - CN->getValueType(0) == N.getValueType().getScalarType()) - return CN; + // BuildVectors can truncate their operands. Ignore that case here unless + // AllowTruncation is set. + if (CN && (UndefElements.none() || AllowUndefs)) { + EVT CVT = CN->getValueType(0); + EVT NSVT = N.getValueType().getScalarType(); + assert(CVT.bitsGE(NSVT) && "Illegal build vector element extension"); + if (AllowTruncation || (CVT == NSVT)) + return CN; + } } return nullptr; } ConstantSDNode *llvm::isConstOrConstSplat(SDValue N, const APInt &DemandedElts, - bool AllowUndefs) { + bool AllowUndefs, + bool AllowTruncation) { if (ConstantSDNode *CN = dyn_cast(N)) return CN; @@ -8673,10 +8683,15 @@ ConstantSDNode *llvm::isConstOrConstSplat(SDValue N, const APInt &DemandedElts, BitVector UndefElements; ConstantSDNode *CN = BV->getConstantSplatNode(DemandedElts, &UndefElements); - // BuildVectors can truncate their operands. Ignore that case here. - if (CN && (UndefElements.none() || AllowUndefs) && - CN->getValueType(0) == N.getValueType().getScalarType()) - return CN; + // BuildVectors can truncate their operands. Ignore that case here unless + // AllowTruncation is set. + if (CN && (UndefElements.none() || AllowUndefs)) { + EVT CVT = CN->getValueType(0); + EVT NSVT = N.getValueType().getScalarType(); + assert(CVT.bitsGE(NSVT) && "Illegal build vector element extension"); + if (AllowTruncation || (CVT == NSVT)) + return CN; + } } return nullptr; diff --git a/test/CodeGen/AArch64/sat-add.ll b/test/CodeGen/AArch64/sat-add.ll index 36e63f3594b..8e54d916627 100644 --- a/test/CodeGen/AArch64/sat-add.ll +++ b/test/CodeGen/AArch64/sat-add.ll @@ -364,8 +364,7 @@ define <16 x i8> @unsigned_sat_constant_v16i8_using_cmp_sum(<16 x i8> %x) { ; CHECK-NEXT: movi v1.16b, #42 ; CHECK-NEXT: add v1.16b, v0.16b, v1.16b ; CHECK-NEXT: cmhi v0.16b, v0.16b, v1.16b -; CHECK-NEXT: bic v1.16b, v1.16b, v0.16b -; CHECK-NEXT: orr v0.16b, v0.16b, v1.16b +; CHECK-NEXT: orr v0.16b, v1.16b, v0.16b ; CHECK-NEXT: ret %a = add <16 x i8> %x, %c = icmp ugt <16 x i8> %x, %a @@ -380,8 +379,7 @@ define <16 x i8> @unsigned_sat_constant_v16i8_using_cmp_notval(<16 x i8> %x) { ; CHECK-NEXT: movi v2.16b, #213 ; CHECK-NEXT: add v1.16b, v0.16b, v1.16b ; CHECK-NEXT: cmhi v0.16b, v0.16b, v2.16b -; CHECK-NEXT: bic v1.16b, v1.16b, v0.16b -; CHECK-NEXT: orr v0.16b, v0.16b, v1.16b +; CHECK-NEXT: orr v0.16b, v1.16b, v0.16b ; CHECK-NEXT: ret %a = add <16 x i8> %x, %c = icmp ugt <16 x i8> %x, @@ -409,8 +407,7 @@ define <8 x i16> @unsigned_sat_constant_v8i16_using_cmp_sum(<8 x i16> %x) { ; CHECK-NEXT: movi v1.8h, #42 ; CHECK-NEXT: add v1.8h, v0.8h, v1.8h ; CHECK-NEXT: cmhi v0.8h, v0.8h, v1.8h -; CHECK-NEXT: bic v1.16b, v1.16b, v0.16b -; CHECK-NEXT: orr v0.16b, v0.16b, v1.16b +; CHECK-NEXT: orr v0.16b, v1.16b, v0.16b ; CHECK-NEXT: ret %a = add <8 x i16> %x, %c = icmp ugt <8 x i16> %x, %a @@ -425,8 +422,7 @@ define <8 x i16> @unsigned_sat_constant_v8i16_using_cmp_notval(<8 x i16> %x) { ; CHECK-NEXT: mvni v2.8h, #42 ; CHECK-NEXT: add v1.8h, v0.8h, v1.8h ; CHECK-NEXT: cmhi v0.8h, v0.8h, v2.8h -; CHECK-NEXT: bic v1.16b, v1.16b, v0.16b -; CHECK-NEXT: orr v0.16b, v0.16b, v1.16b +; CHECK-NEXT: orr v0.16b, v1.16b, v0.16b ; CHECK-NEXT: ret %a = add <8 x i16> %x, %c = icmp ugt <8 x i16> %x, @@ -545,8 +541,7 @@ define <16 x i8> @unsigned_sat_variable_v16i8_using_cmp_sum(<16 x i8> %x, <16 x ; CHECK: // %bb.0: ; CHECK-NEXT: add v1.16b, v0.16b, v1.16b ; CHECK-NEXT: cmhi v0.16b, v0.16b, v1.16b -; CHECK-NEXT: bic v1.16b, v1.16b, v0.16b -; CHECK-NEXT: orr v0.16b, v0.16b, v1.16b +; CHECK-NEXT: orr v0.16b, v1.16b, v0.16b ; CHECK-NEXT: ret %a = add <16 x i8> %x, %y %c = icmp ugt <16 x i8> %x, %a @@ -560,8 +555,7 @@ define <16 x i8> @unsigned_sat_variable_v16i8_using_cmp_notval(<16 x i8> %x, <16 ; CHECK-NEXT: mvn v2.16b, v1.16b ; CHECK-NEXT: add v1.16b, v0.16b, v1.16b ; CHECK-NEXT: cmhi v0.16b, v0.16b, v2.16b -; CHECK-NEXT: bic v1.16b, v1.16b, v0.16b -; CHECK-NEXT: orr v0.16b, v0.16b, v1.16b +; CHECK-NEXT: orr v0.16b, v1.16b, v0.16b ; CHECK-NEXT: ret %noty = xor <16 x i8> %y, %a = add <16 x i8> %x, %y @@ -589,8 +583,7 @@ define <8 x i16> @unsigned_sat_variable_v8i16_using_cmp_sum(<8 x i16> %x, <8 x i ; CHECK: // %bb.0: ; CHECK-NEXT: add v1.8h, v0.8h, v1.8h ; CHECK-NEXT: cmhi v0.8h, v0.8h, v1.8h -; CHECK-NEXT: bic v1.16b, v1.16b, v0.16b -; CHECK-NEXT: orr v0.16b, v0.16b, v1.16b +; CHECK-NEXT: orr v0.16b, v1.16b, v0.16b ; CHECK-NEXT: ret %a = add <8 x i16> %x, %y %c = icmp ugt <8 x i16> %x, %a @@ -604,8 +597,7 @@ define <8 x i16> @unsigned_sat_variable_v8i16_using_cmp_notval(<8 x i16> %x, <8 ; CHECK-NEXT: mvn v2.16b, v1.16b ; CHECK-NEXT: add v1.8h, v0.8h, v1.8h ; CHECK-NEXT: cmhi v0.8h, v0.8h, v2.8h -; CHECK-NEXT: bic v1.16b, v1.16b, v0.16b -; CHECK-NEXT: orr v0.16b, v0.16b, v1.16b +; CHECK-NEXT: orr v0.16b, v1.16b, v0.16b ; CHECK-NEXT: ret %noty = xor <8 x i16> %y, %a = add <8 x i16> %x, %y