From: Argyrios Kyrtzidis Date: Tue, 8 Jan 2013 00:58:25 +0000 (+0000) Subject: [arcmt] Follow-up for r171484; make sure when adding brackets enclosing case statements, X-Git-Url: https://granicus.if.org/sourcecode?a=commitdiff_plain;h=4ce831cba3ae3030674fd9d90f8a69f7b3938d63;p=clang [arcmt] Follow-up for r171484; make sure when adding brackets enclosing case statements, that the case does not "contain" a declaration that is referenced "outside" of it, otherwise we will emit un-compilable code. git-svn-id: https://llvm.org/svn/llvm-project/cfe/trunk@171828 91177308-0d34-0410-b5e6-96231b3b80d8 --- diff --git a/lib/ARCMigrate/TransProtectedScope.cpp b/lib/ARCMigrate/TransProtectedScope.cpp index b202600625..38d72ff428 100644 --- a/lib/ARCMigrate/TransProtectedScope.cpp +++ b/lib/ARCMigrate/TransProtectedScope.cpp @@ -15,6 +15,7 @@ #include "Transforms.h" #include "Internals.h" #include "clang/Sema/SemaDiagnostic.h" +#include "clang/AST/ASTContext.h" using namespace clang; using namespace arcmt; @@ -22,26 +23,58 @@ using namespace trans; namespace { +class LocalRefsCollector : public RecursiveASTVisitor { + SmallVectorImpl &Refs; + +public: + LocalRefsCollector(SmallVectorImpl &refs) + : Refs(refs) { } + + bool VisitDeclRefExpr(DeclRefExpr *E) { + if (ValueDecl *D = E->getDecl()) + if (D->getDeclContext()->getRedeclContext()->isFunctionOrMethod()) + Refs.push_back(E); + return true; + } +}; + struct CaseInfo { SwitchCase *SC; SourceRange Range; - bool FixedBypass; + enum { + St_Unchecked, + St_CannotFix, + St_Fixed + } State; - CaseInfo() : SC(0), FixedBypass(false) {} + CaseInfo() : SC(0), State(St_Unchecked) {} CaseInfo(SwitchCase *S, SourceRange Range) - : SC(S), Range(Range), FixedBypass(false) {} + : SC(S), Range(Range), State(St_Unchecked) {} }; class CaseCollector : public RecursiveASTVisitor { + ParentMap &PMap; llvm::SmallVectorImpl &Cases; public: - CaseCollector(llvm::SmallVectorImpl &Cases) - : Cases(Cases) { } + CaseCollector(ParentMap &PMap, llvm::SmallVectorImpl &Cases) + : PMap(PMap), Cases(Cases) { } bool VisitSwitchStmt(SwitchStmt *S) { - SourceLocation NextLoc = S->getLocEnd(); SwitchCase *Curr = S->getSwitchCaseList(); + if (!Curr) + return true; + Stmt *Parent = getCaseParent(Curr); + Curr = Curr->getNextSwitchCase(); + // Make sure all case statements are in the same scope. + while (Curr) { + if (getCaseParent(Curr) != Parent) + return true; + Curr = Curr->getNextSwitchCase(); + } + + SourceLocation NextLoc = S->getLocEnd(); + Curr = S->getSwitchCaseList(); // We iterate over case statements in reverse source-order. while (Curr) { Cases.push_back(CaseInfo(Curr,SourceRange(Curr->getLocStart(), NextLoc))); @@ -50,69 +83,114 @@ public: } return true; } -}; -} // anonymous namespace + Stmt *getCaseParent(SwitchCase *S) { + Stmt *Parent = PMap.getParent(S); + while (Parent && (isa(Parent) || isa(Parent))) + Parent = PMap.getParent(Parent); + return Parent; + } +}; -static bool isInRange(FullSourceLoc Loc, SourceRange R) { - return !Loc.isBeforeInTranslationUnitThan(R.getBegin()) && - Loc.isBeforeInTranslationUnitThan(R.getEnd()); -} +class ProtectedScopeFixer { + MigrationPass &Pass; + SourceManager &SM; + SmallVector Cases; + SmallVector LocalRefs; -static bool handleProtectedNote(const StoredDiagnostic &Diag, - llvm::SmallVectorImpl &Cases, - TransformActions &TA) { - assert(Diag.getLevel() == DiagnosticsEngine::Note); - - for (unsigned i = 0; i != Cases.size(); i++) { - CaseInfo &info = Cases[i]; - if (isInRange(Diag.getLocation(), info.Range)) { - TA.clearDiagnostic(Diag.getID(), Diag.getLocation()); - if (!info.FixedBypass) { - TA.insertAfterToken(info.SC->getColonLoc(), " {"); - TA.insert(info.Range.getEnd(), "}\n"); - info.FixedBypass = true; +public: + ProtectedScopeFixer(BodyContext &BodyCtx) + : Pass(BodyCtx.getMigrationContext().Pass), + SM(Pass.Ctx.getSourceManager()) { + + CaseCollector(BodyCtx.getParentMap(), Cases) + .TraverseStmt(BodyCtx.getTopStmt()); + LocalRefsCollector(LocalRefs).TraverseStmt(BodyCtx.getTopStmt()); + + SourceRange BodyRange = BodyCtx.getTopStmt()->getSourceRange(); + const CapturedDiagList &DiagList = Pass.getDiags(); + CapturedDiagList::iterator I = DiagList.begin(), E = DiagList.end(); + while (I != E) { + if (I->getID() == diag::err_switch_into_protected_scope && + isInRange(I->getLocation(), BodyRange)) { + handleProtectedScopeError(I, E); + continue; } - return true; + ++I; } } - return false; -} + void handleProtectedScopeError(CapturedDiagList::iterator &DiagI, + CapturedDiagList::iterator DiagE) { + Transaction Trans(Pass.TA); + assert(DiagI->getID() == diag::err_switch_into_protected_scope); + SourceLocation ErrLoc = DiagI->getLocation(); + bool handledAllNotes = true; + ++DiagI; + for (; DiagI != DiagE && DiagI->getLevel() == DiagnosticsEngine::Note; + ++DiagI) { + if (!handleProtectedNote(*DiagI)) + handledAllNotes = false; + } -static void handleProtectedScopeError(CapturedDiagList::iterator &DiagI, - CapturedDiagList::iterator DiagE, - llvm::SmallVectorImpl &Cases, - TransformActions &TA) { - Transaction Trans(TA); - assert(DiagI->getID() == diag::err_switch_into_protected_scope); - SourceLocation ErrLoc = DiagI->getLocation(); - bool handledAllNotes = true; - ++DiagI; - for (; DiagI != DiagE && DiagI->getLevel() == DiagnosticsEngine::Note; - ++DiagI) { - if (!handleProtectedNote(*DiagI, Cases, TA)) - handledAllNotes = false; + if (handledAllNotes) + Pass.TA.clearDiagnostic(diag::err_switch_into_protected_scope, ErrLoc); } - if (handledAllNotes) - TA.clearDiagnostic(diag::err_switch_into_protected_scope, ErrLoc); -} + bool handleProtectedNote(const StoredDiagnostic &Diag) { + assert(Diag.getLevel() == DiagnosticsEngine::Note); -void ProtectedScopeTraverser::traverseBody(BodyContext &BodyCtx) { - MigrationPass &Pass = BodyCtx.getMigrationContext().Pass; - SmallVector Cases; - CaseCollector(Cases).TraverseStmt(BodyCtx.getTopStmt()); - - SourceRange BodyRange = BodyCtx.getTopStmt()->getSourceRange(); - const CapturedDiagList &DiagList = Pass.getDiags(); - CapturedDiagList::iterator I = DiagList.begin(), E = DiagList.end(); - while (I != E) { - if (I->getID() == diag::err_switch_into_protected_scope && - isInRange(I->getLocation(), BodyRange)) { - handleProtectedScopeError(I, E, Cases, Pass.TA); - continue; + for (unsigned i = 0; i != Cases.size(); i++) { + CaseInfo &info = Cases[i]; + if (isInRange(Diag.getLocation(), info.Range)) { + + if (info.State == CaseInfo::St_Unchecked) + tryFixing(info); + assert(info.State != CaseInfo::St_Unchecked); + + if (info.State == CaseInfo::St_Fixed) { + Pass.TA.clearDiagnostic(Diag.getID(), Diag.getLocation()); + return true; + } + return false; + } + } + + return false; + } + + void tryFixing(CaseInfo &info) { + assert(info.State == CaseInfo::St_Unchecked); + if (hasVarReferencedOutside(info)) { + info.State = CaseInfo::St_CannotFix; + return; } - ++I; + + Pass.TA.insertAfterToken(info.SC->getColonLoc(), " {"); + Pass.TA.insert(info.Range.getEnd(), "}\n"); + info.State = CaseInfo::St_Fixed; + } + + bool hasVarReferencedOutside(CaseInfo &info) { + for (unsigned i = 0, e = LocalRefs.size(); i != e; ++i) { + DeclRefExpr *DRE = LocalRefs[i]; + if (isInRange(DRE->getDecl()->getLocation(), info.Range) && + !isInRange(DRE->getLocation(), info.Range)) + return true; + } + return false; + } + + bool isInRange(SourceLocation Loc, SourceRange R) { + if (Loc.isInvalid()) + return false; + return !SM.isBeforeInTranslationUnit(Loc, R.getBegin()) && + SM.isBeforeInTranslationUnit(Loc, R.getEnd()); } +}; + +} // anonymous namespace + +void ProtectedScopeTraverser::traverseBody(BodyContext &BodyCtx) { + ProtectedScopeFixer Fix(BodyCtx); } diff --git a/test/ARCMT/checking.m b/test/ARCMT/checking.m index b0d3243254..9fd50029d0 100644 --- a/test/ARCMT/checking.m +++ b/test/ARCMT/checking.m @@ -181,9 +181,10 @@ void test6(unsigned cond) { switch (cond) { case 0: ; - id x; + id x; // expected-note {{jump bypasses initialization of retaining variable}} - case 1: + case 1: // expected-error {{switch case is in protected scope}} + x = 0; break; } } diff --git a/test/ARCMT/protected-scope.m b/test/ARCMT/protected-scope.m index b33382ed50..8aece44d4f 100644 --- a/test/ARCMT/protected-scope.m +++ b/test/ARCMT/protected-scope.m @@ -18,6 +18,7 @@ void test(id p, int x) { id w3 = p; break; case 2: + case 3: break; default: break; diff --git a/test/ARCMT/protected-scope.m.result b/test/ARCMT/protected-scope.m.result index 42d58b8221..f385d8825d 100644 --- a/test/ARCMT/protected-scope.m.result +++ b/test/ARCMT/protected-scope.m.result @@ -20,6 +20,7 @@ void test(id p, int x) { break; } case 2: + case 3: break; default: break;