From a00a56b905a1e6a1102f21fad0b4ca71e647d8b5 Mon Sep 17 00:00:00 2001 From: Bob Wilson Date: Fri, 11 Apr 2014 17:16:13 +0000 Subject: [PATCH] [PGO] Change MapRegionCounters to be a RecursiveASTVisitor. This avoids the overhead of specifying all the traversal code when using ConstStmtVisitor and makes it a lot easier to maintain this. git-svn-id: https://llvm.org/svn/llvm-project/cfe/trunk@206039 91177308-0d34-0410-b5e6-96231b3b80d8 --- lib/CodeGen/CodeGenPGO.cpp | 207 ++++++++++++++----------------------- 1 file changed, 75 insertions(+), 132 deletions(-) diff --git a/lib/CodeGen/CodeGenPGO.cpp b/lib/CodeGen/CodeGenPGO.cpp index ea069cebfe..37ba0fb885 100644 --- a/lib/CodeGen/CodeGenPGO.cpp +++ b/lib/CodeGen/CodeGenPGO.cpp @@ -321,8 +321,8 @@ llvm::Function *CodeGenPGO::emitInitialization(CodeGenModule &CGM) { } namespace { - /// A StmtVisitor that fills a map of statements to PGO counters. - struct MapRegionCounters : public ConstStmtVisitor { + /// A RecursiveASTVisitor that fills a map of statements to PGO counters. + struct MapRegionCounters : public RecursiveASTVisitor { /// The next counter value to assign. unsigned NextCounter; /// The map of statements to counters. @@ -331,135 +331,55 @@ namespace { MapRegionCounters(llvm::DenseMap &CounterMap) : NextCounter(0), CounterMap(CounterMap) {} - void VisitChildren(const Stmt *S) { - for (Stmt::const_child_range I = S->children(); I; ++I) - if (*I) - this->Visit(*I); + // Do not traverse the BlockDecl inside a BlockExpr since each BlockDecl + // is handled as a separate function. + bool TraverseBlockExpr(BlockExpr *block) { return true; } + + bool VisitDecl(const Decl *D) { + switch (D->getKind()) { + default: + break; + case Decl::Function: + case Decl::CXXMethod: + case Decl::CXXConstructor: + case Decl::CXXDestructor: + case Decl::CXXConversion: + case Decl::ObjCMethod: + case Decl::Block: + CounterMap[D->getBody()] = NextCounter++; + break; + } + return true; } - void VisitStmt(const Stmt *S) { VisitChildren(S); } - /// Assign a counter to track entry to the function body. - void VisitFunctionDecl(const FunctionDecl *D) { - CounterMap[D->getBody()] = NextCounter++; - Visit(D->getBody()); - } - void VisitObjCMethodDecl(const ObjCMethodDecl *D) { - CounterMap[D->getBody()] = NextCounter++; - Visit(D->getBody()); - } - void VisitBlockDecl(const BlockDecl *D) { - CounterMap[D->getBody()] = NextCounter++; - Visit(D->getBody()); - } - /// Assign a counter to track the block following a label. - void VisitLabelStmt(const LabelStmt *S) { - CounterMap[S] = NextCounter++; - Visit(S->getSubStmt()); - } - /// Assign a counter for the body of a while loop. - void VisitWhileStmt(const WhileStmt *S) { - CounterMap[S] = NextCounter++; - Visit(S->getCond()); - Visit(S->getBody()); - } - /// Assign a counter for the body of a do-while loop. - void VisitDoStmt(const DoStmt *S) { - CounterMap[S] = NextCounter++; - Visit(S->getBody()); - Visit(S->getCond()); - } - /// Assign a counter for the body of a for loop. - void VisitForStmt(const ForStmt *S) { - CounterMap[S] = NextCounter++; - if (S->getInit()) - Visit(S->getInit()); - const Expr *E; - if ((E = S->getCond())) - Visit(E); - if ((E = S->getInc())) - Visit(E); - Visit(S->getBody()); - } - /// Assign a counter for the body of a for-range loop. - void VisitCXXForRangeStmt(const CXXForRangeStmt *S) { - CounterMap[S] = NextCounter++; - Visit(S->getRangeStmt()); - Visit(S->getBeginEndStmt()); - Visit(S->getCond()); - Visit(S->getLoopVarStmt()); - Visit(S->getBody()); - Visit(S->getInc()); - } - /// Assign a counter for the body of a for-collection loop. - void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) { - CounterMap[S] = NextCounter++; - Visit(S->getElement()); - Visit(S->getBody()); - } - /// Assign a counter for the exit block of the switch statement. - void VisitSwitchStmt(const SwitchStmt *S) { - CounterMap[S] = NextCounter++; - Visit(S->getCond()); - Visit(S->getBody()); - } - /// Assign a counter for a particular case in a switch. This counts jumps - /// from the switch header as well as fallthrough from the case before this - /// one. - void VisitCaseStmt(const CaseStmt *S) { - CounterMap[S] = NextCounter++; - Visit(S->getSubStmt()); - } - /// Assign a counter for the default case of a switch statement. The count - /// is the number of branches from the loop header to the default, and does - /// not include fallthrough from previous cases. If we have multiple - /// conditional branch blocks from the switch instruction to the default - /// block, as with large GNU case ranges, this is the counter for the last - /// edge in that series, rather than the first. - void VisitDefaultStmt(const DefaultStmt *S) { - CounterMap[S] = NextCounter++; - Visit(S->getSubStmt()); - } - /// Assign a counter for the "then" part of an if statement. The count for - /// the "else" part, if it exists, will be calculated from this counter. - void VisitIfStmt(const IfStmt *S) { - CounterMap[S] = NextCounter++; - Visit(S->getCond()); - Visit(S->getThen()); - if (S->getElse()) - Visit(S->getElse()); - } - /// Assign a counter for the continuation block of a C++ try statement. - void VisitCXXTryStmt(const CXXTryStmt *S) { - CounterMap[S] = NextCounter++; - Visit(S->getTryBlock()); - for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I) - Visit(S->getHandler(I)); - } - /// Assign a counter for a catch statement's handler block. - void VisitCXXCatchStmt(const CXXCatchStmt *S) { - CounterMap[S] = NextCounter++; - Visit(S->getHandlerBlock()); - } - /// Assign a counter for the "true" part of a conditional operator. The - /// count in the "false" part will be calculated from this counter. - void VisitAbstractConditionalOperator( - const AbstractConditionalOperator *E) { - CounterMap[E] = NextCounter++; - Visit(E->getCond()); - Visit(E->getTrueExpr()); - Visit(E->getFalseExpr()); - } - /// Assign a counter for the right hand side of a logical and operator. - void VisitBinLAnd(const BinaryOperator *E) { - CounterMap[E] = NextCounter++; - Visit(E->getLHS()); - Visit(E->getRHS()); - } - /// Assign a counter for the right hand side of a logical or operator. - void VisitBinLOr(const BinaryOperator *E) { - CounterMap[E] = NextCounter++; - Visit(E->getLHS()); - Visit(E->getRHS()); + bool VisitStmt(const Stmt *S) { + switch (S->getStmtClass()) { + default: + break; + case Stmt::LabelStmtClass: + case Stmt::WhileStmtClass: + case Stmt::DoStmtClass: + case Stmt::ForStmtClass: + case Stmt::CXXForRangeStmtClass: + case Stmt::ObjCForCollectionStmtClass: + case Stmt::SwitchStmtClass: + case Stmt::CaseStmtClass: + case Stmt::DefaultStmtClass: + case Stmt::IfStmtClass: + case Stmt::CXXTryStmtClass: + case Stmt::CXXCatchStmtClass: + case Stmt::ConditionalOperatorClass: + case Stmt::BinaryConditionalOperatorClass: + CounterMap[S] = NextCounter++; + break; + case Stmt::BinaryOperatorClass: { + const BinaryOperator *BO = cast(S); + if (BO->getOpcode() == BO_LAnd || BO->getOpcode() == BO_LOr) + CounterMap[S] = NextCounter++; + break; + } + } + return true; } }; @@ -504,6 +424,7 @@ namespace { } void VisitFunctionDecl(const FunctionDecl *D) { + // Counter tracks entry to the function body. RegionCounter Cnt(PGO, D->getBody()); Cnt.beginRegion(); CountMap[D->getBody()] = PGO.getCurrentRegionCount(); @@ -511,6 +432,7 @@ namespace { } void VisitObjCMethodDecl(const ObjCMethodDecl *D) { + // Counter tracks entry to the method body. RegionCounter Cnt(PGO, D->getBody()); Cnt.beginRegion(); CountMap[D->getBody()] = PGO.getCurrentRegionCount(); @@ -518,6 +440,7 @@ namespace { } void VisitBlockDecl(const BlockDecl *D) { + // Counter tracks entry to the block body. RegionCounter Cnt(PGO, D->getBody()); Cnt.beginRegion(); CountMap[D->getBody()] = PGO.getCurrentRegionCount(); @@ -540,6 +463,7 @@ namespace { void VisitLabelStmt(const LabelStmt *S) { RecordNextStmtCount = false; + // Counter tracks the block following the label. RegionCounter Cnt(PGO, S); Cnt.beginRegion(); CountMap[S] = PGO.getCurrentRegionCount(); @@ -564,6 +488,7 @@ namespace { void VisitWhileStmt(const WhileStmt *S) { RecordStmtCount(S); + // Counter tracks the body of the loop. RegionCounter Cnt(PGO, S); BreakContinueStack.push_back(BreakContinue()); // Visit the body region first so the break/continue adjustments can be @@ -589,6 +514,7 @@ namespace { void VisitDoStmt(const DoStmt *S) { RecordStmtCount(S); + // Counter tracks the body of the loop. RegionCounter Cnt(PGO, S); BreakContinueStack.push_back(BreakContinue()); Cnt.beginRegion(/*AddIncomingFallThrough=*/true); @@ -615,6 +541,7 @@ namespace { RecordStmtCount(S); if (S->getInit()) Visit(S->getInit()); + // Counter tracks the body of the loop. RegionCounter Cnt(PGO, S); BreakContinueStack.push_back(BreakContinue()); // Visit the body region first. (This is basically the same as a while @@ -653,6 +580,7 @@ namespace { RecordStmtCount(S); Visit(S->getRangeStmt()); Visit(S->getBeginEndStmt()); + // Counter tracks the body of the loop. RegionCounter Cnt(PGO, S); BreakContinueStack.push_back(BreakContinue()); // Visit the body region first. (This is basically the same as a while @@ -687,6 +615,7 @@ namespace { void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) { RecordStmtCount(S); Visit(S->getElement()); + // Counter tracks the body of the loop. RegionCounter Cnt(PGO, S); BreakContinueStack.push_back(BreakContinue()); Cnt.beginRegion(); @@ -708,6 +637,7 @@ namespace { BreakContinue BC = BreakContinueStack.pop_back_val(); if (!BreakContinueStack.empty()) BreakContinueStack.back().ContinueCount += BC.ContinueCount; + // Counter tracks the exit block of the switch. RegionCounter ExitCnt(PGO, S); ExitCnt.beginRegion(); RecordNextStmtCount = true; @@ -715,6 +645,9 @@ namespace { void VisitCaseStmt(const CaseStmt *S) { RecordNextStmtCount = false; + // Counter for this particular case. This counts only jumps from the + // switch header and does not include fallthrough from the case before + // this one. RegionCounter Cnt(PGO, S); Cnt.beginRegion(/*AddIncomingFallThrough=*/true); CountMap[S] = Cnt.getCount(); @@ -724,6 +657,8 @@ namespace { void VisitDefaultStmt(const DefaultStmt *S) { RecordNextStmtCount = false; + // Counter for this default case. This does not include fallthrough from + // the previous case. RegionCounter Cnt(PGO, S); Cnt.beginRegion(/*AddIncomingFallThrough=*/true); CountMap[S] = Cnt.getCount(); @@ -733,6 +668,8 @@ namespace { void VisitIfStmt(const IfStmt *S) { RecordStmtCount(S); + // Counter tracks the "then" part of an if statement. The count for + // the "else" part, if it exists, will be calculated from this counter. RegionCounter Cnt(PGO, S); Visit(S->getCond()); @@ -756,6 +693,7 @@ namespace { Visit(S->getTryBlock()); for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I) Visit(S->getHandler(I)); + // Counter tracks the continuation block of the try statement. RegionCounter Cnt(PGO, S); Cnt.beginRegion(); RecordNextStmtCount = true; @@ -763,6 +701,7 @@ namespace { void VisitCXXCatchStmt(const CXXCatchStmt *S) { RecordNextStmtCount = false; + // Counter tracks the catch statement's handler block. RegionCounter Cnt(PGO, S); Cnt.beginRegion(); CountMap[S] = PGO.getCurrentRegionCount(); @@ -772,6 +711,8 @@ namespace { void VisitAbstractConditionalOperator( const AbstractConditionalOperator *E) { RecordStmtCount(E); + // Counter tracks the "true" part of a conditional operator. The + // count in the "false" part will be calculated from this counter. RegionCounter Cnt(PGO, E); Visit(E->getCond()); @@ -791,6 +732,7 @@ namespace { void VisitBinLAnd(const BinaryOperator *E) { RecordStmtCount(E); + // Counter tracks the right hand side of a logical and operator. RegionCounter Cnt(PGO, E); Visit(E->getLHS()); Cnt.beginRegion(); @@ -803,6 +745,7 @@ namespace { void VisitBinLOr(const BinaryOperator *E) { RecordStmtCount(E); + // Counter tracks the right hand side of a logical or operator. RegionCounter Cnt(PGO, E); Visit(E->getLHS()); Cnt.beginRegion(); @@ -884,11 +827,11 @@ void CodeGenPGO::mapRegionCounters(const Decl *D) { RegionCounterMap.reset(new llvm::DenseMap); MapRegionCounters Walker(*RegionCounterMap); if (const FunctionDecl *FD = dyn_cast_or_null(D)) - Walker.VisitFunctionDecl(FD); + Walker.TraverseDecl(const_cast(FD)); else if (const ObjCMethodDecl *MD = dyn_cast_or_null(D)) - Walker.VisitObjCMethodDecl(MD); + Walker.TraverseDecl(const_cast(MD)); else if (const BlockDecl *BD = dyn_cast_or_null(D)) - Walker.VisitBlockDecl(BD); + Walker.TraverseDecl(const_cast(BD)); NumRegionCounters = Walker.NextCounter; // FIXME: The number of counters isn't sufficient for the hash FunctionHash = NumRegionCounters; -- 2.40.0