/// their prof branch_weights metadata.
class SwitchInstProfUpdateWrapper {
SwitchInst &SI;
- Optional<SmallVector<uint32_t, 8> > Weights;
- bool Changed = false;
+ Optional<SmallVector<uint32_t, 8> > Weights = None;
+
+ // Sticky invalid state is needed to safely ignore operations with prof data
+ // in cases where SwitchInstProfUpdateWrapper is created from SwitchInst
+ // with inconsistent prof data. TODO: once we fix all prof data
+ // inconsistencies we can turn invalid state to assertions.
+ enum {
+ Invalid,
+ Initialized,
+ Changed
+ } State = Invalid;
protected:
static MDNode *getProfBranchWeightsMD(const SwitchInst &SI);
MDNode *buildProfBranchWeightsMD();
- Optional<SmallVector<uint32_t, 8> > getProfBranchWeights();
+ void init();
public:
using CaseWeightOpt = Optional<uint32_t>;
SwitchInst &operator*() { return SI; }
operator SwitchInst *() { return &SI; }
- SwitchInstProfUpdateWrapper(SwitchInst &SI)
- : SI(SI), Weights(getProfBranchWeights()) {}
+ SwitchInstProfUpdateWrapper(SwitchInst &SI) : SI(SI) { init(); }
~SwitchInstProfUpdateWrapper() {
- if (Changed)
+ if (State == Changed)
SI.setMetadata(LLVMContext::MD_prof, buildProfBranchWeightsMD());
}
using namespace llvm;
+static cl::opt<bool> SwitchInstProfUpdateWrapperStrict(
+ "switch-inst-prof-update-wrapper-strict", cl::Hidden,
+ cl::desc("Assert that prof branch_weights metadata is valid when creating "
+ "an instance of SwitchInstProfUpdateWrapper"),
+ cl::init(false));
+
//===----------------------------------------------------------------------===//
// AllocaInst Class
//===----------------------------------------------------------------------===//
}
MDNode *SwitchInstProfUpdateWrapper::buildProfBranchWeightsMD() {
- assert(Changed && "called only if metadata has changed");
+ assert(State == Changed && "called only if metadata has changed");
if (!Weights)
return nullptr;
return MDBuilder(SI.getParent()->getContext()).createBranchWeights(*Weights);
}
-Optional<SmallVector<uint32_t, 8> >
-SwitchInstProfUpdateWrapper::getProfBranchWeights() {
+void SwitchInstProfUpdateWrapper::init() {
MDNode *ProfileData = getProfBranchWeightsMD(SI);
- if (!ProfileData)
- return None;
+ if (!ProfileData) {
+ State = Initialized;
+ return;
+ }
+
+ if (ProfileData->getNumOperands() != SI.getNumSuccessors() + 1) {
+ State = Invalid;
+ if (SwitchInstProfUpdateWrapperStrict)
+ assert(!"number of prof branch_weights metadata operands corresponds to"
+ " number of succesors");
+ return;
+ }
SmallVector<uint32_t, 8> Weights;
for (unsigned CI = 1, CE = SI.getNumSuccessors(); CI <= CE; ++CI) {
uint32_t CW = C->getValue().getZExtValue();
Weights.push_back(CW);
}
- return Weights;
+ State = Initialized;
+ this->Weights = std::move(Weights);
}
SwitchInst::CaseIt
if (Weights) {
assert(SI.getNumSuccessors() == Weights->size() &&
"num of prof branch_weights must accord with num of successors");
- Changed = true;
+ State = Changed;
// Copy the last case to the place of the removed one and shrink.
// This is tightly coupled with the way SwitchInst::removeCase() removes
// the cases in SwitchInst::removeCase(CaseIt).
SwitchInstProfUpdateWrapper::CaseWeightOpt W) {
SI.addCase(OnVal, Dest);
+ if (State == Invalid)
+ return;
+
if (!Weights && W && *W) {
- Changed = true;
+ State = Changed;
Weights = SmallVector<uint32_t, 8>(SI.getNumSuccessors(), 0);
Weights.getValue()[SI.getNumSuccessors() - 1] = *W;
} else if (Weights) {
- Changed = true;
+ State = Changed;
Weights.getValue().push_back(W ? *W : 0);
}
if (Weights)
SymbolTableList<Instruction>::iterator
SwitchInstProfUpdateWrapper::eraseFromParent() {
// Instruction is erased. Mark as unchanged to not touch it in the destructor.
- Changed = false;
-
- if (Weights)
- Weights->resize(0);
+ if (State != Invalid) {
+ State = Initialized;
+ if (Weights)
+ Weights->resize(0);
+ }
return SI.eraseFromParent();
}
void SwitchInstProfUpdateWrapper::setSuccessorWeight(
unsigned idx, SwitchInstProfUpdateWrapper::CaseWeightOpt W) {
- if (!W)
+ if (!W || State == Invalid)
return;
if (!Weights && *W)
if (Weights) {
auto &OldW = Weights.getValue()[idx];
if (*W != OldW) {
- Changed = true;
+ State = Changed;
OldW = *W;
}
}
SwitchInstProfUpdateWrapper::getSuccessorWeight(const SwitchInst &SI,
unsigned idx) {
if (MDNode *ProfileData = getProfBranchWeightsMD(SI))
- return mdconst::extract<ConstantInt>(ProfileData->getOperand(idx + 1))
- ->getValue()
- .getZExtValue();
+ if (ProfileData->getNumOperands() == SI.getNumSuccessors() + 1)
+ return mdconst::extract<ConstantInt>(ProfileData->getOperand(idx + 1))
+ ->getValue()
+ .getZExtValue();
return None;
}
EXPECT_EQ(BB1.get(), Handle.getCaseSuccessor());
}
+TEST(InstructionsTest, SwitchInstProfUpdateWrapper) {
+ LLVMContext C;
+
+ std::unique_ptr<BasicBlock> BB1, BB2, BB3;
+ BB1.reset(BasicBlock::Create(C));
+ BB2.reset(BasicBlock::Create(C));
+ BB3.reset(BasicBlock::Create(C));
+
+ // We create block 0 after the others so that it gets destroyed first and
+ // clears the uses of the other basic blocks.
+ std::unique_ptr<BasicBlock> BB0(BasicBlock::Create(C));
+
+ auto *Int32Ty = Type::getInt32Ty(C);
+
+ SwitchInst *SI =
+ SwitchInst::Create(UndefValue::get(Int32Ty), BB0.get(), 4, BB0.get());
+ SI->addCase(ConstantInt::get(Int32Ty, 1), BB1.get());
+ SI->addCase(ConstantInt::get(Int32Ty, 2), BB2.get());
+ SI->setMetadata(LLVMContext::MD_prof,
+ MDBuilder(C).createBranchWeights({ 9, 1, 22 }));
+
+ {
+ SwitchInstProfUpdateWrapper SIW(*SI);
+ EXPECT_EQ(*SIW.getSuccessorWeight(0), 9u);
+ EXPECT_EQ(*SIW.getSuccessorWeight(1), 1u);
+ EXPECT_EQ(*SIW.getSuccessorWeight(2), 22u);
+ SIW.setSuccessorWeight(0, 99u);
+ SIW.setSuccessorWeight(1, 11u);
+ EXPECT_EQ(*SIW.getSuccessorWeight(0), 99u);
+ EXPECT_EQ(*SIW.getSuccessorWeight(1), 11u);
+ EXPECT_EQ(*SIW.getSuccessorWeight(2), 22u);
+ }
+
+ { // Create another wrapper and check that the data persist.
+ SwitchInstProfUpdateWrapper SIW(*SI);
+ EXPECT_EQ(*SIW.getSuccessorWeight(0), 99u);
+ EXPECT_EQ(*SIW.getSuccessorWeight(1), 11u);
+ EXPECT_EQ(*SIW.getSuccessorWeight(2), 22u);
+ }
+
+ // Make prof data invalid by adding one extra weight.
+ SI->setMetadata(LLVMContext::MD_prof, MDBuilder(C).createBranchWeights(
+ { 99, 11, 22, 33 })); // extra
+ { // Invalid prof data makes wrapper act as if there were no prof data.
+ SwitchInstProfUpdateWrapper SIW(*SI);
+ ASSERT_FALSE(SIW.getSuccessorWeight(0).hasValue());
+ ASSERT_FALSE(SIW.getSuccessorWeight(1).hasValue());
+ ASSERT_FALSE(SIW.getSuccessorWeight(2).hasValue());
+ SIW.addCase(ConstantInt::get(Int32Ty, 3), BB3.get(), 39);
+ ASSERT_FALSE(SIW.getSuccessorWeight(3).hasValue()); // did not add weight 39
+ }
+
+ { // With added 3rd case the prof data become consistent with num of cases.
+ SwitchInstProfUpdateWrapper SIW(*SI);
+ EXPECT_EQ(*SIW.getSuccessorWeight(0), 99u);
+ EXPECT_EQ(*SIW.getSuccessorWeight(1), 11u);
+ EXPECT_EQ(*SIW.getSuccessorWeight(2), 22u);
+ EXPECT_EQ(*SIW.getSuccessorWeight(3), 33u);
+ }
+
+ // Make prof data invalid by removing one extra weight.
+ SI->setMetadata(LLVMContext::MD_prof,
+ MDBuilder(C).createBranchWeights({ 99, 11, 22 })); // shorter
+ { // Invalid prof data makes wrapper act as if there were no prof data.
+ SwitchInstProfUpdateWrapper SIW(*SI);
+ ASSERT_FALSE(SIW.getSuccessorWeight(0).hasValue());
+ ASSERT_FALSE(SIW.getSuccessorWeight(1).hasValue());
+ ASSERT_FALSE(SIW.getSuccessorWeight(2).hasValue());
+ SIW.removeCase(SwitchInst::CaseIt(SI, 2));
+ }
+
+ { // With removed 3rd case the prof data become consistent with num of cases.
+ SwitchInstProfUpdateWrapper SIW(*SI);
+ EXPECT_EQ(*SIW.getSuccessorWeight(0), 99u);
+ EXPECT_EQ(*SIW.getSuccessorWeight(1), 11u);
+ EXPECT_EQ(*SIW.getSuccessorWeight(2), 22u);
+ }
+}
+
TEST(InstructionsTest, CommuteShuffleMask) {
SmallVector<int, 16> Indices({-1, 0, 7});
ShuffleVectorInst::commuteShuffleMask(Indices, 4);