]> granicus.if.org Git - llvm/commitdiff
Refactor synthetic profile count computation. NFC.
authorEaswaran Raman <eraman@google.com>
Wed, 9 Jan 2019 20:10:27 +0000 (20:10 +0000)
committerEaswaran Raman <eraman@google.com>
Wed, 9 Jan 2019 20:10:27 +0000 (20:10 +0000)
Summary:
Instead of using two separate callbacks to return the entry count and the
relative block frequency, use a single callback to return callsite
count. This would allow better supporting hybrid mode in the future as
the count of callsite need not always be derived from entry count (as in
sample PGO).

Reviewers: davidxl

Subscribers: mehdi_amini, steven_wu, dexonsmith, dang, llvm-commits

Differential Revision: https://reviews.llvm.org/D56464

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@350755 91177308-0d34-0410-b5e6-96231b3b80d8

include/llvm/Analysis/SyntheticCountsUtils.h
lib/Analysis/SyntheticCountsUtils.cpp
lib/LTO/SummaryBasedOptimizations.cpp
lib/Transforms/IPO/SyntheticCountsPropagation.cpp

index 87f4a0100b38901f51f9ee36cbbfe93a7d42dc30..db80bef001e24e4a236cd5220f3096e4cdf17801 100644 (file)
@@ -36,16 +36,17 @@ public:
   using EdgeRef = typename CGT::EdgeRef;
   using SccTy = std::vector<NodeRef>;
 
-  using GetRelBBFreqTy = function_ref<Optional<Scaled64>(EdgeRef)>;
-  using GetCountTy = function_ref<uint64_t(NodeRef)>;
-  using AddCountTy = function_ref<void(NodeRef, uint64_t)>;
+  // Not all EdgeRef have information about the source of the edge. Hence
+  // NodeRef corresponding to the source of the EdgeRef is explicitly passed.
+  using GetProfCountTy = function_ref<Optional<Scaled64>(NodeRef, EdgeRef)>;
+  using AddCountTy = function_ref<void(NodeRef, Scaled64)>;
 
-  static void propagate(const CallGraphType &CG, GetRelBBFreqTy GetRelBBFreq,
-                        GetCountTy GetCount, AddCountTy AddCount);
+  static void propagate(const CallGraphType &CG, GetProfCountTy GetProfCount,
+                        AddCountTy AddCount);
 
 private:
-  static void propagateFromSCC(const SccTy &SCC, GetRelBBFreqTy GetRelBBFreq,
-                               GetCountTy GetCount, AddCountTy AddCount);
+  static void propagateFromSCC(const SccTy &SCC, GetProfCountTy GetProfCount,
+                               AddCountTy AddCount);
 };
 } // namespace llvm
 
index 386396bcff364f3cf0222a3f7994fd4c46023686..c2d7bb11a4cf2213a5848f5229ae09a623955117 100644 (file)
@@ -26,8 +26,7 @@ using namespace llvm;
 // Given an SCC, propagate entry counts along the edge of the SCC nodes.
 template <typename CallGraphType>
 void SyntheticCountsUtils<CallGraphType>::propagateFromSCC(
-    const SccTy &SCC, GetRelBBFreqTy GetRelBBFreq, GetCountTy GetCount,
-    AddCountTy AddCount) {
+    const SccTy &SCC, GetProfCountTy GetProfCount, AddCountTy AddCount) {
 
   DenseSet<NodeRef> SCCNodes;
   SmallVector<std::pair<NodeRef, EdgeRef>, 8> SCCEdges, NonSCCEdges;
@@ -54,17 +53,13 @@ void SyntheticCountsUtils<CallGraphType>::propagateFromSCC(
   // This ensures that the order of
   // traversal of nodes within the SCC doesn't affect the final result.
 
-  DenseMap<NodeRef, uint64_t> AdditionalCounts;
+  DenseMap<NodeRef, Scaled64> AdditionalCounts;
   for (auto &E : SCCEdges) {
-    auto OptRelFreq = GetRelBBFreq(E.second);
-    if (!OptRelFreq)
+    auto OptProfCount = GetProfCount(E.first, E.second);
+    if (!OptProfCount)
       continue;
-    Scaled64 RelFreq = OptRelFreq.getValue();
-    auto Caller = E.first;
     auto Callee = CGT::edge_dest(E.second);
-    RelFreq *= Scaled64(GetCount(Caller), 0);
-    uint64_t AdditionalCount = RelFreq.toInt<uint64_t>();
-    AdditionalCounts[Callee] += AdditionalCount;
+    AdditionalCounts[Callee] += OptProfCount.getValue();
   }
 
   // Update the counts for the nodes in the SCC.
@@ -73,14 +68,11 @@ void SyntheticCountsUtils<CallGraphType>::propagateFromSCC(
 
   // Now update the counts for nodes outside the SCC.
   for (auto &E : NonSCCEdges) {
-    auto OptRelFreq = GetRelBBFreq(E.second);
-    if (!OptRelFreq)
+    auto OptProfCount = GetProfCount(E.first, E.second);
+    if (!OptProfCount)
       continue;
-    Scaled64 RelFreq = OptRelFreq.getValue();
-    auto Caller = E.first;
     auto Callee = CGT::edge_dest(E.second);
-    RelFreq *= Scaled64(GetCount(Caller), 0);
-    AddCount(Callee, RelFreq.toInt<uint64_t>());
+    AddCount(Callee, OptProfCount.getValue());
   }
 }
 
@@ -94,8 +86,7 @@ void SyntheticCountsUtils<CallGraphType>::propagateFromSCC(
 
 template <typename CallGraphType>
 void SyntheticCountsUtils<CallGraphType>::propagate(const CallGraphType &CG,
-                                                    GetRelBBFreqTy GetRelBBFreq,
-                                                    GetCountTy GetCount,
+                                                    GetProfCountTy GetProfCount,
                                                     AddCountTy AddCount) {
   std::vector<SccTy> SCCs;
 
@@ -107,7 +98,7 @@ void SyntheticCountsUtils<CallGraphType>::propagate(const CallGraphType &CG,
   // The scc iterator returns the scc in bottom-up order, so reverse the SCCs
   // and call propagateFromSCC.
   for (auto &SCC : reverse(SCCs))
-    propagateFromSCC(SCC, GetRelBBFreq, GetCount, AddCount);
+    propagateFromSCC(SCC, GetProfCount, AddCount);
 }
 
 template class llvm::SyntheticCountsUtils<const CallGraph *>;
index 8b1abb78462b1ab9044bae7e72c8f31a98f4cd67..bcdd984daa58422fc485585a20d6ae7dbe4ee83b 100644 (file)
@@ -60,21 +60,27 @@ void llvm::computeSyntheticCounts(ModuleSummaryIndex &Index) {
       return UINT64_C(0);
     }
   };
-  auto AddToEntryCount = [](ValueInfo V, uint64_t New) {
+  auto AddToEntryCount = [](ValueInfo V, Scaled64 New) {
     if (!V.getSummaryList().size())
       return;
     for (auto &GVS : V.getSummaryList()) {
       auto S = GVS.get()->getBaseObject();
       auto *F = cast<FunctionSummary>(S);
-      F->setEntryCount(SaturatingAdd(F->entryCount(), New));
+      F->setEntryCount(
+          SaturatingAdd(F->entryCount(), New.template toInt<uint64_t>()));
     }
   };
 
+  auto GetProfileCount = [&](ValueInfo V, FunctionSummary::EdgeTy &Edge) {
+    auto RelFreq = GetCallSiteRelFreq(Edge);
+    Scaled64 EC(GetEntryCount(V), 0);
+    return RelFreq * EC;
+  };
   // After initializing the counts in initializeCounts above, the counts have to
   // be propagated across the combined callgraph.
   // SyntheticCountsUtils::propagate takes care of this propagation on any
   // callgraph that specialized GraphTraits.
-  SyntheticCountsUtils<ModuleSummaryIndex *>::propagate(
-      &Index, GetCallSiteRelFreq, GetEntryCount, AddToEntryCount);
+  SyntheticCountsUtils<ModuleSummaryIndex *>::propagate(&Index, GetProfileCount,
+                                                        AddToEntryCount);
   Index.setHasSyntheticEntryCounts();
 }
index 64837d4f5d6196a60c2a2625f35c15e9c7c10c26..ba4efb3ff60d99c05a34cc9e67d9840dff98cd3b 100644 (file)
@@ -30,6 +30,7 @@
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/Analysis/BlockFrequencyInfo.h"
 #include "llvm/Analysis/CallGraph.h"
+#include "llvm/Analysis/ProfileSummaryInfo.h"
 #include "llvm/Analysis/SyntheticCountsUtils.h"
 #include "llvm/IR/CallSite.h"
 #include "llvm/IR/Function.h"
@@ -98,13 +99,15 @@ PreservedAnalyses SyntheticCountsPropagation::run(Module &M,
                                                   ModuleAnalysisManager &MAM) {
   FunctionAnalysisManager &FAM =
       MAM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
-  DenseMap<Function *, uint64_t> Counts;
+  DenseMap<Function *, Scaled64> Counts;
   // Set initial entry counts.
-  initializeCounts(M, [&](Function *F, uint64_t Count) { Counts[F] = Count; });
+  initializeCounts(
+      M, [&](Function *F, uint64_t Count) { Counts[F] = Scaled64(Count, 0); });
 
-  // Compute the relative block frequency for a call edge. Use scaled numbers
-  // and not integers since the relative block frequency could be less than 1.
-  auto GetCallSiteRelFreq = [&](const CallGraphNode::CallRecord &Edge) {
+  // Edge includes information about the source. Hence ignore the first
+  // parameter.
+  auto GetCallSiteProfCount = [&](const CallGraphNode *,
+                                  const CallGraphNode::CallRecord &Edge) {
     Optional<Scaled64> Res = None;
     if (!Edge.first)
       return Res;
@@ -112,29 +115,33 @@ PreservedAnalyses SyntheticCountsPropagation::run(Module &M,
     CallSite CS(cast<Instruction>(Edge.first));
     Function *Caller = CS.getCaller();
     auto &BFI = FAM.getResult<BlockFrequencyAnalysis>(*Caller);
+
+    // Now compute the callsite count from relative frequency and
+    // entry count:
     BasicBlock *CSBB = CS.getInstruction()->getParent();
     Scaled64 EntryFreq(BFI.getEntryFreq(), 0);
-    Scaled64 BBFreq(BFI.getBlockFreq(CSBB).getFrequency(), 0);
-    BBFreq /= EntryFreq;
-    return Optional<Scaled64>(BBFreq);
+    Scaled64 BBCount(BFI.getBlockFreq(CSBB).getFrequency(), 0);
+    BBCount /= EntryFreq;
+    BBCount *= Counts[Caller];
+    return Optional<Scaled64>(BBCount);
   };
 
   CallGraph CG(M);
   // Propgate the entry counts on the callgraph.
   SyntheticCountsUtils<const CallGraph *>::propagate(
-      &CG, GetCallSiteRelFreq,
-      [&](const CallGraphNode *N) { return Counts[N->getFunction()]; },
-      [&](const CallGraphNode *N, uint64_t New) {
+      &CG, GetCallSiteProfCount, [&](const CallGraphNode *N, Scaled64 New) {
         auto F = N->getFunction();
         if (!F || F->isDeclaration())
           return;
+
         Counts[F] += New;
       });
 
   // Set the counts as metadata.
-  for (auto Entry : Counts)
-    Entry.first->setEntryCount(
-        ProfileCount(Entry.second, Function::PCT_Synthetic));
+  for (auto Entry : Counts) {
+    Entry.first->setEntryCount(ProfileCount(
+        Entry.second.template toInt<uint64_t>(), Function::PCT_Synthetic));
+  }
 
   return PreservedAnalyses::all();
 }