From a06891a3ca56bda101354b9a2a0b10ac18428ddf Mon Sep 17 00:00:00 2001 From: Jakub Kuderski Date: Wed, 25 Sep 2019 14:04:36 +0000 Subject: [PATCH] [Dominators][AMDGPU] Don't use virtual exit node in findNearestCommonDominator. Cleanup MachinePostDominators. Summary: This patch fixes a bug that originated from passing a virtual exit block (nullptr) to `MachinePostDominatorTee::findNearestCommonDominator` and resulted in assertion failures inside its callee. It also applies a small cleanup to the class. The patch introduces a new function in PDT that given a list of `MachineBasicBlock`s finds their NCD. The new overload of `findNearestCommonDominator` handles virtual root correctly. Note that similar handling of virtual root nodes is not necessary in (forward) `DominatorTree`s, as right now they don't use virtual roots. Reviewers: tstellar, tpr, nhaehnle, arsenm, NutshellySima, grosser, hliao Reviewed By: hliao Subscribers: hliao, kzhuravl, jvesely, wdng, yaxunl, dstuttard, t-tye, hiraditya, llvm-commits Tags: #amdgpu, #llvm Differential Revision: https://reviews.llvm.org/D67974 git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@372874 91177308-0d34-0410-b5e6-96231b3b80d8 --- include/llvm/CodeGen/MachinePostDominators.h | 45 +++++++++++--------- include/llvm/CodeGen/MachineRegionInfo.h | 2 +- lib/CodeGen/MachinePostDominators.cpp | 44 ++++++++++++------- lib/Target/AMDGPU/SILowerI1Copies.cpp | 20 ++++----- 4 files changed, 65 insertions(+), 46 deletions(-) diff --git a/include/llvm/CodeGen/MachinePostDominators.h b/include/llvm/CodeGen/MachinePostDominators.h index b67e6b52ac8..a0c2c78de8d 100644 --- a/include/llvm/CodeGen/MachinePostDominators.h +++ b/include/llvm/CodeGen/MachinePostDominators.h @@ -16,68 +16,75 @@ #include "llvm/CodeGen/MachineDominators.h" #include "llvm/CodeGen/MachineFunctionPass.h" +#include namespace llvm { /// -/// PostDominatorTree Class - Concrete subclass of DominatorTree that is used -/// to compute the post-dominator tree. +/// MachinePostDominatorTree - an analysis pass wrapper for DominatorTree +/// used to compute the post-dominator tree for MachineFunctions. /// -struct MachinePostDominatorTree : public MachineFunctionPass { -private: - PostDomTreeBase *DT; +class MachinePostDominatorTree : public MachineFunctionPass { + using PostDomTreeT = PostDomTreeBase; + std::unique_ptr PDT; public: static char ID; MachinePostDominatorTree(); - ~MachinePostDominatorTree() override; - FunctionPass *createMachinePostDominatorTreePass(); const SmallVectorImpl &getRoots() const { - return DT->getRoots(); + return PDT->getRoots(); } - MachineDomTreeNode *getRootNode() const { - return DT->getRootNode(); - } + MachineDomTreeNode *getRootNode() const { return PDT->getRootNode(); } MachineDomTreeNode *operator[](MachineBasicBlock *BB) const { - return DT->getNode(BB); + return PDT->getNode(BB); } MachineDomTreeNode *getNode(MachineBasicBlock *BB) const { - return DT->getNode(BB); + return PDT->getNode(BB); } bool dominates(const MachineDomTreeNode *A, const MachineDomTreeNode *B) const { - return DT->dominates(A, B); + return PDT->dominates(A, B); } bool dominates(const MachineBasicBlock *A, const MachineBasicBlock *B) const { - return DT->dominates(A, B); + return PDT->dominates(A, B); } bool properlyDominates(const MachineDomTreeNode *A, const MachineDomTreeNode *B) const { - return DT->properlyDominates(A, B); + return PDT->properlyDominates(A, B); } bool properlyDominates(const MachineBasicBlock *A, const MachineBasicBlock *B) const { - return DT->properlyDominates(A, B); + return PDT->properlyDominates(A, B); + } + + bool isVirtualRoot(const MachineDomTreeNode *Node) const { + return PDT->isVirtualRoot(Node); } MachineBasicBlock *findNearestCommonDominator(MachineBasicBlock *A, - MachineBasicBlock *B) { - return DT->findNearestCommonDominator(A, B); + MachineBasicBlock *B) const { + return PDT->findNearestCommonDominator(A, B); } + /// Returns the nearest common dominator of the given blocks. + /// If that tree node is a virtual root, a nullptr will be returned. + MachineBasicBlock * + findNearestCommonDominator(ArrayRef Blocks) const; + bool runOnMachineFunction(MachineFunction &MF) override; void getAnalysisUsage(AnalysisUsage &AU) const override; + void releaseMemory() override { PDT.reset(nullptr); } void print(llvm::raw_ostream &OS, const Module *M = nullptr) const override; }; } //end of namespace llvm diff --git a/include/llvm/CodeGen/MachineRegionInfo.h b/include/llvm/CodeGen/MachineRegionInfo.h index 6d9fb9b9100..eeb69fef2c6 100644 --- a/include/llvm/CodeGen/MachineRegionInfo.h +++ b/include/llvm/CodeGen/MachineRegionInfo.h @@ -22,7 +22,7 @@ namespace llvm { -struct MachinePostDominatorTree; +class MachinePostDominatorTree; class MachineRegion; class MachineRegionNode; class MachineRegionInfo; diff --git a/lib/CodeGen/MachinePostDominators.cpp b/lib/CodeGen/MachinePostDominators.cpp index 7f220ed1fd8..f2fc9f814f8 100644 --- a/lib/CodeGen/MachinePostDominators.cpp +++ b/lib/CodeGen/MachinePostDominators.cpp @@ -13,6 +13,8 @@ #include "llvm/CodeGen/MachinePostDominators.h" +#include "llvm/ADT/STLExtras.h" + using namespace llvm; namespace llvm { @@ -25,33 +27,43 @@ char MachinePostDominatorTree::ID = 0; INITIALIZE_PASS(MachinePostDominatorTree, "machinepostdomtree", "MachinePostDominator Tree Construction", true, true) -MachinePostDominatorTree::MachinePostDominatorTree() : MachineFunctionPass(ID) { +MachinePostDominatorTree::MachinePostDominatorTree() + : MachineFunctionPass(ID), PDT(nullptr) { initializeMachinePostDominatorTreePass(*PassRegistry::getPassRegistry()); - DT = new PostDomTreeBase(); } -FunctionPass * -MachinePostDominatorTree::createMachinePostDominatorTreePass() { +FunctionPass *MachinePostDominatorTree::createMachinePostDominatorTreePass() { return new MachinePostDominatorTree(); } -bool -MachinePostDominatorTree::runOnMachineFunction(MachineFunction &F) { - DT->recalculate(F); +bool MachinePostDominatorTree::runOnMachineFunction(MachineFunction &F) { + PDT = std::make_unique(); + PDT->recalculate(F); return false; } -MachinePostDominatorTree::~MachinePostDominatorTree() { - delete DT; -} - -void -MachinePostDominatorTree::getAnalysisUsage(AnalysisUsage &AU) const { +void MachinePostDominatorTree::getAnalysisUsage(AnalysisUsage &AU) const { AU.setPreservesAll(); MachineFunctionPass::getAnalysisUsage(AU); } -void -MachinePostDominatorTree::print(llvm::raw_ostream &OS, const Module *M) const { - DT->print(OS); +MachineBasicBlock *MachinePostDominatorTree::findNearestCommonDominator( + ArrayRef Blocks) const { + assert(!Blocks.empty()); + + MachineBasicBlock *NCD = Blocks.front(); + for (MachineBasicBlock *BB : Blocks.drop_front()) { + NCD = PDT->findNearestCommonDominator(NCD, BB); + + // Stop when the root is reached. + if (PDT->isVirtualRoot(PDT->getNode(NCD))) + return nullptr; + } + + return NCD; +} + +void MachinePostDominatorTree::print(llvm::raw_ostream &OS, + const Module *M) const { + PDT->print(OS); } diff --git a/lib/Target/AMDGPU/SILowerI1Copies.cpp b/lib/Target/AMDGPU/SILowerI1Copies.cpp index 86242f1fb64..b4541253635 100644 --- a/lib/Target/AMDGPU/SILowerI1Copies.cpp +++ b/lib/Target/AMDGPU/SILowerI1Copies.cpp @@ -589,12 +589,12 @@ void SILowerI1Copies::lowerPhis() { // Phis in a loop that are observed outside the loop receive a simple but // conservatively correct treatment. - MachineBasicBlock *PostDomBound = &MBB; - for (MachineInstr &Use : MRI->use_instructions(DstReg)) { - PostDomBound = - PDT->findNearestCommonDominator(PostDomBound, Use.getParent()); - } + std::vector DomBlocks = {&MBB}; + for (MachineInstr &Use : MRI->use_instructions(DstReg)) + DomBlocks.push_back(Use.getParent()); + MachineBasicBlock *PostDomBound = + PDT->findNearestCommonDominator(DomBlocks); unsigned FoundLoopLevel = LF.findLoop(PostDomBound); SSAUpdater.Initialize(DstReg); @@ -711,12 +711,12 @@ void SILowerI1Copies::lowerCopiesToI1() { // Defs in a loop that are observed outside the loop must be transformed // into appropriate bit manipulation. - MachineBasicBlock *PostDomBound = &MBB; - for (MachineInstr &Use : MRI->use_instructions(DstReg)) { - PostDomBound = - PDT->findNearestCommonDominator(PostDomBound, Use.getParent()); - } + std::vector DomBlocks = {&MBB}; + for (MachineInstr &Use : MRI->use_instructions(DstReg)) + DomBlocks.push_back(Use.getParent()); + MachineBasicBlock *PostDomBound = + PDT->findNearestCommonDominator(DomBlocks); unsigned FoundLoopLevel = LF.findLoop(PostDomBound); if (FoundLoopLevel) { SSAUpdater.Initialize(DstReg); -- 2.40.0