#include "llvm/Target/TargetSubtargetInfo.h"
#include <algorithm>
#include <cmath>
+#include <utility>
using namespace llvm;
return SDValue(N, 0);
}
+static std::pair<APInt, bool> FoldValue(unsigned Opcode, const APInt &C1,
+ const APInt &C2) {
+ switch (Opcode) {
+ case ISD::ADD: return std::make_pair(C1 + C2, true);
+ case ISD::SUB: return std::make_pair(C1 - C2, true);
+ case ISD::MUL: return std::make_pair(C1 * C2, true);
+ case ISD::AND: return std::make_pair(C1 & C2, true);
+ case ISD::OR: return std::make_pair(C1 | C2, true);
+ case ISD::XOR: return std::make_pair(C1 ^ C2, true);
+ case ISD::SHL: return std::make_pair(C1 << C2, true);
+ case ISD::SRL: return std::make_pair(C1.lshr(C2), true);
+ case ISD::SRA: return std::make_pair(C1.ashr(C2), true);
+ case ISD::ROTL: return std::make_pair(C1.rotl(C2), true);
+ case ISD::ROTR: return std::make_pair(C1.rotr(C2), true);
+ case ISD::UDIV:
+ if (!C2.getBoolValue())
+ break;
+ return std::make_pair(C1.udiv(C2), true);
+ case ISD::UREM:
+ if (!C2.getBoolValue())
+ break;
+ return std::make_pair(C1.urem(C2), true);
+ case ISD::SDIV:
+ if (!C2.getBoolValue())
+ break;
+ return std::make_pair(C1.sdiv(C2), true);
+ case ISD::SREM:
+ if (!C2.getBoolValue())
+ break;
+ return std::make_pair(C1.srem(C2), true);
+ }
+ return std::make_pair(APInt(1, 0), false);
+}
+
+SDValue SelectionDAG::FoldConstantArithmetic(unsigned Opcode, SDLoc DL, EVT VT,
+ const ConstantSDNode *Cst1,
+ const ConstantSDNode *Cst2) {
+ if (Cst1->isOpaque() || Cst2->isOpaque())
+ return SDValue();
+
+ std::pair<APInt, bool> Folded = FoldValue(Opcode, Cst1->getAPIntValue(),
+ Cst2->getAPIntValue());
+ if (!Folded.second)
+ return SDValue();
+ return getConstant(Folded.first, DL, VT);
+}
+
SDValue SelectionDAG::FoldConstantArithmetic(unsigned Opcode, SDLoc DL, EVT VT,
SDNode *Cst1, SDNode *Cst2) {
// If the opcode is a target-specific ISD node, there's nothing we can
if (Opcode >= ISD::BUILTIN_OP_END)
return SDValue();
- SmallVector<std::pair<ConstantSDNode *, ConstantSDNode *>, 4> Inputs;
- SmallVector<SDValue, 4> Outputs;
- EVT SVT = VT.getScalarType();
+ // Handle the case of two scalars.
+ if (const ConstantSDNode *Scalar1 = dyn_cast<ConstantSDNode>(Cst1)) {
+ if (const ConstantSDNode *Scalar2 = dyn_cast<ConstantSDNode>(Cst2)) {
+ if (SDValue Folded =
+ FoldConstantArithmetic(Opcode, DL, VT, Scalar1, Scalar2)) {
+ if (!VT.isVector())
+ return Folded;
+ SmallVector<SDValue, 4> Outputs;
+ // We may have a vector type but a scalar result. Create a splat.
+ Outputs.resize(VT.getVectorNumElements(), Outputs.back());
+ // Build a big vector out of the scalar elements we generated.
+ return getNode(ISD::BUILD_VECTOR, SDLoc(), VT, Outputs);
+ } else {
+ return SDValue();
+ }
+ }
+ }
- ConstantSDNode *Scalar1 = dyn_cast<ConstantSDNode>(Cst1);
- ConstantSDNode *Scalar2 = dyn_cast<ConstantSDNode>(Cst2);
- if (Scalar1 && Scalar2 && (Scalar1->isOpaque() || Scalar2->isOpaque()))
+ // For vectors extract each constant element into Inputs so we can constant
+ // fold them individually.
+ BuildVectorSDNode *BV1 = dyn_cast<BuildVectorSDNode>(Cst1);
+ BuildVectorSDNode *BV2 = dyn_cast<BuildVectorSDNode>(Cst2);
+ if (!BV1 || !BV2)
return SDValue();
- if (Scalar1 && Scalar2)
- // Scalar instruction.
- Inputs.push_back(std::make_pair(Scalar1, Scalar2));
- else {
- // For vectors extract each constant element into Inputs so we can constant
- // fold them individually.
- BuildVectorSDNode *BV1 = dyn_cast<BuildVectorSDNode>(Cst1);
- BuildVectorSDNode *BV2 = dyn_cast<BuildVectorSDNode>(Cst2);
- if (!BV1 || !BV2)
- return SDValue();
-
- assert(BV1->getNumOperands() == BV2->getNumOperands() && "Out of sync!");
-
- for (unsigned I = 0, E = BV1->getNumOperands(); I != E; ++I) {
- ConstantSDNode *V1 = dyn_cast<ConstantSDNode>(BV1->getOperand(I));
- ConstantSDNode *V2 = dyn_cast<ConstantSDNode>(BV2->getOperand(I));
- if (!V1 || !V2) // Not a constant, bail.
- return SDValue();
+ assert(BV1->getNumOperands() == BV2->getNumOperands() && "Out of sync!");
- if (V1->isOpaque() || V2->isOpaque())
- return SDValue();
-
- // Avoid BUILD_VECTOR nodes that perform implicit truncation.
- // FIXME: This is valid and could be handled by truncating the APInts.
- if (V1->getValueType(0) != SVT || V2->getValueType(0) != SVT)
- return SDValue();
+ EVT SVT = VT.getScalarType();
+ SmallVector<SDValue, 4> Outputs;
+ for (unsigned I = 0, E = BV1->getNumOperands(); I != E; ++I) {
+ ConstantSDNode *V1 = dyn_cast<ConstantSDNode>(BV1->getOperand(I));
+ ConstantSDNode *V2 = dyn_cast<ConstantSDNode>(BV2->getOperand(I));
+ if (!V1 || !V2) // Not a constant, bail.
+ return SDValue();
- Inputs.push_back(std::make_pair(V1, V2));
- }
- }
+ if (V1->isOpaque() || V2->isOpaque())
+ return SDValue();
- // We have a number of constant values, constant fold them element by element.
- for (unsigned I = 0, E = Inputs.size(); I != E; ++I) {
- const APInt &C1 = Inputs[I].first->getAPIntValue();
- const APInt &C2 = Inputs[I].second->getAPIntValue();
+ // Avoid BUILD_VECTOR nodes that perform implicit truncation.
+ // FIXME: This is valid and could be handled by truncating the APInts.
+ if (V1->getValueType(0) != SVT || V2->getValueType(0) != SVT)
+ return SDValue();
- switch (Opcode) {
- case ISD::ADD:
- Outputs.push_back(getConstant(C1 + C2, DL, SVT));
- break;
- case ISD::SUB:
- Outputs.push_back(getConstant(C1 - C2, DL, SVT));
- break;
- case ISD::MUL:
- Outputs.push_back(getConstant(C1 * C2, DL, SVT));
- break;
- case ISD::UDIV:
- if (!C2.getBoolValue())
- return SDValue();
- Outputs.push_back(getConstant(C1.udiv(C2), DL, SVT));
- break;
- case ISD::UREM:
- if (!C2.getBoolValue())
- return SDValue();
- Outputs.push_back(getConstant(C1.urem(C2), DL, SVT));
- break;
- case ISD::SDIV:
- if (!C2.getBoolValue())
- return SDValue();
- Outputs.push_back(getConstant(C1.sdiv(C2), DL, SVT));
- break;
- case ISD::SREM:
- if (!C2.getBoolValue())
- return SDValue();
- Outputs.push_back(getConstant(C1.srem(C2), DL, SVT));
- break;
- case ISD::AND:
- Outputs.push_back(getConstant(C1 & C2, DL, SVT));
- break;
- case ISD::OR:
- Outputs.push_back(getConstant(C1 | C2, DL, SVT));
- break;
- case ISD::XOR:
- Outputs.push_back(getConstant(C1 ^ C2, DL, SVT));
- break;
- case ISD::SHL:
- Outputs.push_back(getConstant(C1 << C2, DL, SVT));
- break;
- case ISD::SRL:
- Outputs.push_back(getConstant(C1.lshr(C2), DL, SVT));
- break;
- case ISD::SRA:
- Outputs.push_back(getConstant(C1.ashr(C2), DL, SVT));
- break;
- case ISD::ROTL:
- Outputs.push_back(getConstant(C1.rotl(C2), DL, SVT));
- break;
- case ISD::ROTR:
- Outputs.push_back(getConstant(C1.rotr(C2), DL, SVT));
- break;
- default:
+ // Fold one vector element.
+ std::pair<APInt, bool> Folded = FoldValue(Opcode, V1->getAPIntValue(),
+ V2->getAPIntValue());
+ if (!Folded.second)
return SDValue();
- }
+ Outputs.push_back(getConstant(Folded.first, DL, SVT));
}
- assert((Scalar1 && Scalar2) || (VT.getVectorNumElements() == Outputs.size() &&
- "Expected a scalar or vector!"));
-
- // Handle the scalar case first.
- if (!VT.isVector())
- return Outputs.back();
+ assert(VT.getVectorNumElements() == Outputs.size() &&
+ "Vector size mismatch!");
// We may have a vector type but a scalar result. Create a splat.
Outputs.resize(VT.getVectorNumElements(), Outputs.back());