From 73be7b823e848437bf41189053276924cf1d2bfa Mon Sep 17 00:00:00 2001 From: Sam Parker Date: Thu, 11 Jul 2019 07:47:50 +0000 Subject: [PATCH] [ARM][ParallelDSP] Change the search for smlads Two functional changes have been made here: - Now search up from any add instruction to find the chains of operations that we may turn into a smlad. This allows the generation of a smlad which doesn't accumulate into a phi. - The search function has been corrected to stop it falsely searching up through an invalid path. The bulk of the changes have been making the Reduction struct a class and making it more C++y with getters and setters. Differential Revision: https://reviews.llvm.org/D61780 git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@365740 91177308-0d34-0410-b5e6-96231b3b80d8 --- lib/Target/ARM/ARMParallelDSP.cpp | 568 ++++++++++-------- test/CodeGen/ARM/ParallelDSP/aliasing.ll | 4 +- .../ARM/ParallelDSP/inner-full-unroll.ll | 151 +++++ 3 files changed, 470 insertions(+), 253 deletions(-) create mode 100644 test/CodeGen/ARM/ParallelDSP/inner-full-unroll.ll diff --git a/lib/Target/ARM/ARMParallelDSP.cpp b/lib/Target/ARM/ARMParallelDSP.cpp index 3cff9b56851..5389d09bf7d 100644 --- a/lib/Target/ARM/ARMParallelDSP.cpp +++ b/lib/Target/ARM/ARMParallelDSP.cpp @@ -48,7 +48,7 @@ DisableParallelDSP("disable-arm-parallel-dsp", cl::Hidden, cl::init(false), namespace { struct OpChain; struct BinOpChain; - struct Reduction; + class Reduction; using OpChainList = SmallVector, 8>; using ReductionList = SmallVector; @@ -79,10 +79,8 @@ namespace { unsigned size() const { return AllValues.size(); } }; - // 'BinOpChain' and 'Reduction' are just some bookkeeping data structures. - // 'Reduction' contains the phi-node and accumulator statement from where we - // start pattern matching, and 'BinOpChain' the multiplication - // instructions that are candidates for parallel execution. + // 'BinOpChain' holds the multiplication instructions that are candidates + // for parallel execution. struct BinOpChain : public OpChain { ValueList LHS; // List of all (narrow) left hand operands. ValueList RHS; // List of all (narrow) right hand operands. @@ -97,15 +95,70 @@ namespace { bool AreSymmetrical(BinOpChain *Other); }; - struct Reduction { - PHINode *Phi; // The Phi-node from where we start - // pattern matching. - Instruction *AccIntAdd; // The accumulating integer add statement, - // i.e, the reduction statement. - OpChainList MACCandidates; // The MAC candidates associated with - // this reduction statement. - PMACPairList PMACPairs; - Reduction (PHINode *P, Instruction *Acc) : Phi(P), AccIntAdd(Acc) { }; + /// Represent a sequence of multiply-accumulate operations with the aim to + /// perform the multiplications in parallel. + class Reduction { + Instruction *Root = nullptr; + Value *Acc = nullptr; + OpChainList Muls; + PMACPairList MulPairs; + SmallPtrSet Adds; + + public: + Reduction() = delete; + + Reduction (Instruction *Add) : Root(Add) { } + + /// Record an Add instruction that is a part of the this reduction. + void InsertAdd(Instruction *I) { Adds.insert(I); } + + /// Record a BinOpChain, rooted at a Mul instruction, that is a part of + /// this reduction. + void InsertMul(Instruction *I, ValueList &LHS, ValueList &RHS) { + Muls.push_back(make_unique(I, LHS, RHS)); + } + + /// Add the incoming accumulator value, returns true if a value had not + /// already been added. Returning false signals to the user that this + /// reduction already has a value to initialise the accumulator. + bool InsertAcc(Value *V) { + if (Acc) + return false; + Acc = V; + return true; + } + + /// Set two BinOpChains, rooted at muls, that can be executed as a single + /// parallel operation. + void AddMulPair(BinOpChain *Mul0, BinOpChain *Mul1) { + MulPairs.push_back(std::make_pair(Mul0, Mul1)); + } + + /// Return true if enough mul operations are found that can be executed in + /// parallel. + bool CreateParallelPairs(); + + /// Return the add instruction which is the root of the reduction. + Instruction *getRoot() { return Root; } + + /// Return the incoming value to be accumulated. This maybe null. + Value *getAccumulator() { return Acc; } + + /// Return the set of adds that comprise the reduction. + SmallPtrSetImpl &getAdds() { return Adds; } + + /// Return the BinOpChain, rooted at mul instruction, that comprise the + /// the reduction. + OpChainList &getMuls() { return Muls; } + + /// Return the BinOpChain, rooted at mul instructions, that have been + /// paired for parallel execution. + PMACPairList &getMulPairs() { return MulPairs; } + + /// To finalise, replace the uses of the root with the intrinsic call. + void UpdateRoot(Instruction *SMLAD) { + Root->replaceAllUsesWith(SMLAD); + } }; class WidenedLoad { @@ -133,25 +186,25 @@ namespace { const DataLayout *DL; Module *M; std::map LoadPairs; + SmallPtrSet OffsetLoads; std::map> WideLoads; + template + bool IsNarrowSequence(Value *V, ValueList &VL); + bool RecordMemoryOps(BasicBlock *BB); - bool InsertParallelMACs(Reduction &Reduction); + void InsertParallelMACs(Reduction &Reduction); bool AreSequentialLoads(LoadInst *Ld0, LoadInst *Ld1, MemInstList &VecMem); LoadInst* CreateWideLoad(SmallVectorImpl &Loads, IntegerType *LoadTy); - void CreateParallelMACPairs(Reduction &R); - Instruction *CreateSMLADCall(SmallVectorImpl &VecLd0, - SmallVectorImpl &VecLd1, - Instruction *Acc, bool Exchange, - Instruction *InsertAfter); + bool CreateParallelPairs(Reduction &R); /// Try to match and generate: SMLAD, SMLADX - Signed Multiply Accumulate /// Dual performs two signed 16x16-bit multiplications. It adds the /// products to a 32-bit accumulate operand. Optionally, the instruction can /// exchange the halfwords of the second operand before performing the /// arithmetic. - bool MatchSMLAD(Function &F); + bool MatchSMLAD(Loop *L); public: static char ID; @@ -201,11 +254,8 @@ namespace { return false; } - // We need a preheader as getIncomingValueForBlock assumes there is one. - if (!TheLoop->getLoopPreheader()) { - LLVM_DEBUG(dbgs() << "No preheader found, bailing out\n"); - return false; - } + if (!TheLoop->getLoopPreheader()) + InsertPreheaderForLoop(L, DT, LI, nullptr, true); Function &F = *Header->getParent(); M = F.getParent(); @@ -242,7 +292,7 @@ namespace { return false; } - bool Changes = MatchSMLAD(F); + bool Changes = MatchSMLAD(L); return Changes; } }; @@ -275,6 +325,51 @@ bool ARMParallelDSP::AreSequentialLoads(LoadInst *Ld0, LoadInst *Ld1, return true; } +// MaxBitwidth: the maximum supported bitwidth of the elements in the DSP +// instructions, which is set to 16. So here we should collect all i8 and i16 +// narrow operations. +// TODO: we currently only collect i16, and will support i8 later, so that's +// why we check that types are equal to MaxBitWidth, and not <= MaxBitWidth. +template +bool ARMParallelDSP::IsNarrowSequence(Value *V, ValueList &VL) { + ConstantInt *CInt; + + if (match(V, m_ConstantInt(CInt))) { + // TODO: if a constant is used, it needs to fit within the bit width. + return false; + } + + auto *I = dyn_cast(V); + if (!I) + return false; + + Value *Val, *LHS, *RHS; + if (match(V, m_Trunc(m_Value(Val)))) { + if (cast(I)->getDestTy()->getIntegerBitWidth() == MaxBitWidth) + return IsNarrowSequence(Val, VL); + } else if (match(V, m_Add(m_Value(LHS), m_Value(RHS)))) { + // TODO: we need to implement sadd16/sadd8 for this, which enables to + // also do the rewrite for smlad8.ll, but it is unsupported for now. + return false; + } else if (match(V, m_ZExtOrSExt(m_Value(Val)))) { + if (cast(I)->getSrcTy()->getIntegerBitWidth() != MaxBitWidth) + return false; + + if (match(Val, m_Load(m_Value()))) { + auto *Ld = cast(Val); + + // Check that these load could be paired. + if (!LoadPairs.count(Ld) && !OffsetLoads.count(Ld)) + return false; + + VL.push_back(Val); + VL.push_back(I); + return true; + } + } + return false; +} + /// Iterate through the block and record base, offset pairs of loads which can /// be widened into a single load. bool ARMParallelDSP::RecordMemoryOps(BasicBlock *BB) { @@ -342,6 +437,7 @@ bool ARMParallelDSP::RecordMemoryOps(BasicBlock *BB) { if (AreSequentialAccesses(Base, Offset, *DL, *SE) && SafeToPair(Base, Offset)) { LoadPairs[Base] = Offset; + OffsetLoads.insert(Offset); break; } } @@ -357,15 +453,150 @@ bool ARMParallelDSP::RecordMemoryOps(BasicBlock *BB) { return LoadPairs.size() > 1; } -void ARMParallelDSP::CreateParallelMACPairs(Reduction &R) { - OpChainList &Candidates = R.MACCandidates; - PMACPairList &PMACPairs = R.PMACPairs; - const unsigned Elems = Candidates.size(); +// Loop Pass that needs to identify integer add/sub reductions of 16-bit vector +// multiplications. +// To use SMLAD: +// 1) we first need to find integer add then look for this pattern: +// +// acc0 = ... +// ld0 = load i16 +// sext0 = sext i16 %ld0 to i32 +// ld1 = load i16 +// sext1 = sext i16 %ld1 to i32 +// mul0 = mul %sext0, %sext1 +// ld2 = load i16 +// sext2 = sext i16 %ld2 to i32 +// ld3 = load i16 +// sext3 = sext i16 %ld3 to i32 +// mul1 = mul i32 %sext2, %sext3 +// add0 = add i32 %mul0, %acc0 +// acc1 = add i32 %add0, %mul1 +// +// Which can be selected to: +// +// ldr r0 +// ldr r1 +// smlad r2, r0, r1, r2 +// +// If constants are used instead of loads, these will need to be hoisted +// out and into a register. +// +// If loop invariants are used instead of loads, these need to be packed +// before the loop begins. +// +bool ARMParallelDSP::MatchSMLAD(Loop *L) { + // Search recursively back through the operands to find a tree of values that + // form a multiply-accumulate chain. The search records the Add and Mul + // instructions that form the reduction and allows us to find a single value + // to be used as the initial input to the accumlator. + std::function Search = [&] + (Value *V, Reduction &R) -> bool { + + // If we find a non-instruction, try to use it as the initial accumulator + // value. This may have already been found during the search in which case + // this function will return false, signaling a search fail. + auto *I = dyn_cast(V); + if (!I) + return R.InsertAcc(V); + + switch (I->getOpcode()) { + default: + break; + case Instruction::PHI: + // Could be the accumulator value. + return R.InsertAcc(V); + case Instruction::Add: { + // Adds should be adding together two muls, or another add and a mul to + // be within the mac chain. One of the operands may also be the + // accumulator value at which point we should stop searching. + bool ValidLHS = Search(I->getOperand(0), R); + bool ValidRHS = Search(I->getOperand(1), R); + if (!ValidLHS && !ValidLHS) + return false; + else if (ValidLHS && ValidRHS) { + R.InsertAdd(I); + return true; + } else { + R.InsertAdd(I); + return R.InsertAcc(I); + } + } + case Instruction::Mul: { + Value *MulOp0 = I->getOperand(0); + Value *MulOp1 = I->getOperand(1); + if (isa(MulOp0) && isa(MulOp1)) { + ValueList LHS; + ValueList RHS; + if (IsNarrowSequence<16>(MulOp0, LHS) && + IsNarrowSequence<16>(MulOp1, RHS)) { + R.InsertMul(I, LHS, RHS); + return true; + } + } + return false; + } + case Instruction::SExt: + return Search(I->getOperand(0), R); + } + return false; + }; + + bool Changed = false; + SmallPtrSet AllAdds; + BasicBlock *Latch = L->getLoopLatch(); + + for (Instruction &I : reverse(*Latch)) { + if (I.getOpcode() != Instruction::Add) + continue; + + if (AllAdds.count(&I)) + continue; + + const auto *Ty = I.getType(); + if (!Ty->isIntegerTy(32) && !Ty->isIntegerTy(64)) + continue; + + Reduction R(&I); + if (!Search(&I, R)) + continue; + + if (!CreateParallelPairs(R)) + continue; + + InsertParallelMACs(R); + Changed = true; + AllAdds.insert(R.getAdds().begin(), R.getAdds().end()); + } + + return Changed; +} + +bool ARMParallelDSP::CreateParallelPairs(Reduction &R) { + + // Not enough mul operations to make a pair. + if (R.getMuls().size() < 2) + return false; + + // Check that the muls operate directly upon sign extended loads. + for (auto &MulChain : R.getMuls()) { + // A mul has 2 operands, and a narrow op consist of sext and a load; thus + // we expect at least 4 items in this operand value list. + if (MulChain->size() < 4) { + LLVM_DEBUG(dbgs() << "Operand list too short.\n"); + return false; + } + MulChain->PopulateLoads(); + ValueList &LHS = static_cast(MulChain.get())->LHS; + ValueList &RHS = static_cast(MulChain.get())->RHS; - if (Elems < 2) - return; + // Use +=2 to skip over the expected extend instructions. + for (unsigned i = 0, e = LHS.size(); i < e; i += 2) { + if (!isa(LHS[i]) || !isa(RHS[i])) + return false; + } + } - auto CanPair = [&](BinOpChain *PMul0, BinOpChain *PMul1) { + auto CanPair = [&](Reduction &R, BinOpChain *PMul0, BinOpChain *PMul1) { if (!PMul0->AreSymmetrical(PMul1)) return false; @@ -391,13 +622,13 @@ void ARMParallelDSP::CreateParallelMACPairs(Reduction &R) { if (AreSequentialLoads(Ld0, Ld1, PMul0->VecLd)) { if (AreSequentialLoads(Ld2, Ld3, PMul1->VecLd)) { LLVM_DEBUG(dbgs() << "OK: found two pairs of parallel loads!\n"); - PMACPairs.push_back(std::make_pair(PMul0, PMul1)); + R.AddMulPair(PMul0, PMul1); return true; } else if (AreSequentialLoads(Ld3, Ld2, PMul1->VecLd)) { LLVM_DEBUG(dbgs() << "OK: found two pairs of parallel loads!\n"); LLVM_DEBUG(dbgs() << " exchanging Ld2 and Ld3\n"); PMul1->Exchange = true; - PMACPairs.push_back(std::make_pair(PMul0, PMul1)); + R.AddMulPair(PMul0, PMul1); return true; } } else if (AreSequentialLoads(Ld1, Ld0, PMul0->VecLd) && @@ -407,16 +638,18 @@ void ARMParallelDSP::CreateParallelMACPairs(Reduction &R) { LLVM_DEBUG(dbgs() << " and swapping muls\n"); PMul0->Exchange = true; // Only the second operand can be exchanged, so swap the muls. - PMACPairs.push_back(std::make_pair(PMul1, PMul0)); + R.AddMulPair(PMul1, PMul0); return true; } } return false; }; + OpChainList &Muls = R.getMuls(); + const unsigned Elems = Muls.size(); SmallPtrSet Paired; for (unsigned i = 0; i < Elems; ++i) { - BinOpChain *PMul0 = static_cast(Candidates[i].get()); + BinOpChain *PMul0 = static_cast(Muls[i].get()); if (Paired.count(PMul0->Root)) continue; @@ -424,7 +657,7 @@ void ARMParallelDSP::CreateParallelMACPairs(Reduction &R) { if (i == j) continue; - BinOpChain *PMul1 = static_cast(Candidates[j].get()); + BinOpChain *PMul1 = static_cast(Muls[j].get()); if (Paired.count(PMul1->Root)) continue; @@ -435,199 +668,67 @@ void ARMParallelDSP::CreateParallelMACPairs(Reduction &R) { assert(PMul0 != PMul1 && "expected different chains"); - if (CanPair(PMul0, PMul1)) { + if (CanPair(R, PMul0, PMul1)) { Paired.insert(Mul0); Paired.insert(Mul1); break; } } } + return !R.getMulPairs().empty(); } -bool ARMParallelDSP::InsertParallelMACs(Reduction &Reduction) { - Instruction *Acc = Reduction.Phi; - Instruction *InsertAfter = Reduction.AccIntAdd; - for (auto &Pair : Reduction.PMACPairs) { +void ARMParallelDSP::InsertParallelMACs(Reduction &R) { + + auto CreateSMLADCall = [&](SmallVectorImpl &VecLd0, + SmallVectorImpl &VecLd1, + Value *Acc, bool Exchange, + Instruction *InsertAfter) { + // Replace the reduction chain with an intrinsic call + IntegerType *Ty = IntegerType::get(M->getContext(), 32); + LoadInst *WideLd0 = WideLoads.count(VecLd0[0]) ? + WideLoads[VecLd0[0]]->getLoad() : CreateWideLoad(VecLd0, Ty); + LoadInst *WideLd1 = WideLoads.count(VecLd1[0]) ? + WideLoads[VecLd1[0]]->getLoad() : CreateWideLoad(VecLd1, Ty); + + Value* Args[] = { WideLd0, WideLd1, Acc }; + Function *SMLAD = nullptr; + if (Exchange) + SMLAD = Acc->getType()->isIntegerTy(32) ? + Intrinsic::getDeclaration(M, Intrinsic::arm_smladx) : + Intrinsic::getDeclaration(M, Intrinsic::arm_smlaldx); + else + SMLAD = Acc->getType()->isIntegerTy(32) ? + Intrinsic::getDeclaration(M, Intrinsic::arm_smlad) : + Intrinsic::getDeclaration(M, Intrinsic::arm_smlald); + + IRBuilder Builder(InsertAfter->getParent(), + ++BasicBlock::iterator(InsertAfter)); + Instruction *Call = Builder.CreateCall(SMLAD, Args); + NumSMLAD++; + return Call; + }; + + Instruction *InsertAfter = R.getRoot(); + Value *Acc = R.getAccumulator(); + if (!Acc) + Acc = ConstantInt::get(IntegerType::get(M->getContext(), 32), 0); + + LLVM_DEBUG(dbgs() << "Root: " << *InsertAfter << "\n" + << "Acc: " << *Acc << "\n"); + for (auto &Pair : R.getMulPairs()) { BinOpChain *PMul0 = Pair.first; BinOpChain *PMul1 = Pair.second; - LLVM_DEBUG(dbgs() << "Found parallel MACs:\n" + LLVM_DEBUG(dbgs() << "Muls:\n" << "- " << *PMul0->Root << "\n" << "- " << *PMul1->Root << "\n"); Acc = CreateSMLADCall(PMul0->VecLd, PMul1->VecLd, Acc, PMul1->Exchange, InsertAfter); - InsertAfter = Acc; - } - - if (Acc != Reduction.Phi) { - LLVM_DEBUG(dbgs() << "Replace Accumulate: "; Acc->dump()); - Reduction.AccIntAdd->replaceAllUsesWith(Acc); - return true; + InsertAfter = cast(Acc); } - return false; -} - -template -bool IsExtendingLoad(Value *V) { - auto *I = dyn_cast(V); - if (!I) - return false; - - if (I->getSrcTy()->getIntegerBitWidth() != BitWidth) - return false; - - return isa(I->getOperand(0)); -} - -static void MatchParallelMACSequences(Reduction &R, - OpChainList &Candidates) { - Instruction *Acc = R.AccIntAdd; - LLVM_DEBUG(dbgs() << "\n- Analysing:\t" << *Acc << "\n"); - - // Returns false to signal the search should be stopped. - std::function Match = - [&Candidates, &Match](Value *V) -> bool { - - auto *I = dyn_cast(V); - if (!I) - return false; - - switch (I->getOpcode()) { - case Instruction::Add: - if (Match(I->getOperand(0)) || (Match(I->getOperand(1)))) - return true; - break; - case Instruction::Mul: { - Value *Op0 = I->getOperand(0); - Value *Op1 = I->getOperand(1); - if (IsExtendingLoad(Op0) && - IsExtendingLoad(Op1)) { - ValueList LHS = { cast(Op0)->getOperand(0), Op0 }; - ValueList RHS = { cast(Op1)->getOperand(0), Op1 }; - Candidates.push_back(make_unique(I, LHS, RHS)); - } - return false; - } - case Instruction::SExt: - return Match(I->getOperand(0)); - } - return false; - }; - - while (Match (Acc)); - LLVM_DEBUG(dbgs() << "Finished matching MAC sequences, found " - << Candidates.size() << " candidates.\n"); -} - -static bool CheckMACMemory(OpChainList &Candidates) { - for (auto &C : Candidates) { - // A mul has 2 operands, and a narrow op consist of sext and a load; thus - // we expect at least 4 items in this operand value list. - if (C->size() < 4) { - LLVM_DEBUG(dbgs() << "Operand list too short.\n"); - return false; - } - C->PopulateLoads(); - ValueList &LHS = static_cast(C.get())->LHS; - ValueList &RHS = static_cast(C.get())->RHS; - - // Use +=2 to skip over the expected extend instructions. - for (unsigned i = 0, e = LHS.size(); i < e; i += 2) { - if (!isa(LHS[i]) || !isa(RHS[i])) - return false; - } - } - return true; -} - -// Loop Pass that needs to identify integer add/sub reductions of 16-bit vector -// multiplications. -// To use SMLAD: -// 1) we first need to find integer add reduction PHIs, -// 2) then from the PHI, look for this pattern: -// -// acc0 = phi i32 [0, %entry], [%acc1, %loop.body] -// ld0 = load i16 -// sext0 = sext i16 %ld0 to i32 -// ld1 = load i16 -// sext1 = sext i16 %ld1 to i32 -// mul0 = mul %sext0, %sext1 -// ld2 = load i16 -// sext2 = sext i16 %ld2 to i32 -// ld3 = load i16 -// sext3 = sext i16 %ld3 to i32 -// mul1 = mul i32 %sext2, %sext3 -// add0 = add i32 %mul0, %acc0 -// acc1 = add i32 %add0, %mul1 -// -// Which can be selected to: -// -// ldr.h r0 -// ldr.h r1 -// smlad r2, r0, r1, r2 -// -// If constants are used instead of loads, these will need to be hoisted -// out and into a register. -// -// If loop invariants are used instead of loads, these need to be packed -// before the loop begins. -// -bool ARMParallelDSP::MatchSMLAD(Function &F) { - - auto FindReductions = [&](ReductionList &Reductions) { - RecurrenceDescriptor RecDesc; - const bool HasFnNoNaNAttr = - F.getFnAttribute("no-nans-fp-math").getValueAsString() == "true"; - BasicBlock *Latch = L->getLoopLatch(); - - for (PHINode &Phi : Latch->phis()) { - const auto *Ty = Phi.getType(); - if (!Ty->isIntegerTy(32) && !Ty->isIntegerTy(64)) - continue; - - const bool IsReduction = RecurrenceDescriptor::AddReductionVar( - &Phi, RecurrenceDescriptor::RK_IntegerAdd, L, HasFnNoNaNAttr, RecDesc); - - if (!IsReduction) - continue; - - Instruction *Acc = dyn_cast(Phi.getIncomingValueForBlock(Latch)); - if (!Acc) - continue; - - Reductions.push_back(Reduction(&Phi, Acc)); - } - return !Reductions.empty(); - }; - - ReductionList Reductions; - if (!FindReductions(Reductions)) - return false; - - for (auto &R : Reductions) { - OpChainList MACCandidates; - MatchParallelMACSequences(R, MACCandidates); - if (!CheckMACMemory(MACCandidates)) - continue; - - R.MACCandidates = std::move(MACCandidates); - - LLVM_DEBUG(dbgs() << "MAC candidates:\n"; - for (auto &M : R.MACCandidates) - M->Root->dump(); - dbgs() << "\n";); - } - - bool Changed = false; - // Check whether statements in the basic block that write to memory alias - // with the memory locations accessed by the MAC-chains. - for (auto &R : Reductions) { - CreateParallelMACPairs(R); - Changed |= InsertParallelMACs(R); - } - - return Changed; + R.UpdateRoot(cast(Acc)); } LoadInst* ARMParallelDSP::CreateWideLoad(SmallVectorImpl &Loads, @@ -696,43 +797,6 @@ LoadInst* ARMParallelDSP::CreateWideLoad(SmallVectorImpl &Loads, return WideLoad; } -Instruction *ARMParallelDSP::CreateSMLADCall(SmallVectorImpl &VecLd0, - SmallVectorImpl &VecLd1, - Instruction *Acc, bool Exchange, - Instruction *InsertAfter) { - LLVM_DEBUG(dbgs() << "Create SMLAD intrinsic using:\n" - << "- " << *VecLd0[0] << "\n" - << "- " << *VecLd0[1] << "\n" - << "- " << *VecLd1[0] << "\n" - << "- " << *VecLd1[1] << "\n" - << "- " << *Acc << "\n" - << "- Exchange: " << Exchange << "\n"); - - // Replace the reduction chain with an intrinsic call - IntegerType *Ty = IntegerType::get(M->getContext(), 32); - LoadInst *WideLd0 = WideLoads.count(VecLd0[0]) ? - WideLoads[VecLd0[0]]->getLoad() : CreateWideLoad(VecLd0, Ty); - LoadInst *WideLd1 = WideLoads.count(VecLd1[0]) ? - WideLoads[VecLd1[0]]->getLoad() : CreateWideLoad(VecLd1, Ty); - - Value* Args[] = { WideLd0, WideLd1, Acc }; - Function *SMLAD = nullptr; - if (Exchange) - SMLAD = Acc->getType()->isIntegerTy(32) ? - Intrinsic::getDeclaration(M, Intrinsic::arm_smladx) : - Intrinsic::getDeclaration(M, Intrinsic::arm_smlaldx); - else - SMLAD = Acc->getType()->isIntegerTy(32) ? - Intrinsic::getDeclaration(M, Intrinsic::arm_smlad) : - Intrinsic::getDeclaration(M, Intrinsic::arm_smlald); - - IRBuilder Builder(InsertAfter->getParent(), - ++BasicBlock::iterator(InsertAfter)); - CallInst *Call = Builder.CreateCall(SMLAD, Args); - NumSMLAD++; - return Call; -} - // Compare the value lists in Other to this chain. bool BinOpChain::AreSymmetrical(BinOpChain *Other) { // Element-by-element comparison of Value lists returning true if they are diff --git a/test/CodeGen/ARM/ParallelDSP/aliasing.ll b/test/CodeGen/ARM/ParallelDSP/aliasing.ll index 47047c7f44b..4edf5bfbbef 100644 --- a/test/CodeGen/ARM/ParallelDSP/aliasing.ll +++ b/test/CodeGen/ARM/ParallelDSP/aliasing.ll @@ -451,8 +451,10 @@ for.body: br i1 %exitcond, label %for.body, label %for.cond.cleanup } +; TODO: I think we should be able to generate one smlad here. The search fails +; when it finds the alias. ; CHECK-LABEL: one_pair_alias -; FIXME: This tests shows we have a bug with smlad insertion +; CHECK-NOT: call i32 @llvm.arm.smlad define i32 @one_pair_alias(i16* noalias nocapture readonly %b, i16* noalias nocapture readonly %c) { entry: br label %for.body diff --git a/test/CodeGen/ARM/ParallelDSP/inner-full-unroll.ll b/test/CodeGen/ARM/ParallelDSP/inner-full-unroll.ll new file mode 100644 index 00000000000..052fb51a8dd --- /dev/null +++ b/test/CodeGen/ARM/ParallelDSP/inner-full-unroll.ll @@ -0,0 +1,151 @@ +; RUN: opt -mtriple=thumbv7em -arm-parallel-dsp -dce -S %s -o - | FileCheck %s + +; CHECK-LABEL: full_unroll +; CHECK: [[IV:%[^ ]+]] = phi i32 +; CHECK: [[AI:%[^ ]+]] = getelementptr inbounds i32, i32* %a, i32 [[IV]] +; CHECK: [[BI:%[^ ]+]] = getelementptr inbounds i16*, i16** %b, i32 [[IV]] +; CHECK: [[BIJ:%[^ ]+]] = load i16*, i16** %arrayidx5, align 4 +; CHECK: [[CI:%[^ ]+]] = getelementptr inbounds i16*, i16** %c, i32 [[IV]] +; CHECK: [[CIJ:%[^ ]+]] = load i16*, i16** [[CI]], align 4 +; CHECK: [[BIJ_CAST:%[^ ]+]] = bitcast i16* [[BIJ]] to i32* +; CHECK: [[BIJ_LD:%[^ ]+]] = load i32, i32* [[BIJ_CAST]], align 2 +; CHECK: [[CIJ_CAST:%[^ ]+]] = bitcast i16* [[CIJ]] to i32* +; CHECK: [[CIJ_LD:%[^ ]+]] = load i32, i32* [[CIJ_CAST]], align 2 +; CHECK: [[BIJ_2:%[^ ]+]] = getelementptr inbounds i16, i16* [[BIJ]], i32 2 +; CHECK: [[BIJ_2_CAST:%[^ ]+]] = bitcast i16* [[BIJ_2]] to i32* +; CHECK: [[BIJ_2_LD:%[^ ]+]] = load i32, i32* [[BIJ_2_CAST]], align 2 +; CHECK: [[CIJ_2:%[^ ]+]] = getelementptr inbounds i16, i16* [[CIJ]], i32 2 +; CHECK: [[CIJ_2_CAST:%[^ ]+]] = bitcast i16* [[CIJ_2]] to i32* +; CHECK: [[CIJ_2_LD:%[^ ]+]] = load i32, i32* [[CIJ_2_CAST]], align 2 +; CHECK: [[SMLAD0:%[^ ]+]] = call i32 @llvm.arm.smlad(i32 [[CIJ_2_LD]], i32 [[BIJ_2_LD]], i32 0) +; CHECK: [[SMLAD1:%[^ ]+]] = call i32 @llvm.arm.smlad(i32 [[CIJ_LD]], i32 [[BIJ_LD]], i32 [[SMLAD0]]) +; CHECK: store i32 [[SMLAD1]], i32* %arrayidx, align 4 + +define void @full_unroll(i32* noalias nocapture %a, i16** noalias nocapture readonly %b, i16** noalias nocapture readonly %c, i32 %N) { +entry: + %cmp29 = icmp eq i32 %N, 0 + br i1 %cmp29, label %for.cond.cleanup, label %for.body + +for.cond.cleanup: ; preds = %for.body, %entry + ret void + +for.body: ; preds = %entry, %for.body + %i.030 = phi i32 [ %inc12, %for.body ], [ 0, %entry ] + %arrayidx = getelementptr inbounds i32, i32* %a, i32 %i.030 + %arrayidx5 = getelementptr inbounds i16*, i16** %b, i32 %i.030 + %0 = load i16*, i16** %arrayidx5, align 4 + %arrayidx7 = getelementptr inbounds i16*, i16** %c, i32 %i.030 + %1 = load i16*, i16** %arrayidx7, align 4 + %2 = load i16, i16* %0, align 2 + %conv = sext i16 %2 to i32 + %3 = load i16, i16* %1, align 2 + %conv9 = sext i16 %3 to i32 + %mul = mul nsw i32 %conv9, %conv + %arrayidx6.1 = getelementptr inbounds i16, i16* %0, i32 1 + %4 = load i16, i16* %arrayidx6.1, align 2 + %conv.1 = sext i16 %4 to i32 + %arrayidx8.1 = getelementptr inbounds i16, i16* %1, i32 1 + %5 = load i16, i16* %arrayidx8.1, align 2 + %conv9.1 = sext i16 %5 to i32 + %mul.1 = mul nsw i32 %conv9.1, %conv.1 + %add.1 = add nsw i32 %mul.1, %mul + %arrayidx6.2 = getelementptr inbounds i16, i16* %0, i32 2 + %6 = load i16, i16* %arrayidx6.2, align 2 + %conv.2 = sext i16 %6 to i32 + %arrayidx8.2 = getelementptr inbounds i16, i16* %1, i32 2 + %7 = load i16, i16* %arrayidx8.2, align 2 + %conv9.2 = sext i16 %7 to i32 + %mul.2 = mul nsw i32 %conv9.2, %conv.2 + %add.2 = add nsw i32 %mul.2, %add.1 + %arrayidx6.3 = getelementptr inbounds i16, i16* %0, i32 3 + %8 = load i16, i16* %arrayidx6.3, align 2 + %conv.3 = sext i16 %8 to i32 + %arrayidx8.3 = getelementptr inbounds i16, i16* %1, i32 3 + %9 = load i16, i16* %arrayidx8.3, align 2 + %conv9.3 = sext i16 %9 to i32 + %mul.3 = mul nsw i32 %conv9.3, %conv.3 + %add.3 = add nsw i32 %mul.3, %add.2 + store i32 %add.3, i32* %arrayidx, align 4 + %inc12 = add nuw i32 %i.030, 1 + %exitcond = icmp eq i32 %inc12, %N + br i1 %exitcond, label %for.cond.cleanup, label %for.body +} + +; CHECK-LABEL: full_unroll_sub +; CHEC: [[IV:%[^ ]+]] = phi i32 +; CHECK: [[AI:%[^ ]+]] = getelementptr inbounds i32, i32* %a, i32 [[IV]] +; CHECK: [[BI:%[^ ]+]] = getelementptr inbounds i16*, i16** %b, i32 [[IV]] +; CHECK: [[BIJ:%[^ ]+]] = load i16*, i16** [[BI]], align 4 +; CHECK: [[CI:%[^ ]+]] = getelementptr inbounds i16*, i16** %c, i32 [[IV]] +; CHECK: [[CIJ:%[^ ]+]] = load i16*, i16** [[CI]], align 4 +; CHECK: [[BIJ_LD:%[^ ]+]] = load i16, i16* [[BIJ]], align 2 +; CHECK: [[BIJ_LD_SXT:%[^ ]+]] = sext i16 [[BIJ_LD]] to i32 +; CHECK: [[CIJ_LD:%[^ ]+]] = load i16, i16* [[CIJ]], align 2 +; CHECK: [[CIJ_LD_SXT:%[^ ]+]] = sext i16 [[CIJ_LD]] to i32 +; CHECK: [[SUB:%[^ ]+]] = sub nsw i32 [[CIJ_LD_SXT]], [[BIJ_LD_SXT]] +; CHECK: [[BIJ_1:%[^ ]+]] = getelementptr inbounds i16, i16* [[BIJ]], i32 1 +; CHECK: [[BIJ_1_LD:%[^ ]+]] = load i16, i16* [[BIJ_1]], align 2 +; CHECK: [[BIJ_1_LD_SXT:%[^ ]+]] = sext i16 [[BIJ_1_LD]] to i32 +; CHECK: [[CIJ_1:%[^ ]+]] = getelementptr inbounds i16, i16* [[CIJ]], i32 1 +; CHECK: [[CIJ_1_LD:%[^ ]+]] = load i16, i16* [[CIJ_1]], align 2 +; CHECK: [[CIJ_1_LD_SXT:%[^ ]+]] = sext i16 [[CIJ_1_LD]] to i32 +; CHECK: [[MUL:%[^ ]+]] = mul nsw i32 [[CIJ_1_LD_SXT]], [[BIJ_1_LD_SXT]] +; CHECK: [[ACC:%[^ ]+]] = add nsw i32 [[MUL]], [[SUB]] +; CHECK: [[BIJ_2:%[^ ]+]] = getelementptr inbounds i16, i16* [[BIJ]], i32 2 +; CHECK: [[BIJ_2_CAST:%[^ ]+]] = bitcast i16* [[BIJ_2]] to i32* +; CHECK: [[BIJ_2_LD:%[^ ]+]] = load i32, i32* [[BIJ_2_CAST]], align 2 +; CHECK: [[CIJ_2:%[^ ]+]] = getelementptr inbounds i16, i16* [[CIJ]], i32 2 +; CHECK: [[CIJ_2_CAST:%[^ ]+]] = bitcast i16* [[CIJ_2]] to i32* +; CHECK: [[CIJ_2_LD:%[^ ]+]] = load i32, i32* [[CIJ_2_CAST]], align 2 +; CHECK: [[SMLAD0:%[^ ]+]] = call i32 @llvm.arm.smlad(i32 [[CIJ_2_LD]], i32 [[BIJ_2_LD]], i32 [[ACC]]) +; CHECK: store i32 [[SMLAD0]], i32* %arrayidx, align 4 + +define void @full_unroll_sub(i32* noalias nocapture %a, i16** noalias nocapture readonly %b, i16** noalias nocapture readonly %c, i32 %N) { +entry: + %cmp29 = icmp eq i32 %N, 0 + br i1 %cmp29, label %for.cond.cleanup, label %for.body + +for.cond.cleanup: ; preds = %for.body, %entry + ret void + +for.body: ; preds = %entry, %for.body + %i.030 = phi i32 [ %inc12, %for.body ], [ 0, %entry ] + %arrayidx = getelementptr inbounds i32, i32* %a, i32 %i.030 + %arrayidx5 = getelementptr inbounds i16*, i16** %b, i32 %i.030 + %0 = load i16*, i16** %arrayidx5, align 4 + %arrayidx7 = getelementptr inbounds i16*, i16** %c, i32 %i.030 + %1 = load i16*, i16** %arrayidx7, align 4 + %2 = load i16, i16* %0, align 2 + %conv = sext i16 %2 to i32 + %3 = load i16, i16* %1, align 2 + %conv9 = sext i16 %3 to i32 + %sub = sub nsw i32 %conv9, %conv + %arrayidx6.1 = getelementptr inbounds i16, i16* %0, i32 1 + %4 = load i16, i16* %arrayidx6.1, align 2 + %conv.1 = sext i16 %4 to i32 + %arrayidx8.1 = getelementptr inbounds i16, i16* %1, i32 1 + %5 = load i16, i16* %arrayidx8.1, align 2 + %conv9.1 = sext i16 %5 to i32 + %mul.1 = mul nsw i32 %conv9.1, %conv.1 + %add.1 = add nsw i32 %mul.1, %sub + %arrayidx6.2 = getelementptr inbounds i16, i16* %0, i32 2 + %6 = load i16, i16* %arrayidx6.2, align 2 + %conv.2 = sext i16 %6 to i32 + %arrayidx8.2 = getelementptr inbounds i16, i16* %1, i32 2 + %7 = load i16, i16* %arrayidx8.2, align 2 + %conv9.2 = sext i16 %7 to i32 + %mul.2 = mul nsw i32 %conv9.2, %conv.2 + %add.2 = add nsw i32 %mul.2, %add.1 + %arrayidx6.3 = getelementptr inbounds i16, i16* %0, i32 3 + %8 = load i16, i16* %arrayidx6.3, align 2 + %conv.3 = sext i16 %8 to i32 + %arrayidx8.3 = getelementptr inbounds i16, i16* %1, i32 3 + %9 = load i16, i16* %arrayidx8.3, align 2 + %conv9.3 = sext i16 %9 to i32 + %mul.3 = mul nsw i32 %conv9.3, %conv.3 + %add.3 = add nsw i32 %mul.3, %add.2 + store i32 %add.3, i32* %arrayidx, align 4 + %inc12 = add nuw i32 %i.030, 1 + %exitcond = icmp eq i32 %inc12, %N + br i1 %exitcond, label %for.cond.cleanup, label %for.body +} -- 2.40.0