/// Returns the SDNode if it is a constant splat BuildVector or constant int.
ConstantSDNode *isConstOrConstSplat(SDValue N, bool AllowUndefs = false);
+/// Returns the SDNode if it is a demanded constant splat BuildVector or
+/// constant int.
+ConstantSDNode *isConstOrConstSplat(SDValue N, const APInt &DemandedElts,
+ bool AllowUndefs = false);
+
/// Returns the SDNode if it is a constant splat BuildVector or constant float.
ConstantFPSDNode *isConstOrConstSplatFP(SDValue N, bool AllowUndefs = false);
+/// Returns the SDNode if it is a demanded constant splat BuildVector or
+/// constant float.
+ConstantFPSDNode *isConstOrConstSplatFP(SDValue N, const APInt &DemandedElts,
+ bool AllowUndefs = false);
+
/// Return true if the value is a constant 0 integer or a splatted vector of
/// a constant 0 integer (with no undefs by default).
/// Build vector implicit truncation is not an issue for null values.
unsigned MinSplatBits = 0,
bool isBigEndian = false) const;
+ /// Returns the demanded splatted value or a null value if this is not a
+ /// splat.
+ ///
+ /// The DemandedElts mask indicates the elements that must be in the splat.
+ /// If passed a non-null UndefElements bitvector, it will resize it to match
+ /// the vector width and set the bits where elements are undef.
+ SDValue getSplatValue(const APInt &DemandedElts,
+ BitVector *UndefElements = nullptr) const;
+
/// Returns the splatted value or a null value if this is not a splat.
///
/// If passed a non-null UndefElements bitvector, it will resize it to match
/// the vector width and set the bits where elements are undef.
SDValue getSplatValue(BitVector *UndefElements = nullptr) const;
+ /// Returns the demanded splatted constant or null if this is not a constant
+ /// splat.
+ ///
+ /// The DemandedElts mask indicates the elements that must be in the splat.
+ /// If passed a non-null UndefElements bitvector, it will resize it to match
+ /// the vector width and set the bits where elements are undef.
+ ConstantSDNode *
+ getConstantSplatNode(const APInt &DemandedElts,
+ BitVector *UndefElements = nullptr) const;
+
/// Returns the splatted constant or null if this is not a constant
/// splat.
///
ConstantSDNode *
getConstantSplatNode(BitVector *UndefElements = nullptr) const;
+ /// Returns the demanded splatted constant FP or null if this is not a
+ /// constant FP splat.
+ ///
+ /// The DemandedElts mask indicates the elements that must be in the splat.
+ /// If passed a non-null UndefElements bitvector, it will resize it to match
+ /// the vector width and set the bits where elements are undef.
+ ConstantFPSDNode *
+ getConstantFPSplatNode(const APInt &DemandedElts,
+ BitVector *UndefElements = nullptr) const;
+
/// Returns the splatted constant FP or null if this is not a constant
/// FP splat.
///
(AllowUndefs || !UndefElts);
}
-/// Helper function that checks to see if a node is a constant or a
-/// build vector of splat constants at least within the demanded elts.
-static ConstantSDNode *isConstOrDemandedConstSplat(SDValue N,
- const APInt &DemandedElts) {
- if (ConstantSDNode *CN = dyn_cast<ConstantSDNode>(N))
- return CN;
- if (N.getOpcode() != ISD::BUILD_VECTOR)
- return nullptr;
- EVT VT = N.getValueType();
- ConstantSDNode *Cst = nullptr;
- unsigned NumElts = VT.getVectorNumElements();
- assert(DemandedElts.getBitWidth() == NumElts && "Unexpected vector size");
- for (unsigned i = 0; i != NumElts; ++i) {
- if (!DemandedElts[i])
- continue;
- ConstantSDNode *C = dyn_cast<ConstantSDNode>(N.getOperand(i));
- if (!C || (Cst && Cst->getAPIntValue() != C->getAPIntValue()) ||
- C->getValueType(0) != VT.getScalarType())
- return nullptr;
- Cst = C;
- }
- return Cst;
-}
-
/// If a SHL/SRA/SRL node has a constant or splat constant shift amount that
/// is less than the element bit-width of the shift node, return it.
static const APInt *getValidShiftAmountConstant(SDValue V) {
break;
case ISD::FSHL:
case ISD::FSHR:
- if (ConstantSDNode *C =
- isConstOrDemandedConstSplat(Op.getOperand(2), DemandedElts)) {
+ if (ConstantSDNode *C = isConstOrConstSplat(Op.getOperand(2), DemandedElts)) {
unsigned Amt = C->getAPIntValue().urem(BitWidth);
// For fshl, 0-shift returns the 1st arg.
// the minimum of the clamp min/max range.
bool IsMax = (Opcode == ISD::SMAX);
ConstantSDNode *CstLow = nullptr, *CstHigh = nullptr;
- if ((CstLow = isConstOrDemandedConstSplat(Op.getOperand(1), DemandedElts)))
+ if ((CstLow = isConstOrConstSplat(Op.getOperand(1), DemandedElts)))
if (Op.getOperand(0).getOpcode() == (IsMax ? ISD::SMIN : ISD::SMAX))
- CstHigh = isConstOrDemandedConstSplat(Op.getOperand(0).getOperand(1),
- DemandedElts);
+ CstHigh =
+ isConstOrConstSplat(Op.getOperand(0).getOperand(1), DemandedElts);
if (CstLow && CstHigh) {
if (!IsMax)
std::swap(CstLow, CstHigh);
Tmp = ComputeNumSignBits(Op.getOperand(0), DemandedElts, Depth+1);
// SRA X, C -> adds C sign bits.
if (ConstantSDNode *C =
- isConstOrDemandedConstSplat(Op.getOperand(1), DemandedElts)) {
+ isConstOrConstSplat(Op.getOperand(1), DemandedElts)) {
APInt ShiftVal = C->getAPIntValue();
ShiftVal += Tmp;
Tmp = ShiftVal.uge(VTBits) ? VTBits : ShiftVal.getZExtValue();
return Tmp;
case ISD::SHL:
if (ConstantSDNode *C =
- isConstOrDemandedConstSplat(Op.getOperand(1), DemandedElts)) {
+ isConstOrConstSplat(Op.getOperand(1), DemandedElts)) {
// shl destroys sign bits.
Tmp = ComputeNumSignBits(Op.getOperand(0), DemandedElts, Depth+1);
if (C->getAPIntValue().uge(VTBits) || // Bad shift.
// the minimum of the clamp min/max range.
bool IsMax = (Opcode == ISD::SMAX);
ConstantSDNode *CstLow = nullptr, *CstHigh = nullptr;
- if ((CstLow = isConstOrDemandedConstSplat(Op.getOperand(1), DemandedElts)))
+ if ((CstLow = isConstOrConstSplat(Op.getOperand(1), DemandedElts)))
if (Op.getOperand(0).getOpcode() == (IsMax ? ISD::SMIN : ISD::SMAX))
- CstHigh = isConstOrDemandedConstSplat(Op.getOperand(0).getOperand(1),
- DemandedElts);
+ CstHigh =
+ isConstOrConstSplat(Op.getOperand(0).getOperand(1), DemandedElts);
if (CstLow && CstHigh) {
if (!IsMax)
std::swap(CstLow, CstHigh);
return nullptr;
}
+ConstantSDNode *llvm::isConstOrConstSplat(SDValue N, const APInt &DemandedElts,
+ bool AllowUndefs) {
+ if (ConstantSDNode *CN = dyn_cast<ConstantSDNode>(N))
+ return CN;
+
+ if (BuildVectorSDNode *BV = dyn_cast<BuildVectorSDNode>(N)) {
+ BitVector UndefElements;
+ ConstantSDNode *CN = BV->getConstantSplatNode(DemandedElts, &UndefElements);
+
+ // BuildVectors can truncate their operands. Ignore that case here.
+ if (CN && (UndefElements.none() || AllowUndefs) &&
+ CN->getValueType(0) == N.getValueType().getScalarType())
+ return CN;
+ }
+
+ return nullptr;
+}
+
ConstantFPSDNode *llvm::isConstOrConstSplatFP(SDValue N, bool AllowUndefs) {
if (ConstantFPSDNode *CN = dyn_cast<ConstantFPSDNode>(N))
return CN;
return nullptr;
}
+ConstantFPSDNode *llvm::isConstOrConstSplatFP(SDValue N,
+ const APInt &DemandedElts,
+ bool AllowUndefs) {
+ if (ConstantFPSDNode *CN = dyn_cast<ConstantFPSDNode>(N))
+ return CN;
+
+ if (BuildVectorSDNode *BV = dyn_cast<BuildVectorSDNode>(N)) {
+ BitVector UndefElements;
+ ConstantFPSDNode *CN =
+ BV->getConstantFPSplatNode(DemandedElts, &UndefElements);
+ if (CN && (UndefElements.none() || AllowUndefs))
+ return CN;
+ }
+
+ return nullptr;
+}
+
bool llvm::isNullOrNullSplat(SDValue N, bool AllowUndefs) {
// TODO: may want to use peekThroughBitcast() here.
ConstantSDNode *C = isConstOrConstSplat(N, AllowUndefs);
return true;
}
-SDValue BuildVectorSDNode::getSplatValue(BitVector *UndefElements) const {
+SDValue BuildVectorSDNode::getSplatValue(const APInt &DemandedElts,
+ BitVector *UndefElements) const {
if (UndefElements) {
UndefElements->clear();
UndefElements->resize(getNumOperands());
}
+ assert(getNumOperands() == DemandedElts.getBitWidth() &&
+ "Unexpected vector size");
+ if (!DemandedElts)
+ return SDValue();
SDValue Splatted;
for (unsigned i = 0, e = getNumOperands(); i != e; ++i) {
+ if (!DemandedElts[i])
+ continue;
SDValue Op = getOperand(i);
if (Op.isUndef()) {
if (UndefElements)
}
if (!Splatted) {
- assert(getOperand(0).isUndef() &&
+ unsigned FirstDemandedIdx = DemandedElts.countTrailingZeros();
+ assert(getOperand(FirstDemandedIdx).isUndef() &&
"Can only have a splat without a constant for all undefs.");
- return getOperand(0);
+ return getOperand(FirstDemandedIdx);
}
return Splatted;
}
+SDValue BuildVectorSDNode::getSplatValue(BitVector *UndefElements) const {
+ APInt DemandedElts = APInt::getAllOnesValue(getNumOperands());
+ return getSplatValue(DemandedElts, UndefElements);
+}
+
+ConstantSDNode *
+BuildVectorSDNode::getConstantSplatNode(const APInt &DemandedElts,
+ BitVector *UndefElements) const {
+ return dyn_cast_or_null<ConstantSDNode>(
+ getSplatValue(DemandedElts, UndefElements));
+}
+
ConstantSDNode *
BuildVectorSDNode::getConstantSplatNode(BitVector *UndefElements) const {
return dyn_cast_or_null<ConstantSDNode>(getSplatValue(UndefElements));
}
+ConstantFPSDNode *
+BuildVectorSDNode::getConstantFPSplatNode(const APInt &DemandedElts,
+ BitVector *UndefElements) const {
+ return dyn_cast_or_null<ConstantFPSDNode>(
+ getSplatValue(DemandedElts, UndefElements));
+}
+
ConstantFPSDNode *
BuildVectorSDNode::getConstantFPSplatNode(BitVector *UndefElements) const {
return dyn_cast_or_null<ConstantFPSDNode>(getSplatValue(UndefElements));