From 42cff157db9cef189027dc30645c768f632656c4 Mon Sep 17 00:00:00 2001 From: Matthias Braun Date: Tue, 19 May 2015 01:40:21 +0000 Subject: [PATCH] SelectionDAG: Cleanup and simplify FoldConstantArithmetic This cleans up the FoldConstantArithmetic code by factoring out the case of two ConstantSDNodes into an own function. This avoids unnecessary complexity for many callers who already have ConstantSDNode arguments. This also avoids an intermeidate SmallVector datastructure and a loop over that datastructure. git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@237651 91177308-0d34-0410-b5e6-96231b3b80d8 --- include/llvm/CodeGen/SelectionDAG.h | 4 + lib/CodeGen/SelectionDAG/SelectionDAG.cpp | 191 +++++++++++----------- 2 files changed, 95 insertions(+), 100 deletions(-) diff --git a/include/llvm/CodeGen/SelectionDAG.h b/include/llvm/CodeGen/SelectionDAG.h index 09fdb626de9..a76843200be 100644 --- a/include/llvm/CodeGen/SelectionDAG.h +++ b/include/llvm/CodeGen/SelectionDAG.h @@ -1128,6 +1128,10 @@ public: SDValue FoldConstantArithmetic(unsigned Opcode, SDLoc DL, EVT VT, SDNode *Cst1, SDNode *Cst2); + SDValue FoldConstantArithmetic(unsigned Opcode, SDLoc DL, EVT VT, + const ConstantSDNode *Cst1, + const ConstantSDNode *Cst2); + /// Constant fold a setcc to true or false. SDValue FoldSetCC(EVT VT, SDValue N1, SDValue N2, ISD::CondCode Cond, SDLoc dl); diff --git a/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/lib/CodeGen/SelectionDAG/SelectionDAG.cpp index 7e898d582a4..6d75a7c9533 100644 --- a/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -49,6 +49,7 @@ #include "llvm/Target/TargetSubtargetInfo.h" #include #include +#include using namespace llvm; @@ -3109,6 +3110,53 @@ SDValue SelectionDAG::getNode(unsigned Opcode, SDLoc DL, return SDValue(N, 0); } +static std::pair FoldValue(unsigned Opcode, const APInt &C1, + const APInt &C2) { + switch (Opcode) { + case ISD::ADD: return std::make_pair(C1 + C2, true); + case ISD::SUB: return std::make_pair(C1 - C2, true); + case ISD::MUL: return std::make_pair(C1 * C2, true); + case ISD::AND: return std::make_pair(C1 & C2, true); + case ISD::OR: return std::make_pair(C1 | C2, true); + case ISD::XOR: return std::make_pair(C1 ^ C2, true); + case ISD::SHL: return std::make_pair(C1 << C2, true); + case ISD::SRL: return std::make_pair(C1.lshr(C2), true); + case ISD::SRA: return std::make_pair(C1.ashr(C2), true); + case ISD::ROTL: return std::make_pair(C1.rotl(C2), true); + case ISD::ROTR: return std::make_pair(C1.rotr(C2), true); + case ISD::UDIV: + if (!C2.getBoolValue()) + break; + return std::make_pair(C1.udiv(C2), true); + case ISD::UREM: + if (!C2.getBoolValue()) + break; + return std::make_pair(C1.urem(C2), true); + case ISD::SDIV: + if (!C2.getBoolValue()) + break; + return std::make_pair(C1.sdiv(C2), true); + case ISD::SREM: + if (!C2.getBoolValue()) + break; + return std::make_pair(C1.srem(C2), true); + } + return std::make_pair(APInt(1, 0), false); +} + +SDValue SelectionDAG::FoldConstantArithmetic(unsigned Opcode, SDLoc DL, EVT VT, + const ConstantSDNode *Cst1, + const ConstantSDNode *Cst2) { + if (Cst1->isOpaque() || Cst2->isOpaque()) + return SDValue(); + + std::pair Folded = FoldValue(Opcode, Cst1->getAPIntValue(), + Cst2->getAPIntValue()); + if (!Folded.second) + return SDValue(); + return getConstant(Folded.first, DL, VT); +} + SDValue SelectionDAG::FoldConstantArithmetic(unsigned Opcode, SDLoc DL, EVT VT, SDNode *Cst1, SDNode *Cst2) { // If the opcode is a target-specific ISD node, there's nothing we can @@ -3117,116 +3165,59 @@ SDValue SelectionDAG::FoldConstantArithmetic(unsigned Opcode, SDLoc DL, EVT VT, if (Opcode >= ISD::BUILTIN_OP_END) return SDValue(); - SmallVector, 4> Inputs; - SmallVector Outputs; - EVT SVT = VT.getScalarType(); + // Handle the case of two scalars. + if (const ConstantSDNode *Scalar1 = dyn_cast(Cst1)) { + if (const ConstantSDNode *Scalar2 = dyn_cast(Cst2)) { + if (SDValue Folded = + FoldConstantArithmetic(Opcode, DL, VT, Scalar1, Scalar2)) { + if (!VT.isVector()) + return Folded; + SmallVector Outputs; + // We may have a vector type but a scalar result. Create a splat. + Outputs.resize(VT.getVectorNumElements(), Outputs.back()); + // Build a big vector out of the scalar elements we generated. + return getNode(ISD::BUILD_VECTOR, SDLoc(), VT, Outputs); + } else { + return SDValue(); + } + } + } - ConstantSDNode *Scalar1 = dyn_cast(Cst1); - ConstantSDNode *Scalar2 = dyn_cast(Cst2); - if (Scalar1 && Scalar2 && (Scalar1->isOpaque() || Scalar2->isOpaque())) + // For vectors extract each constant element into Inputs so we can constant + // fold them individually. + BuildVectorSDNode *BV1 = dyn_cast(Cst1); + BuildVectorSDNode *BV2 = dyn_cast(Cst2); + if (!BV1 || !BV2) return SDValue(); - if (Scalar1 && Scalar2) - // Scalar instruction. - Inputs.push_back(std::make_pair(Scalar1, Scalar2)); - else { - // For vectors extract each constant element into Inputs so we can constant - // fold them individually. - BuildVectorSDNode *BV1 = dyn_cast(Cst1); - BuildVectorSDNode *BV2 = dyn_cast(Cst2); - if (!BV1 || !BV2) - return SDValue(); - - assert(BV1->getNumOperands() == BV2->getNumOperands() && "Out of sync!"); - - for (unsigned I = 0, E = BV1->getNumOperands(); I != E; ++I) { - ConstantSDNode *V1 = dyn_cast(BV1->getOperand(I)); - ConstantSDNode *V2 = dyn_cast(BV2->getOperand(I)); - if (!V1 || !V2) // Not a constant, bail. - return SDValue(); + assert(BV1->getNumOperands() == BV2->getNumOperands() && "Out of sync!"); - if (V1->isOpaque() || V2->isOpaque()) - return SDValue(); - - // Avoid BUILD_VECTOR nodes that perform implicit truncation. - // FIXME: This is valid and could be handled by truncating the APInts. - if (V1->getValueType(0) != SVT || V2->getValueType(0) != SVT) - return SDValue(); + EVT SVT = VT.getScalarType(); + SmallVector Outputs; + for (unsigned I = 0, E = BV1->getNumOperands(); I != E; ++I) { + ConstantSDNode *V1 = dyn_cast(BV1->getOperand(I)); + ConstantSDNode *V2 = dyn_cast(BV2->getOperand(I)); + if (!V1 || !V2) // Not a constant, bail. + return SDValue(); - Inputs.push_back(std::make_pair(V1, V2)); - } - } + if (V1->isOpaque() || V2->isOpaque()) + return SDValue(); - // We have a number of constant values, constant fold them element by element. - for (unsigned I = 0, E = Inputs.size(); I != E; ++I) { - const APInt &C1 = Inputs[I].first->getAPIntValue(); - const APInt &C2 = Inputs[I].second->getAPIntValue(); + // Avoid BUILD_VECTOR nodes that perform implicit truncation. + // FIXME: This is valid and could be handled by truncating the APInts. + if (V1->getValueType(0) != SVT || V2->getValueType(0) != SVT) + return SDValue(); - switch (Opcode) { - case ISD::ADD: - Outputs.push_back(getConstant(C1 + C2, DL, SVT)); - break; - case ISD::SUB: - Outputs.push_back(getConstant(C1 - C2, DL, SVT)); - break; - case ISD::MUL: - Outputs.push_back(getConstant(C1 * C2, DL, SVT)); - break; - case ISD::UDIV: - if (!C2.getBoolValue()) - return SDValue(); - Outputs.push_back(getConstant(C1.udiv(C2), DL, SVT)); - break; - case ISD::UREM: - if (!C2.getBoolValue()) - return SDValue(); - Outputs.push_back(getConstant(C1.urem(C2), DL, SVT)); - break; - case ISD::SDIV: - if (!C2.getBoolValue()) - return SDValue(); - Outputs.push_back(getConstant(C1.sdiv(C2), DL, SVT)); - break; - case ISD::SREM: - if (!C2.getBoolValue()) - return SDValue(); - Outputs.push_back(getConstant(C1.srem(C2), DL, SVT)); - break; - case ISD::AND: - Outputs.push_back(getConstant(C1 & C2, DL, SVT)); - break; - case ISD::OR: - Outputs.push_back(getConstant(C1 | C2, DL, SVT)); - break; - case ISD::XOR: - Outputs.push_back(getConstant(C1 ^ C2, DL, SVT)); - break; - case ISD::SHL: - Outputs.push_back(getConstant(C1 << C2, DL, SVT)); - break; - case ISD::SRL: - Outputs.push_back(getConstant(C1.lshr(C2), DL, SVT)); - break; - case ISD::SRA: - Outputs.push_back(getConstant(C1.ashr(C2), DL, SVT)); - break; - case ISD::ROTL: - Outputs.push_back(getConstant(C1.rotl(C2), DL, SVT)); - break; - case ISD::ROTR: - Outputs.push_back(getConstant(C1.rotr(C2), DL, SVT)); - break; - default: + // Fold one vector element. + std::pair Folded = FoldValue(Opcode, V1->getAPIntValue(), + V2->getAPIntValue()); + if (!Folded.second) return SDValue(); - } + Outputs.push_back(getConstant(Folded.first, DL, SVT)); } - assert((Scalar1 && Scalar2) || (VT.getVectorNumElements() == Outputs.size() && - "Expected a scalar or vector!")); - - // Handle the scalar case first. - if (!VT.isVector()) - return Outputs.back(); + assert(VT.getVectorNumElements() == Outputs.size() && + "Vector size mismatch!"); // We may have a vector type but a scalar result. Create a splat. Outputs.resize(VT.getVectorNumElements(), Outputs.back()); -- 2.40.0