#include "SafeStackLayout.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/ADT/Triple.h"
+#include "llvm/Analysis/AssumptionCache.h"
#include "llvm/Analysis/BranchProbabilityInfo.h"
#include "llvm/Analysis/ScalarEvolution.h"
#include "llvm/Analysis/ScalarEvolutionExpressions.h"
/// determined statically), and the unsafe stack, which contains all
/// local variables that are accessed in ways that we can't prove to
/// be safe.
-class SafeStack : public FunctionPass {
- const TargetMachine *TM;
- const TargetLoweringBase *TL;
- const DataLayout *DL;
- ScalarEvolution *SE;
+class SafeStack {
+ Function &F;
+ const TargetLoweringBase &TL;
+ const DataLayout &DL;
+ ScalarEvolution &SE;
Type *StackPtrTy;
Type *IntPtrTy;
uint64_t AllocaSize);
public:
- static char ID; // Pass identification, replacement for typeid.
- SafeStack(const TargetMachine *TM)
- : FunctionPass(ID), TM(TM), TL(nullptr), DL(nullptr) {
- initializeSafeStackPass(*PassRegistry::getPassRegistry());
- }
- SafeStack() : SafeStack(nullptr) {}
-
- void getAnalysisUsage(AnalysisUsage &AU) const override {
- AU.addRequired<ScalarEvolutionWrapperPass>();
- }
-
- bool doInitialization(Module &M) override {
- DL = &M.getDataLayout();
-
- StackPtrTy = Type::getInt8PtrTy(M.getContext());
- IntPtrTy = DL->getIntPtrType(M.getContext());
- Int32Ty = Type::getInt32Ty(M.getContext());
- Int8Ty = Type::getInt8Ty(M.getContext());
-
- return false;
- }
-
- bool runOnFunction(Function &F) override;
-}; // class SafeStack
+ SafeStack(Function &F, const TargetLoweringBase &TL, const DataLayout &DL,
+ ScalarEvolution &SE)
+ : F(F), TL(TL), DL(DL), SE(SE),
+ StackPtrTy(Type::getInt8PtrTy(F.getContext())),
+ IntPtrTy(DL.getIntPtrType(F.getContext())),
+ Int32Ty(Type::getInt32Ty(F.getContext())),
+ Int8Ty(Type::getInt8Ty(F.getContext())) {}
+
+ // Run the transformation on the associated function.
+ // Returns whether the function was changed.
+ bool run();
+};
uint64_t SafeStack::getStaticAllocaAllocationSize(const AllocaInst* AI) {
- uint64_t Size = DL->getTypeAllocSize(AI->getAllocatedType());
+ uint64_t Size = DL.getTypeAllocSize(AI->getAllocatedType());
if (AI->isArrayAllocation()) {
auto C = dyn_cast<ConstantInt>(AI->getArraySize());
if (!C)
bool SafeStack::IsAccessSafe(Value *Addr, uint64_t AccessSize,
const Value *AllocaPtr, uint64_t AllocaSize) {
- AllocaOffsetRewriter Rewriter(*SE, AllocaPtr);
- const SCEV *Expr = Rewriter.visit(SE->getSCEV(Addr));
+ AllocaOffsetRewriter Rewriter(SE, AllocaPtr);
+ const SCEV *Expr = Rewriter.visit(SE.getSCEV(Addr));
- uint64_t BitWidth = SE->getTypeSizeInBits(Expr->getType());
- ConstantRange AccessStartRange = SE->getUnsignedRange(Expr);
+ uint64_t BitWidth = SE.getTypeSizeInBits(Expr->getType());
+ ConstantRange AccessStartRange = SE.getUnsignedRange(Expr);
ConstantRange SizeRange =
ConstantRange(APInt(BitWidth, 0), APInt(BitWidth, AccessSize));
ConstantRange AccessRange = AccessStartRange.add(SizeRange);
<< *AllocaPtr << "\n"
<< " Access " << *Addr << "\n"
<< " SCEV " << *Expr
- << " U: " << SE->getUnsignedRange(Expr)
- << ", S: " << SE->getSignedRange(Expr) << "\n"
+ << " U: " << SE.getUnsignedRange(Expr)
+ << ", S: " << SE.getSignedRange(Expr) << "\n"
<< " Range " << AccessRange << "\n"
<< " AllocaRange " << AllocaRange << "\n"
<< " " << (Safe ? "safe" : "unsafe") << "\n");
switch (I->getOpcode()) {
case Instruction::Load: {
- if (!IsAccessSafe(UI, DL->getTypeStoreSize(I->getType()), AllocaPtr,
+ if (!IsAccessSafe(UI, DL.getTypeStoreSize(I->getType()), AllocaPtr,
AllocaSize))
return false;
break;
return false;
}
- if (!IsAccessSafe(UI, DL->getTypeStoreSize(I->getOperand(0)->getType()),
+ if (!IsAccessSafe(UI, DL.getTypeStoreSize(I->getOperand(0)->getType()),
AllocaPtr, AllocaSize))
return false;
break;
}
Value *SafeStack::getStackGuard(IRBuilder<> &IRB, Function &F) {
- Value *StackGuardVar = TL->getIRStackGuard(IRB);
+ Value *StackGuardVar = TL.getIRStackGuard(IRB);
if (!StackGuardVar)
StackGuardVar =
F.getParent()->getOrInsertGlobal("__stack_chk_guard", StackPtrTy);
if (!Arg.hasByValAttr())
continue;
uint64_t Size =
- DL->getTypeStoreSize(Arg.getType()->getPointerElementType());
+ DL.getTypeStoreSize(Arg.getType()->getPointerElementType());
if (IsSafeStackAlloca(&Arg, Size))
continue;
if (StackGuardSlot) {
Type *Ty = StackGuardSlot->getAllocatedType();
unsigned Align =
- std::max(DL->getPrefTypeAlignment(Ty), StackGuardSlot->getAlignment());
+ std::max(DL.getPrefTypeAlignment(Ty), StackGuardSlot->getAlignment());
SSL.addObject(StackGuardSlot, getStaticAllocaAllocationSize(StackGuardSlot),
Align, SSC.getFullLiveRange());
}
for (Argument *Arg : ByValArguments) {
Type *Ty = Arg->getType()->getPointerElementType();
- uint64_t Size = DL->getTypeStoreSize(Ty);
+ uint64_t Size = DL.getTypeStoreSize(Ty);
if (Size == 0)
Size = 1; // Don't create zero-sized stack objects.
// Ensure the object is properly aligned.
- unsigned Align = std::max((unsigned)DL->getPrefTypeAlignment(Ty),
+ unsigned Align = std::max((unsigned)DL.getPrefTypeAlignment(Ty),
Arg->getParamAlignment());
SSL.addObject(Arg, Size, Align, SSC.getFullLiveRange());
}
// Ensure the object is properly aligned.
unsigned Align =
- std::max((unsigned)DL->getPrefTypeAlignment(Ty), AI->getAlignment());
+ std::max((unsigned)DL.getPrefTypeAlignment(Ty), AI->getAlignment());
SSL.addObject(AI, Size, Align, SSC.getLiveRange(AI));
}
unsigned Offset = SSL.getObjectOffset(Arg);
Type *Ty = Arg->getType()->getPointerElementType();
- uint64_t Size = DL->getTypeStoreSize(Ty);
+ uint64_t Size = DL.getTypeStoreSize(Ty);
if (Size == 0)
Size = 1; // Don't create zero-sized stack objects.
ArraySize = IRB.CreateIntCast(ArraySize, IntPtrTy, false);
Type *Ty = AI->getAllocatedType();
- uint64_t TySize = DL->getTypeAllocSize(Ty);
+ uint64_t TySize = DL.getTypeAllocSize(Ty);
Value *Size = IRB.CreateMul(ArraySize, ConstantInt::get(IntPtrTy, TySize));
Value *SP = IRB.CreatePtrToInt(IRB.CreateLoad(UnsafeStackPtr), IntPtrTy);
// Align the SP value to satisfy the AllocaInst, type and stack alignments.
unsigned Align = std::max(
- std::max((unsigned)DL->getPrefTypeAlignment(Ty), AI->getAlignment()),
+ std::max((unsigned)DL.getPrefTypeAlignment(Ty), AI->getAlignment()),
(unsigned)StackAlignment);
assert(isPowerOf2_32(Align));
}
}
-bool SafeStack::runOnFunction(Function &F) {
- DEBUG(dbgs() << "[SafeStack] Function: " << F.getName() << "\n");
-
- if (!F.hasFnAttribute(Attribute::SafeStack)) {
- DEBUG(dbgs() << "[SafeStack] safestack is not requested"
- " for this function\n");
- return false;
- }
-
- if (F.isDeclaration()) {
- DEBUG(dbgs() << "[SafeStack] function definition"
- " is not available\n");
- return false;
- }
-
- if (!TM)
- report_fatal_error("Target machine is required");
- TL = TM->getSubtargetImpl(F)->getTargetLowering();
- SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
+bool SafeStack::run() {
+ assert(F.hasFnAttribute(Attribute::SafeStack) &&
+ "Can't run SafeStack on a function without the attribute");
+ assert(!F.isDeclaration() && "Can't run SafeStack on a function declaration");
++NumFunctions;
++NumUnsafeStackRestorePointsFunctions;
IRBuilder<> IRB(&F.front(), F.begin()->getFirstInsertionPt());
- UnsafeStackPtr = TL->getSafeStackPointerLocation(IRB);
+ UnsafeStackPtr = TL.getSafeStackPointerLocation(IRB);
// Load the current stack pointer (we'll also use it as a base pointer).
// FIXME: use a dedicated register for it ?
return true;
}
+class SafeStackLegacyPass : public FunctionPass {
+ const TargetMachine *TM;
+
+public:
+ static char ID; // Pass identification, replacement for typeid..
+ SafeStackLegacyPass(const TargetMachine *TM) : FunctionPass(ID), TM(TM) {
+ initializeSafeStackLegacyPassPass(*PassRegistry::getPassRegistry());
+ }
+
+ SafeStackLegacyPass() : SafeStackLegacyPass(nullptr) {}
+
+ void getAnalysisUsage(AnalysisUsage &AU) const override {
+ AU.addRequired<ScalarEvolutionWrapperPass>();
+ }
+
+ bool runOnFunction(Function &F) override {
+ DEBUG(dbgs() << "[SafeStack] Function: " << F.getName() << "\n");
+
+ if (!F.hasFnAttribute(Attribute::SafeStack)) {
+ DEBUG(dbgs() << "[SafeStack] safestack is not requested"
+ " for this function\n");
+ return false;
+ }
+
+ if (F.isDeclaration()) {
+ DEBUG(dbgs() << "[SafeStack] function definition"
+ " is not available\n");
+ return false;
+ }
+
+ if (!TM)
+ report_fatal_error("Target machine is required");
+ auto *TL = TM->getSubtargetImpl(F)->getTargetLowering();
+ if (!TL)
+ report_fatal_error("TargetLowering instance is required");
+
+ auto *DL = &F.getParent()->getDataLayout();
+ auto &SE = getAnalysis<ScalarEvolutionWrapperPass>().getSE();
+
+ return SafeStack(F, *TL, *DL, SE).run();
+ }
+};
+
} // anonymous namespace
-char SafeStack::ID = 0;
-INITIALIZE_TM_PASS_BEGIN(SafeStack, "safe-stack",
+char SafeStackLegacyPass::ID = 0;
+INITIALIZE_TM_PASS_BEGIN(SafeStackLegacyPass, "safe-stack",
"Safe Stack instrumentation pass", false, false)
-INITIALIZE_TM_PASS_END(SafeStack, "safe-stack",
+INITIALIZE_TM_PASS_END(SafeStackLegacyPass, "safe-stack",
"Safe Stack instrumentation pass", false, false)
FunctionPass *llvm::createSafeStackPass(const llvm::TargetMachine *TM) {
- return new SafeStack(TM);
+ return new SafeStackLegacyPass(TM);
}