namespace llvm {
-template <class LatticeVal> class SparseSolver;
+/// A template for translating between LLVM Values and LatticeKeys. Clients must
+/// provide a specialization of LatticeKeyInfo for their LatticeKey type.
+template <class LatticeKey> struct LatticeKeyInfo {
+ // static inline Value *getValueFromLatticeKey(LatticeKey Key);
+ // static inline LatticeKey getLatticeKeyFromValue(Value *V);
+};
+
+template <class LatticeKey, class LatticeVal,
+ class KeyInfo = LatticeKeyInfo<LatticeKey>>
+class SparseSolver;
/// AbstractLatticeFunction - This class is implemented by the dataflow instance
/// to specify what the lattice values are and how they handle merges etc. This
/// gives the client the power to compute lattice values from instructions,
/// constants, etc. The current requirement is that lattice values must be
-/// copyable. At the moment, nothing tries to avoid copying.
-
-
-template <class LatticeVal> class AbstractLatticeFunction {
+/// copyable. At the moment, nothing tries to avoid copying. Additionally,
+/// lattice keys must be able to be used as keys of a mapping data structure.
+/// Internally, the generic solver currently uses a DenseMap to map lattice keys
+/// to lattice values. If the lattice key is a non-standard type, a
+/// specialization of DenseMapInfo must be provided.
+template <class LatticeKey, class LatticeVal> class AbstractLatticeFunction {
private:
LatticeVal UndefVal, OverdefinedVal, UntrackedVal;
LatticeVal getOverdefinedVal() const { return OverdefinedVal; }
LatticeVal getUntrackedVal() const { return UntrackedVal; }
- /// IsUntrackedValue - If the specified Value is something that is obviously
- /// uninteresting to the analysis (and would always return UntrackedVal),
- /// this function can return true to avoid pointless work.
- virtual bool IsUntrackedValue(Value *V) { return false; }
+ /// IsUntrackedValue - If the specified LatticeKey is obviously uninteresting
+ /// to the analysis (i.e., it would always return UntrackedVal), this
+ /// function can return true to avoid pointless work.
+ virtual bool IsUntrackedValue(LatticeKey Key) { return false; }
- /// ComputeConstant - Given a constant value, compute and return a lattice
- /// value corresponding to the specified constant.
- virtual LatticeVal ComputeConstant(Constant *C) {
- return getOverdefinedVal(); // always safe
+ /// ComputeLatticeVal - Compute and return a LatticeVal corresponding to the
+ /// given LatticeKey.
+ virtual LatticeVal ComputeLatticeVal(LatticeKey Key) {
+ return getOverdefinedVal();
}
/// IsSpecialCasedPHI - Given a PHI node, determine whether this PHI node is
/// one that the we want to handle through ComputeInstructionState.
virtual bool IsSpecialCasedPHI(PHINode *PN) { return false; }
- /// GetConstant - If the specified lattice value is representable as an LLVM
- /// constant value, return it. Otherwise return null. The returned value
- /// must be in the same LLVM type as Val.
- virtual Constant *GetConstant(LatticeVal LV, Value *Val,
- SparseSolver<LatticeVal> &SS) {
- return nullptr;
- }
-
- /// ComputeArgument - Given a formal argument value, compute and return a
- /// lattice value corresponding to the specified argument.
- virtual LatticeVal ComputeArgument(Argument *I) {
- return getOverdefinedVal(); // always safe
- }
-
/// MergeValues - Compute and return the merge of the two specified lattice
/// values. Merging should only move one direction down the lattice to
/// guarantee convergence (toward overdefined).
return getOverdefinedVal(); // always safe, never useful.
}
- /// ComputeInstructionState - Given an instruction and a vector of its operand
- /// values, compute the result value of the instruction.
- virtual LatticeVal ComputeInstructionState(Instruction &I,
- SparseSolver<LatticeVal> &SS) {
- return getOverdefinedVal(); // always safe, never useful.
+ /// ComputeInstructionState - Compute the LatticeKeys that change as a result
+ /// of executing instruction \p I. Their associated LatticeVals are store in
+ /// \p ChangedValues.
+ virtual void
+ ComputeInstructionState(Instruction &I,
+ DenseMap<LatticeKey, LatticeVal> &ChangedValues,
+ SparseSolver<LatticeKey, LatticeVal> &SS) = 0;
+
+ /// PrintLatticeVal - Render the given LatticeVal to the specified stream.
+ virtual void PrintLatticeVal(LatticeVal LV, raw_ostream &OS);
+
+ /// PrintLatticeKey - Render the given LatticeKey to the specified stream.
+ virtual void PrintLatticeKey(LatticeKey Key, raw_ostream &OS);
+
+ /// GetValueFromLatticeVal - If the given LatticeVal is representable as an
+ /// LLVM value, return it; otherwise, return nullptr. If a type is given, the
+ /// returned value must have the same type. This function is used by the
+ /// generic solver in attempting to resolve branch and switch conditions.
+ virtual Value *GetValueFromLatticeVal(LatticeVal LV, Type *Ty = nullptr) {
+ return nullptr;
}
-
- /// PrintValue - Render the specified lattice value to the specified stream.
- virtual void PrintValue(LatticeVal V, raw_ostream &OS);
};
/// SparseSolver - This class is a general purpose solver for Sparse Conditional
/// Propagation with a programmable lattice function.
-template <class LatticeVal> class SparseSolver {
+template <class LatticeKey, class LatticeVal, class KeyInfo>
+class SparseSolver {
/// LatticeFunc - This is the object that knows the lattice and how to
/// compute transfer functions.
- AbstractLatticeFunction<LatticeVal> *LatticeFunc;
+ AbstractLatticeFunction<LatticeKey, LatticeVal> *LatticeFunc;
- /// ValueState - Holds the lattice state associated with LLVM values.
- DenseMap<Value *, LatticeVal> ValueState;
+ /// ValueState - Holds the LatticeVals associated with LatticeKeys.
+ DenseMap<LatticeKey, LatticeVal> ValueState;
/// BBExecutable - Holds the basic blocks that are executable.
SmallPtrSet<BasicBlock *, 16> BBExecutable;
std::set<Edge> KnownFeasibleEdges;
public:
- explicit SparseSolver(AbstractLatticeFunction<LatticeVal> *Lattice)
+ explicit SparseSolver(
+ AbstractLatticeFunction<LatticeKey, LatticeVal> *Lattice)
: LatticeFunc(Lattice) {}
SparseSolver(const SparseSolver &) = delete;
SparseSolver &operator=(const SparseSolver &) = delete;
/// Solve - Solve for constants and executable blocks.
- void Solve(Function &F);
+ void Solve();
- void Print(Function &F, raw_ostream &OS) const;
+ void Print(raw_ostream &OS) const;
/// getExistingValueState - Return the LatticeVal object corresponding to the
/// given value from the ValueState map. If the value is not in the map,
/// UntrackedVal is returned, unlike the getValueState method.
- LatticeVal getExistingValueState(Value *V) const {
- auto I = ValueState.find(V);
+ LatticeVal getExistingValueState(LatticeKey Key) const {
+ auto I = ValueState.find(Key);
return I != ValueState.end() ? I->second : LatticeFunc->getUntrackedVal();
}
/// getValueState - Return the LatticeVal object corresponding to the given
/// value from the ValueState map. If the value is not in the map, its state
/// is initialized.
- LatticeVal getValueState(Value *V);
+ LatticeVal getValueState(LatticeKey Key);
/// isEdgeFeasible - Return true if the control flow edge from the 'From'
/// basic block to the 'To' basic block is currently feasible. If
return BBExecutable.count(BB);
}
-private:
- /// UpdateState - When the state for some instruction is potentially updated,
- /// this function notices and adds I to the worklist if needed.
- void UpdateState(Instruction &Inst, LatticeVal V);
-
/// MarkBlockExecutable - This method can be used by clients to mark all of
/// the blocks that are known to be intrinsically live in the processed unit.
void MarkBlockExecutable(BasicBlock *BB);
+private:
+ /// UpdateState - When the state of some LatticeKey is potentially updated to
+ /// the given LatticeVal, this function notices and adds the LLVM value
+ /// corresponding the key to the work list, if needed.
+ void UpdateState(LatticeKey Key, LatticeVal LV);
+
/// markEdgeExecutable - Mark a basic block as executable, adding it to the BB
/// work list if it is not already executable.
void markEdgeExecutable(BasicBlock *Source, BasicBlock *Dest);
// AbstractLatticeFunction Implementation
//===----------------------------------------------------------------------===//
-template <class LatticeVal>
-void AbstractLatticeFunction<LatticeVal>::PrintValue(LatticeVal V,
- raw_ostream &OS) {
+template <class LatticeKey, class LatticeVal>
+void AbstractLatticeFunction<LatticeKey, LatticeVal>::PrintLatticeVal(
+ LatticeVal V, raw_ostream &OS) {
if (V == UndefVal)
OS << "undefined";
else if (V == OverdefinedVal)
OS << "unknown lattice value";
}
+template <class LatticeKey, class LatticeVal>
+void AbstractLatticeFunction<LatticeKey, LatticeVal>::PrintLatticeKey(
+ LatticeKey Key, raw_ostream &OS) {
+ OS << "unknown lattice key";
+}
+
//===----------------------------------------------------------------------===//
// SparseSolver Implementation
//===----------------------------------------------------------------------===//
-template <class LatticeVal>
-LatticeVal SparseSolver<LatticeVal>::getValueState(Value *V) {
- auto I = ValueState.find(V);
+template <class LatticeKey, class LatticeVal, class KeyInfo>
+LatticeVal
+SparseSolver<LatticeKey, LatticeVal, KeyInfo>::getValueState(LatticeKey Key) {
+ auto I = ValueState.find(Key);
if (I != ValueState.end())
return I->second; // Common case, in the map
- LatticeVal LV;
- if (LatticeFunc->IsUntrackedValue(V))
+ if (LatticeFunc->IsUntrackedValue(Key))
return LatticeFunc->getUntrackedVal();
- else if (Constant *C = dyn_cast<Constant>(V))
- LV = LatticeFunc->ComputeConstant(C);
- else if (Argument *A = dyn_cast<Argument>(V))
- LV = LatticeFunc->ComputeArgument(A);
- else if (!isa<Instruction>(V))
- // All other non-instructions are overdefined.
- LV = LatticeFunc->getOverdefinedVal();
- else
- // All instructions are underdefined by default.
- LV = LatticeFunc->getUndefVal();
+ LatticeVal LV = LatticeFunc->ComputeLatticeVal(Key);
// If this value is untracked, don't add it to the map.
if (LV == LatticeFunc->getUntrackedVal())
return LV;
- return ValueState[V] = LV;
+ return ValueState[Key] = LV;
}
-template <class LatticeVal>
-void SparseSolver<LatticeVal>::UpdateState(Instruction &Inst, LatticeVal V) {
- auto I = ValueState.find(&Inst);
- if (I != ValueState.end() && I->second == V)
+template <class LatticeKey, class LatticeVal, class KeyInfo>
+void SparseSolver<LatticeKey, LatticeVal, KeyInfo>::UpdateState(LatticeKey Key,
+ LatticeVal LV) {
+ auto I = ValueState.find(Key);
+ if (I != ValueState.end() && I->second == LV)
return; // No change.
- // An update. Visit uses of I.
- ValueState[&Inst] = V;
- ValueWorkList.push_back(&Inst);
+ // Update the state of the given LatticeKey and add its corresponding LLVM
+ // value to the work list.
+ ValueState[Key] = LV;
+ if (Value *V = KeyInfo::getValueFromLatticeKey(Key))
+ ValueWorkList.push_back(V);
}
-template <class LatticeVal>
-void SparseSolver<LatticeVal>::MarkBlockExecutable(BasicBlock *BB) {
+template <class LatticeKey, class LatticeVal, class KeyInfo>
+void SparseSolver<LatticeKey, LatticeVal, KeyInfo>::MarkBlockExecutable(
+ BasicBlock *BB) {
+ if (!BBExecutable.insert(BB).second)
+ return;
DEBUG(dbgs() << "Marking Block Executable: " << BB->getName() << "\n");
- BBExecutable.insert(BB); // Basic block is executable!
BBWorkList.push_back(BB); // Add the block to the work list!
}
-template <class LatticeVal>
-void SparseSolver<LatticeVal>::markEdgeExecutable(BasicBlock *Source,
- BasicBlock *Dest) {
+template <class LatticeKey, class LatticeVal, class KeyInfo>
+void SparseSolver<LatticeKey, LatticeVal, KeyInfo>::markEdgeExecutable(
+ BasicBlock *Source, BasicBlock *Dest) {
if (!KnownFeasibleEdges.insert(Edge(Source, Dest)).second)
return; // This edge is already known to be executable!
}
}
-template <class LatticeVal>
-void SparseSolver<LatticeVal>::getFeasibleSuccessors(
+template <class LatticeKey, class LatticeVal, class KeyInfo>
+void SparseSolver<LatticeKey, LatticeVal, KeyInfo>::getFeasibleSuccessors(
TerminatorInst &TI, SmallVectorImpl<bool> &Succs, bool AggressiveUndef) {
Succs.resize(TI.getNumSuccessors());
if (TI.getNumSuccessors() == 0)
LatticeVal BCValue;
if (AggressiveUndef)
- BCValue = getValueState(BI->getCondition());
+ BCValue =
+ getValueState(KeyInfo::getLatticeKeyFromValue(BI->getCondition()));
else
- BCValue = getExistingValueState(BI->getCondition());
+ BCValue = getExistingValueState(
+ KeyInfo::getLatticeKeyFromValue(BI->getCondition()));
if (BCValue == LatticeFunc->getOverdefinedVal() ||
BCValue == LatticeFunc->getUntrackedVal()) {
if (BCValue == LatticeFunc->getUndefVal())
return;
- Constant *C = LatticeFunc->GetConstant(BCValue, BI->getCondition(), *this);
+ Constant *C =
+ dyn_cast_or_null<Constant>(LatticeFunc->GetValueFromLatticeVal(
+ BCValue, BI->getCondition()->getType()));
if (!C || !isa<ConstantInt>(C)) {
// Non-constant values can go either way.
Succs[0] = Succs[1] = true;
return;
}
- if (isa<InvokeInst>(TI)) {
- // Invoke instructions successors are always executable.
- // TODO: Could ask the lattice function if the value can throw.
- Succs[0] = Succs[1] = true;
+ if (TI.isExceptional()) {
+ Succs.assign(Succs.size(), true);
return;
}
SwitchInst &SI = cast<SwitchInst>(TI);
LatticeVal SCValue;
if (AggressiveUndef)
- SCValue = getValueState(SI.getCondition());
+ SCValue = getValueState(KeyInfo::getLatticeKeyFromValue(SI.getCondition()));
else
- SCValue = getExistingValueState(SI.getCondition());
+ SCValue = getExistingValueState(
+ KeyInfo::getLatticeKeyFromValue(SI.getCondition()));
if (SCValue == LatticeFunc->getOverdefinedVal() ||
SCValue == LatticeFunc->getUntrackedVal()) {
if (SCValue == LatticeFunc->getUndefVal())
return;
- Constant *C = LatticeFunc->GetConstant(SCValue, SI.getCondition(), *this);
+ Constant *C = dyn_cast_or_null<Constant>(LatticeFunc->GetValueFromLatticeVal(
+ SCValue, SI.getCondition()->getType()));
if (!C || !isa<ConstantInt>(C)) {
// All destinations are executable!
Succs.assign(TI.getNumSuccessors(), true);
Succs[Case.getSuccessorIndex()] = true;
}
-template <class LatticeVal>
-bool SparseSolver<LatticeVal>::isEdgeFeasible(BasicBlock *From, BasicBlock *To,
- bool AggressiveUndef) {
+template <class LatticeKey, class LatticeVal, class KeyInfo>
+bool SparseSolver<LatticeKey, LatticeVal, KeyInfo>::isEdgeFeasible(
+ BasicBlock *From, BasicBlock *To, bool AggressiveUndef) {
SmallVector<bool, 16> SuccFeasible;
TerminatorInst *TI = From->getTerminator();
getFeasibleSuccessors(*TI, SuccFeasible, AggressiveUndef);
return false;
}
-template <class LatticeVal>
-void SparseSolver<LatticeVal>::visitTerminatorInst(TerminatorInst &TI) {
+template <class LatticeKey, class LatticeVal, class KeyInfo>
+void SparseSolver<LatticeKey, LatticeVal, KeyInfo>::visitTerminatorInst(
+ TerminatorInst &TI) {
SmallVector<bool, 16> SuccFeasible;
getFeasibleSuccessors(TI, SuccFeasible, true);
markEdgeExecutable(BB, TI.getSuccessor(i));
}
-template <class LatticeVal>
-void SparseSolver<LatticeVal>::visitPHINode(PHINode &PN) {
+template <class LatticeKey, class LatticeVal, class KeyInfo>
+void SparseSolver<LatticeKey, LatticeVal, KeyInfo>::visitPHINode(PHINode &PN) {
// The lattice function may store more information on a PHINode than could be
// computed from its incoming values. For example, SSI form stores its sigma
// functions as PHINodes with a single incoming value.
if (LatticeFunc->IsSpecialCasedPHI(&PN)) {
- LatticeVal IV = LatticeFunc->ComputeInstructionState(PN, *this);
- if (IV != LatticeFunc->getUntrackedVal())
- UpdateState(PN, IV);
+ DenseMap<LatticeKey, LatticeVal> ChangedValues;
+ LatticeFunc->ComputeInstructionState(PN, ChangedValues, *this);
+ for (auto &ChangedValue : ChangedValues)
+ if (ChangedValue.second != LatticeFunc->getUntrackedVal())
+ UpdateState(ChangedValue.first, ChangedValue.second);
return;
}
- LatticeVal PNIV = getValueState(&PN);
+ LatticeKey Key = KeyInfo::getLatticeKeyFromValue(&PN);
+ LatticeVal PNIV = getValueState(Key);
LatticeVal Overdefined = LatticeFunc->getOverdefinedVal();
// If this value is already overdefined (common) just return.
// Super-extra-high-degree PHI nodes are unlikely to ever be interesting,
// and slow us down a lot. Just mark them overdefined.
if (PN.getNumIncomingValues() > 64) {
- UpdateState(PN, Overdefined);
+ UpdateState(Key, Overdefined);
return;
}
continue;
// Merge in this value.
- LatticeVal OpVal = getValueState(PN.getIncomingValue(i));
+ LatticeVal OpVal =
+ getValueState(KeyInfo::getLatticeKeyFromValue(PN.getIncomingValue(i)));
if (OpVal != PNIV)
PNIV = LatticeFunc->MergeValues(PNIV, OpVal);
}
// Update the PHI with the compute value, which is the merge of the inputs.
- UpdateState(PN, PNIV);
+ UpdateState(Key, PNIV);
}
-template <class LatticeVal>
-void SparseSolver<LatticeVal>::visitInst(Instruction &I) {
+template <class LatticeKey, class LatticeVal, class KeyInfo>
+void SparseSolver<LatticeKey, LatticeVal, KeyInfo>::visitInst(Instruction &I) {
// PHIs are handled by the propagation logic, they are never passed into the
// transfer functions.
if (PHINode *PN = dyn_cast<PHINode>(&I))
// Otherwise, ask the transfer function what the result is. If this is
// something that we care about, remember it.
- LatticeVal IV = LatticeFunc->ComputeInstructionState(I, *this);
- if (IV != LatticeFunc->getUntrackedVal())
- UpdateState(I, IV);
+ DenseMap<LatticeKey, LatticeVal> ChangedValues;
+ LatticeFunc->ComputeInstructionState(I, ChangedValues, *this);
+ for (auto &ChangedValue : ChangedValues)
+ if (ChangedValue.second != LatticeFunc->getUntrackedVal())
+ UpdateState(ChangedValue.first, ChangedValue.second);
if (TerminatorInst *TI = dyn_cast<TerminatorInst>(&I))
visitTerminatorInst(*TI);
}
-template <class LatticeVal> void SparseSolver<LatticeVal>::Solve(Function &F) {
- MarkBlockExecutable(&F.getEntryBlock());
-
+template <class LatticeKey, class LatticeVal, class KeyInfo>
+void SparseSolver<LatticeKey, LatticeVal, KeyInfo>::Solve() {
// Process the work lists until they are empty!
while (!BBWorkList.empty() || !ValueWorkList.empty()) {
// Process the value work list.
}
}
-template <class LatticeVal>
-void SparseSolver<LatticeVal>::Print(Function &F, raw_ostream &OS) const {
- OS << "\nFUNCTION: " << F.getName() << "\n";
- for (auto &BB : F) {
- if (!BBExecutable.count(&BB))
- OS << "INFEASIBLE: ";
- OS << "\t";
- if (BB.hasName())
- OS << BB.getName() << ":\n";
- else
- OS << "; anon bb\n";
- for (auto &I : BB) {
- LatticeFunc->PrintValue(getExistingValueState(&I), OS);
- OS << I << "\n";
- }
+template <class LatticeKey, class LatticeVal, class KeyInfo>
+void SparseSolver<LatticeKey, LatticeVal, KeyInfo>::Print(
+ raw_ostream &OS) const {
+ if (ValueState.empty())
+ return;
+ LatticeKey Key;
+ LatticeVal LV;
+
+ OS << "ValueState:\n";
+ for (auto &Entry : ValueState) {
+ std::tie(Key, LV) = Entry;
+ if (LV == LatticeFunc->getUntrackedVal())
+ continue;
+ OS << "\t";
+ LatticeFunc->PrintLatticeVal(LV, OS);
+ OS << ": ";
+ LatticeFunc->PrintLatticeKey(Key, OS);
OS << "\n";
}
}
--- /dev/null
+//===- SparsePropagation.cpp - Unit tests for the generic solver ----------===//
+//
+// The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/SparsePropagation.h"
+#include "llvm/ADT/PointerIntPair.h"
+#include "llvm/IR/CallSite.h"
+#include "llvm/IR/IRBuilder.h"
+#include "gtest/gtest.h"
+using namespace llvm;
+
+namespace {
+/// To enable interprocedural analysis, we assign LLVM values to the following
+/// groups. The register group represents SSA registers, the return group
+/// represents the return values of functions, and the memory group represents
+/// in-memory values. An LLVM Value can technically be in more than one group.
+/// It's necessary to distinguish these groups so we can, for example, track a
+/// global variable separately from the value stored at its location.
+enum class IPOGrouping { Register, Return, Memory };
+
+/// Our LatticeKeys are PointerIntPairs composed of LLVM values and groupings.
+/// The PointerIntPair header provides a DenseMapInfo specialization, so using
+/// these as LatticeKeys is fine.
+using TestLatticeKey = PointerIntPair<Value *, 2, IPOGrouping>;
+} // namespace
+
+namespace llvm {
+/// A specialization of LatticeKeyInfo for TestLatticeKeys. The generic solver
+/// must translate between LatticeKeys and LLVM Values when adding Values to
+/// its work list and inspecting the state of control-flow related values.
+template <> struct LatticeKeyInfo<TestLatticeKey> {
+ static inline Value *getValueFromLatticeKey(TestLatticeKey Key) {
+ return Key.getPointer();
+ }
+ static inline TestLatticeKey getLatticeKeyFromValue(Value *V) {
+ return TestLatticeKey(V, IPOGrouping::Register);
+ }
+};
+} // namespace llvm
+
+namespace {
+/// This class defines a simple test lattice value that could be used for
+/// solving problems similar to constant propagation. The value is maintained
+/// as a PointerIntPair.
+class TestLatticeVal {
+public:
+ /// The states of the lattices value. Only the ConstantVal state is
+ /// interesting; the rest are special states used by the generic solver. The
+ /// UntrackedVal state differs from the other three in that the generic
+ /// solver uses it to avoid doing unnecessary work. In particular, when a
+ /// value moves to the UntrackedVal state, it's users are not notified.
+ enum TestLatticeStateTy {
+ UndefinedVal,
+ ConstantVal,
+ OverdefinedVal,
+ UntrackedVal
+ };
+
+ TestLatticeVal() : LatticeVal(nullptr, UndefinedVal) {}
+ TestLatticeVal(Constant *C, TestLatticeStateTy State)
+ : LatticeVal(C, State) {}
+
+ /// Return true if this lattice value is in the Constant state. This is used
+ /// for checking the solver results.
+ bool isConstant() const { return LatticeVal.getInt() == ConstantVal; }
+
+ /// Return true if this lattice value is in the Overdefined state. This is
+ /// used for checking the solver results.
+ bool isOverdefined() const { return LatticeVal.getInt() == OverdefinedVal; }
+
+ bool operator==(const TestLatticeVal &RHS) const {
+ return LatticeVal == RHS.LatticeVal;
+ }
+
+ bool operator!=(const TestLatticeVal &RHS) const {
+ return LatticeVal != RHS.LatticeVal;
+ }
+
+private:
+ /// A simple lattice value type for problems similar to constant propagation.
+ /// It holds the constant value and the lattice state.
+ PointerIntPair<const Constant *, 2, TestLatticeStateTy> LatticeVal;
+};
+
+/// This class defines a simple test lattice function that could be used for
+/// solving problems similar to constant propagation. The test lattice differs
+/// from a "real" lattice in a few ways. First, it initializes all return
+/// values, values stored in global variables, and arguments in the undefined
+/// state. This means that there are no limitations on what we can track
+/// interprocedurally. For simplicity, all global values in the tests will be
+/// given internal linkage, since this is not something this lattice function
+/// tracks. Second, it only handles the few instructions necessary for the
+/// tests.
+class TestLatticeFunc
+ : public AbstractLatticeFunction<TestLatticeKey, TestLatticeVal> {
+public:
+ /// Construct a new test lattice function with special values for the
+ /// Undefined, Overdefined, and Untracked states.
+ TestLatticeFunc()
+ : AbstractLatticeFunction(
+ TestLatticeVal(nullptr, TestLatticeVal::UndefinedVal),
+ TestLatticeVal(nullptr, TestLatticeVal::OverdefinedVal),
+ TestLatticeVal(nullptr, TestLatticeVal::UntrackedVal)) {}
+
+ /// Compute and return a TestLatticeVal for the given TestLatticeKey. For the
+ /// test analysis, a LatticeKey will begin in the undefined state, unless it
+ /// represents an LLVM Constant in the register grouping.
+ TestLatticeVal ComputeLatticeVal(TestLatticeKey Key) override {
+ if (Key.getInt() == IPOGrouping::Register)
+ if (auto *C = dyn_cast<Constant>(Key.getPointer()))
+ return TestLatticeVal(C, TestLatticeVal::ConstantVal);
+ return getUndefVal();
+ }
+
+ /// Merge the two given lattice values. This merge should be equivalent to
+ /// what is done for constant propagation. That is, the resulting lattice
+ /// value is constant only if the two given lattice values are constant and
+ /// hold the same value.
+ TestLatticeVal MergeValues(TestLatticeVal X, TestLatticeVal Y) override {
+ if (X == getUntrackedVal() || Y == getUntrackedVal())
+ return getUntrackedVal();
+ if (X == getOverdefinedVal() || Y == getOverdefinedVal())
+ return getOverdefinedVal();
+ if (X == getUndefVal() && Y == getUndefVal())
+ return getUndefVal();
+ if (X == getUndefVal())
+ return Y;
+ if (Y == getUndefVal())
+ return X;
+ if (X == Y)
+ return X;
+ return getOverdefinedVal();
+ }
+
+ /// Compute the lattice values that change as a result of executing the given
+ /// instruction. We only handle the few instructions needed for the tests.
+ void ComputeInstructionState(
+ Instruction &I, DenseMap<TestLatticeKey, TestLatticeVal> &ChangedValues,
+ SparseSolver<TestLatticeKey, TestLatticeVal> &SS) override {
+ switch (I.getOpcode()) {
+ case Instruction::Call:
+ return visitCallSite(cast<CallInst>(&I), ChangedValues, SS);
+ case Instruction::Ret:
+ return visitReturn(*cast<ReturnInst>(&I), ChangedValues, SS);
+ case Instruction::Store:
+ return visitStore(*cast<StoreInst>(&I), ChangedValues, SS);
+ default:
+ return visitInst(I, ChangedValues, SS);
+ }
+ }
+
+private:
+ /// Handle call sites. The state of a called function's argument is the merge
+ /// of the current formal argument state with the call site's corresponding
+ /// actual argument state. The call site state is the merge of the call site
+ /// state with the returned value state of the called function.
+ void visitCallSite(CallSite CS,
+ DenseMap<TestLatticeKey, TestLatticeVal> &ChangedValues,
+ SparseSolver<TestLatticeKey, TestLatticeVal> &SS) {
+ Function *F = CS.getCalledFunction();
+ Instruction *I = CS.getInstruction();
+ auto RegI = TestLatticeKey(I, IPOGrouping::Register);
+ if (!F) {
+ ChangedValues[RegI] = getOverdefinedVal();
+ return;
+ }
+ SS.MarkBlockExecutable(&F->front());
+ for (Argument &A : F->args()) {
+ auto RegFormal = TestLatticeKey(&A, IPOGrouping::Register);
+ auto RegActual =
+ TestLatticeKey(CS.getArgument(A.getArgNo()), IPOGrouping::Register);
+ ChangedValues[RegFormal] =
+ MergeValues(SS.getValueState(RegFormal), SS.getValueState(RegActual));
+ }
+ auto RetF = TestLatticeKey(F, IPOGrouping::Return);
+ ChangedValues[RegI] =
+ MergeValues(SS.getValueState(RegI), SS.getValueState(RetF));
+ }
+
+ /// Handle return instructions. The function's return state is the merge of
+ /// the returned value state and the function's current return state.
+ void visitReturn(ReturnInst &I,
+ DenseMap<TestLatticeKey, TestLatticeVal> &ChangedValues,
+ SparseSolver<TestLatticeKey, TestLatticeVal> &SS) {
+ Function *F = I.getParent()->getParent();
+ if (F->getReturnType()->isVoidTy())
+ return;
+ auto RegR = TestLatticeKey(I.getReturnValue(), IPOGrouping::Register);
+ auto RetF = TestLatticeKey(F, IPOGrouping::Return);
+ ChangedValues[RetF] =
+ MergeValues(SS.getValueState(RegR), SS.getValueState(RetF));
+ }
+
+ /// Handle store instructions. If the pointer operand of the store is a
+ /// global variable, we attempt to track the value. The global variable state
+ /// is the merge of the stored value state with the current global variable
+ /// state.
+ void visitStore(StoreInst &I,
+ DenseMap<TestLatticeKey, TestLatticeVal> &ChangedValues,
+ SparseSolver<TestLatticeKey, TestLatticeVal> &SS) {
+ auto *GV = dyn_cast<GlobalVariable>(I.getPointerOperand());
+ if (!GV)
+ return;
+ auto RegVal = TestLatticeKey(I.getValueOperand(), IPOGrouping::Register);
+ auto MemPtr = TestLatticeKey(GV, IPOGrouping::Memory);
+ ChangedValues[MemPtr] =
+ MergeValues(SS.getValueState(RegVal), SS.getValueState(MemPtr));
+ }
+
+ /// Handle all other instructions. All other instructions are marked
+ /// overdefined.
+ void visitInst(Instruction &I,
+ DenseMap<TestLatticeKey, TestLatticeVal> &ChangedValues,
+ SparseSolver<TestLatticeKey, TestLatticeVal> &SS) {
+ auto RegI = TestLatticeKey(&I, IPOGrouping::Register);
+ ChangedValues[RegI] = getOverdefinedVal();
+ }
+};
+
+/// This class defines the common data used for all of the tests. The tests
+/// should add code to the module and then run the solver.
+class SparsePropagationTest : public testing::Test {
+protected:
+ LLVMContext Context;
+ Module M;
+ IRBuilder<> Builder;
+ TestLatticeFunc Lattice;
+ SparseSolver<TestLatticeKey, TestLatticeVal> Solver;
+
+public:
+ SparsePropagationTest()
+ : M("", Context), Builder(Context), Solver(&Lattice) {}
+};
+} // namespace
+
+/// Test that we mark discovered functions executable.
+///
+/// define internal void @f() {
+/// call void @g()
+/// ret void
+/// }
+///
+/// define internal void @g() {
+/// call void @f()
+/// ret void
+/// }
+///
+/// For this test, we initially mark "f" executable, and the solver discovers
+/// "g" because of the call in "f". The mutually recursive call in "g" also
+/// tests that we don't add a block to the basic block work list if it is
+/// already executable. Doing so would put the solver into an infinite loop.
+TEST_F(SparsePropagationTest, MarkBlockExecutable) {
+ Function *F = Function::Create(FunctionType::get(Builder.getVoidTy(), false),
+ GlobalValue::InternalLinkage, "f", &M);
+ Function *G = Function::Create(FunctionType::get(Builder.getVoidTy(), false),
+ GlobalValue::InternalLinkage, "g", &M);
+ BasicBlock *FEntry = BasicBlock::Create(Context, "", F);
+ BasicBlock *GEntry = BasicBlock::Create(Context, "", G);
+ Builder.SetInsertPoint(FEntry);
+ Builder.CreateCall(G);
+ Builder.CreateRetVoid();
+ Builder.SetInsertPoint(GEntry);
+ Builder.CreateCall(F);
+ Builder.CreateRetVoid();
+
+ Solver.MarkBlockExecutable(FEntry);
+ Solver.Solve();
+
+ EXPECT_TRUE(Solver.isBlockExecutable(GEntry));
+}
+
+/// Test that we propagate information through global variables.
+///
+/// @gv = internal global i64
+///
+/// define internal void @f() {
+/// store i64 1, i64* @gv
+/// ret void
+/// }
+///
+/// define internal void @g() {
+/// store i64 1, i64* @gv
+/// ret void
+/// }
+///
+/// For this test, we initially mark both "f" and "g" executable, and the
+/// solver computes the lattice state of the global variable as constant.
+TEST_F(SparsePropagationTest, GlobalVariableConstant) {
+ Function *F = Function::Create(FunctionType::get(Builder.getVoidTy(), false),
+ GlobalValue::InternalLinkage, "f", &M);
+ Function *G = Function::Create(FunctionType::get(Builder.getVoidTy(), false),
+ GlobalValue::InternalLinkage, "g", &M);
+ GlobalVariable *GV =
+ new GlobalVariable(M, Builder.getInt64Ty(), false,
+ GlobalValue::InternalLinkage, nullptr, "gv");
+ BasicBlock *FEntry = BasicBlock::Create(Context, "", F);
+ BasicBlock *GEntry = BasicBlock::Create(Context, "", G);
+ Builder.SetInsertPoint(FEntry);
+ Builder.CreateStore(Builder.getInt64(1), GV);
+ Builder.CreateRetVoid();
+ Builder.SetInsertPoint(GEntry);
+ Builder.CreateStore(Builder.getInt64(1), GV);
+ Builder.CreateRetVoid();
+
+ Solver.MarkBlockExecutable(FEntry);
+ Solver.MarkBlockExecutable(GEntry);
+ Solver.Solve();
+
+ auto MemGV = TestLatticeKey(GV, IPOGrouping::Memory);
+ EXPECT_TRUE(Solver.getExistingValueState(MemGV).isConstant());
+}
+
+/// Test that we propagate information through global variables.
+///
+/// @gv = internal global i64
+///
+/// define internal void @f() {
+/// store i64 0, i64* @gv
+/// ret void
+/// }
+///
+/// define internal void @g() {
+/// store i64 1, i64* @gv
+/// ret void
+/// }
+///
+/// For this test, we initially mark both "f" and "g" executable, and the
+/// solver computes the lattice state of the global variable as overdefined.
+TEST_F(SparsePropagationTest, GlobalVariableOverDefined) {
+ Function *F = Function::Create(FunctionType::get(Builder.getVoidTy(), false),
+ GlobalValue::InternalLinkage, "f", &M);
+ Function *G = Function::Create(FunctionType::get(Builder.getVoidTy(), false),
+ GlobalValue::InternalLinkage, "g", &M);
+ GlobalVariable *GV =
+ new GlobalVariable(M, Builder.getInt64Ty(), false,
+ GlobalValue::InternalLinkage, nullptr, "gv");
+ BasicBlock *FEntry = BasicBlock::Create(Context, "", F);
+ BasicBlock *GEntry = BasicBlock::Create(Context, "", G);
+ Builder.SetInsertPoint(FEntry);
+ Builder.CreateStore(Builder.getInt64(0), GV);
+ Builder.CreateRetVoid();
+ Builder.SetInsertPoint(GEntry);
+ Builder.CreateStore(Builder.getInt64(1), GV);
+ Builder.CreateRetVoid();
+
+ Solver.MarkBlockExecutable(FEntry);
+ Solver.MarkBlockExecutable(GEntry);
+ Solver.Solve();
+
+ auto MemGV = TestLatticeKey(GV, IPOGrouping::Memory);
+ EXPECT_TRUE(Solver.getExistingValueState(MemGV).isOverdefined());
+}
+
+/// Test that we propagate information through function returns.
+///
+/// define internal i64 @f(i1* %cond) {
+/// if:
+/// %0 = load i1, i1* %cond
+/// br i1 %0, label %then, label %else
+///
+/// then:
+/// ret i64 1
+///
+/// else:
+/// ret i64 1
+/// }
+///
+/// For this test, we initially mark "f" executable, and the solver computes
+/// the return value of the function as constant.
+TEST_F(SparsePropagationTest, FunctionDefined) {
+ Function *F =
+ Function::Create(FunctionType::get(Builder.getInt64Ty(),
+ {Type::getInt1PtrTy(Context)}, false),
+ GlobalValue::InternalLinkage, "f", &M);
+ BasicBlock *If = BasicBlock::Create(Context, "if", F);
+ BasicBlock *Then = BasicBlock::Create(Context, "then", F);
+ BasicBlock *Else = BasicBlock::Create(Context, "else", F);
+ F->arg_begin()->setName("cond");
+ Builder.SetInsertPoint(If);
+ LoadInst *Cond = Builder.CreateLoad(F->arg_begin());
+ Builder.CreateCondBr(Cond, Then, Else);
+ Builder.SetInsertPoint(Then);
+ Builder.CreateRet(Builder.getInt64(1));
+ Builder.SetInsertPoint(Else);
+ Builder.CreateRet(Builder.getInt64(1));
+
+ Solver.MarkBlockExecutable(If);
+ Solver.Solve();
+
+ auto RetF = TestLatticeKey(F, IPOGrouping::Return);
+ EXPECT_TRUE(Solver.getExistingValueState(RetF).isConstant());
+}
+
+/// Test that we propagate information through function returns.
+///
+/// define internal i64 @f(i1* %cond) {
+/// if:
+/// %0 = load i1, i1* %cond
+/// br i1 %0, label %then, label %else
+///
+/// then:
+/// ret i64 0
+///
+/// else:
+/// ret i64 1
+/// }
+///
+/// For this test, we initially mark "f" executable, and the solver computes
+/// the return value of the function as overdefined.
+TEST_F(SparsePropagationTest, FunctionOverDefined) {
+ Function *F =
+ Function::Create(FunctionType::get(Builder.getInt64Ty(),
+ {Type::getInt1PtrTy(Context)}, false),
+ GlobalValue::InternalLinkage, "f", &M);
+ BasicBlock *If = BasicBlock::Create(Context, "if", F);
+ BasicBlock *Then = BasicBlock::Create(Context, "then", F);
+ BasicBlock *Else = BasicBlock::Create(Context, "else", F);
+ F->arg_begin()->setName("cond");
+ Builder.SetInsertPoint(If);
+ LoadInst *Cond = Builder.CreateLoad(F->arg_begin());
+ Builder.CreateCondBr(Cond, Then, Else);
+ Builder.SetInsertPoint(Then);
+ Builder.CreateRet(Builder.getInt64(0));
+ Builder.SetInsertPoint(Else);
+ Builder.CreateRet(Builder.getInt64(1));
+
+ Solver.MarkBlockExecutable(If);
+ Solver.Solve();
+
+ auto RetF = TestLatticeKey(F, IPOGrouping::Return);
+ EXPECT_TRUE(Solver.getExistingValueState(RetF).isOverdefined());
+}
+
+/// Test that we propagate information through arguments.
+///
+/// define internal void @f() {
+/// call void @g(i64 0, i64 1)
+/// call void @g(i64 1, i64 1)
+/// ret void
+/// }
+///
+/// define internal void @g(i64 %a, i64 %b) {
+/// ret void
+/// }
+///
+/// For this test, we initially mark "f" executable, and the solver discovers
+/// "g" because of the calls in "f". The solver computes the state of argument
+/// "a" as overdefined and the state of "b" as constant.
+///
+/// In addition, this test demonstrates that ComputeInstructionState can alter
+/// the state of multiple lattice values, in addition to the one associated
+/// with the instruction definition. Each call instruction in this test updates
+/// the state of arguments "a" and "b".
+TEST_F(SparsePropagationTest, ComputeInstructionState) {
+ Function *F = Function::Create(FunctionType::get(Builder.getVoidTy(), false),
+ GlobalValue::InternalLinkage, "f", &M);
+ Function *G = Function::Create(
+ FunctionType::get(Builder.getVoidTy(),
+ {Builder.getInt64Ty(), Builder.getInt64Ty()}, false),
+ GlobalValue::InternalLinkage, "g", &M);
+ Argument *A = G->arg_begin();
+ Argument *B = std::next(G->arg_begin());
+ A->setName("a");
+ B->setName("b");
+ BasicBlock *FEntry = BasicBlock::Create(Context, "", F);
+ BasicBlock *GEntry = BasicBlock::Create(Context, "", G);
+ Builder.SetInsertPoint(FEntry);
+ Builder.CreateCall(G, {Builder.getInt64(0), Builder.getInt64(1)});
+ Builder.CreateCall(G, {Builder.getInt64(1), Builder.getInt64(1)});
+ Builder.CreateRetVoid();
+ Builder.SetInsertPoint(GEntry);
+ Builder.CreateRetVoid();
+
+ Solver.MarkBlockExecutable(FEntry);
+ Solver.Solve();
+
+ auto RegA = TestLatticeKey(A, IPOGrouping::Register);
+ auto RegB = TestLatticeKey(B, IPOGrouping::Register);
+ EXPECT_TRUE(Solver.getExistingValueState(RegA).isOverdefined());
+ EXPECT_TRUE(Solver.getExistingValueState(RegB).isConstant());
+}
+
+/// Test that we can handle exceptional terminator instructions.
+///
+/// declare internal void @p()
+///
+/// declare internal void @g()
+///
+/// define internal void @f() personality i8* bitcast (void ()* @p to i8*) {
+/// entry:
+/// invoke void @g()
+/// to label %exit unwind label %catch.pad
+///
+/// catch.pad:
+/// %0 = catchswitch within none [label %catch.body] unwind to caller
+///
+/// catch.body:
+/// %1 = catchpad within %0 []
+/// catchret from %1 to label %exit
+///
+/// exit:
+/// ret void
+/// }
+///
+/// For this test, we initially mark the entry block executable. The solver
+/// then discovers the rest of the blocks in the function are executable.
+TEST_F(SparsePropagationTest, ExceptionalTerminatorInsts) {
+ Function *P = Function::Create(FunctionType::get(Builder.getVoidTy(), false),
+ GlobalValue::InternalLinkage, "p", &M);
+ Function *G = Function::Create(FunctionType::get(Builder.getVoidTy(), false),
+ GlobalValue::InternalLinkage, "g", &M);
+ Function *F = Function::Create(FunctionType::get(Builder.getVoidTy(), false),
+ GlobalValue::InternalLinkage, "f", &M);
+ Constant *C =
+ ConstantExpr::getCast(Instruction::BitCast, P, Builder.getInt8PtrTy());
+ F->setPersonalityFn(C);
+ BasicBlock *Entry = BasicBlock::Create(Context, "entry", F);
+ BasicBlock *Pad = BasicBlock::Create(Context, "catch.pad", F);
+ BasicBlock *Body = BasicBlock::Create(Context, "catch.body", F);
+ BasicBlock *Exit = BasicBlock::Create(Context, "exit", F);
+ Builder.SetInsertPoint(Entry);
+ Builder.CreateInvoke(G, Exit, Pad);
+ Builder.SetInsertPoint(Pad);
+ CatchSwitchInst *CatchSwitch =
+ Builder.CreateCatchSwitch(ConstantTokenNone::get(Context), nullptr, 1);
+ CatchSwitch->addHandler(Body);
+ Builder.SetInsertPoint(Body);
+ CatchPadInst *CatchPad = Builder.CreateCatchPad(CatchSwitch, {});
+ Builder.CreateCatchRet(CatchPad, Exit);
+ Builder.SetInsertPoint(Exit);
+ Builder.CreateRetVoid();
+
+ Solver.MarkBlockExecutable(Entry);
+ Solver.Solve();
+
+ EXPECT_TRUE(Solver.isBlockExecutable(Pad));
+ EXPECT_TRUE(Solver.isBlockExecutable(Body));
+ EXPECT_TRUE(Solver.isBlockExecutable(Exit));
+}