From fe7beeb421b90f66b4a96c2358c5a838c376a0a8 Mon Sep 17 00:00:00 2001 From: Daniel Jasper Date: Mon, 16 Jul 2012 09:18:17 +0000 Subject: [PATCH] Add refactoring callbacks to make common kinds of refactorings easy. git-svn-id: https://llvm.org/svn/llvm-project/cfe/trunk@160255 91177308-0d34-0410-b5e6-96231b3b80d8 --- .../clang/ASTMatchers/RefactoringCallbacks.h | 90 +++++++++++++++++ include/clang/Tooling/Tooling.h | 6 +- lib/ASTMatchers/CMakeLists.txt | 1 + lib/ASTMatchers/RefactoringCallbacks.cpp | 78 +++++++++++++++ unittests/ASTMatchers/CMakeLists.txt | 3 +- .../ASTMatchers/RefactoringCallbacksTest.cpp | 98 +++++++++++++++++++ 6 files changed, 273 insertions(+), 3 deletions(-) create mode 100644 include/clang/ASTMatchers/RefactoringCallbacks.h create mode 100644 lib/ASTMatchers/RefactoringCallbacks.cpp create mode 100644 unittests/ASTMatchers/RefactoringCallbacksTest.cpp diff --git a/include/clang/ASTMatchers/RefactoringCallbacks.h b/include/clang/ASTMatchers/RefactoringCallbacks.h new file mode 100644 index 0000000000..5d9c99ff2c --- /dev/null +++ b/include/clang/ASTMatchers/RefactoringCallbacks.h @@ -0,0 +1,90 @@ +//===--- RefactoringCallbacks.h - Structural query framework ----*- C++ -*-===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// Provides callbacks to make common kinds of refactorings easy. +// +// The general idea is to construct a matcher expression that describes a +// subtree match on the AST and then replace the corresponding source code +// either by some specific text or some other AST node. +// +// Example: +// int main(int argc, char **argv) { +// ClangTool Tool(argc, argv); +// MatchFinder Finder; +// ReplaceStmtWithText Callback("integer", "42"); +// Finder.AddMatcher(id("integer", expression(integerLiteral())), Callback); +// return Tool.run(newFrontendActionFactory(&Finder)); +// } +// +// This will replace all integer literals with "42". +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_CLANG_AST_MATCHERS_REFACTORING_CALLBACKS_H +#define LLVM_CLANG_AST_MATCHERS_REFACTORING_CALLBACKS_H + +#include "clang/ASTMatchers/ASTMatchFinder.h" +#include "clang/Tooling/Refactoring.h" + +namespace clang { +namespace ast_matchers { + +/// \brief Base class for RefactoringCallbacks. +/// +/// Collects \c tooling::Replacements while running. +class RefactoringCallback : public MatchFinder::MatchCallback { +public: + RefactoringCallback(); + tooling::Replacements &getReplacements(); + +protected: + tooling::Replacements Replace; +}; + +/// \brief Replace the text of the statement bound to \c FromId with the text in +/// \c ToText. +class ReplaceStmtWithText : public RefactoringCallback { +public: + ReplaceStmtWithText(StringRef FromId, StringRef ToText); + virtual void run(const MatchFinder::MatchResult &Result); + +private: + std::string FromId; + std::string ToText; +}; + +/// \brief Replace the text of the statement bound to \c FromId with the text of +/// the statement bound to \c ToId. +class ReplaceStmtWithStmt : public RefactoringCallback { +public: + ReplaceStmtWithStmt(StringRef FromId, StringRef ToId); + virtual void run(const MatchFinder::MatchResult &Result); + +private: + std::string FromId; + std::string ToId; +}; + +/// \brief Replace an if-statement bound to \c Id with the outdented text of its +/// body, choosing the consequent or the alternative based on whether +/// \c PickTrueBranch is true. +class ReplaceIfStmtWithItsBody : public RefactoringCallback { +public: + ReplaceIfStmtWithItsBody(StringRef Id, bool PickTrueBranch); + virtual void run(const MatchFinder::MatchResult &Result); + +private: + std::string Id; + const bool PickTrueBranch; +}; + +} // end namespace ast_matchers +} // end namespace clang + +#endif // LLVM_CLANG_AST_MATCHERS_REFACTORING_CALLBACKS_H diff --git a/include/clang/Tooling/Tooling.h b/include/clang/Tooling/Tooling.h index 03f9d0b4bc..e06705f027 100644 --- a/include/clang/Tooling/Tooling.h +++ b/include/clang/Tooling/Tooling.h @@ -86,7 +86,8 @@ FrontendActionFactory *newFrontendActionFactory(); /// FrontendActionFactory *FactoryAdapter = /// newFrontendActionFactory(&Factory); template -FrontendActionFactory *newFrontendActionFactory(FactoryT *ConsumerFactory); +inline FrontendActionFactory *newFrontendActionFactory( + FactoryT *ConsumerFactory); /// \brief Runs (and deletes) the tool on 'Code' with the -fsyntax-only flag. /// @@ -202,7 +203,8 @@ FrontendActionFactory *newFrontendActionFactory() { } template -FrontendActionFactory *newFrontendActionFactory(FactoryT *ConsumerFactory) { +inline FrontendActionFactory *newFrontendActionFactory( + FactoryT *ConsumerFactory) { class FrontendActionFactoryAdapter : public FrontendActionFactory { public: explicit FrontendActionFactoryAdapter(FactoryT *ConsumerFactory) diff --git a/lib/ASTMatchers/CMakeLists.txt b/lib/ASTMatchers/CMakeLists.txt index 8fc7d4b208..ac7988d86f 100644 --- a/lib/ASTMatchers/CMakeLists.txt +++ b/lib/ASTMatchers/CMakeLists.txt @@ -4,6 +4,7 @@ set(LLVM_USED_LIBS clangBasic clangAST) add_clang_library(clangASTMatchers ASTMatchFinder.cpp ASTMatchersInternal.cpp + RefactoringCallbacks.cpp ) add_dependencies(clangASTMatchers diff --git a/lib/ASTMatchers/RefactoringCallbacks.cpp b/lib/ASTMatchers/RefactoringCallbacks.cpp new file mode 100644 index 0000000000..e747cd7bdc --- /dev/null +++ b/lib/ASTMatchers/RefactoringCallbacks.cpp @@ -0,0 +1,78 @@ +//===--- RefactoringCallbacks.cpp - Structural query framework ------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// +//===----------------------------------------------------------------------===// +#include "clang/Lex/Lexer.h" +#include "clang/ASTMatchers/RefactoringCallbacks.h" + +namespace clang { +namespace ast_matchers { + +RefactoringCallback::RefactoringCallback() {} +tooling::Replacements &RefactoringCallback::getReplacements() { + return Replace; +} + +static tooling::Replacement replaceStmtWithText(SourceManager &Sources, + const Stmt &From, + StringRef Text) { + return tooling::Replacement(Sources, CharSourceRange::getTokenRange( + From.getSourceRange()), Text); +} +static tooling::Replacement replaceStmtWithStmt(SourceManager &Sources, + const Stmt &From, + const Stmt &To) { + return replaceStmtWithText(Sources, From, Lexer::getSourceText( + CharSourceRange::getTokenRange(To.getSourceRange()), + Sources, LangOptions())); +} + +ReplaceStmtWithText::ReplaceStmtWithText(StringRef FromId, StringRef ToText) + : FromId(FromId), ToText(ToText) {} + +void ReplaceStmtWithText::run(const MatchFinder::MatchResult &Result) { + if (const Stmt *FromMatch = Result.Nodes.getStmtAs(FromId)) { + Replace.insert(tooling::Replacement( + *Result.SourceManager, + CharSourceRange::getTokenRange(FromMatch->getSourceRange()), + ToText)); + } +} + +ReplaceStmtWithStmt::ReplaceStmtWithStmt(StringRef FromId, StringRef ToId) + : FromId(FromId), ToId(ToId) {} + +void ReplaceStmtWithStmt::run(const MatchFinder::MatchResult &Result) { + const Stmt *FromMatch = Result.Nodes.getStmtAs(FromId); + const Stmt *ToMatch = Result.Nodes.getStmtAs(ToId); + if (FromMatch && ToMatch) + Replace.insert(replaceStmtWithStmt( + *Result.SourceManager, *FromMatch, *ToMatch)); +} + +ReplaceIfStmtWithItsBody::ReplaceIfStmtWithItsBody(StringRef Id, + bool PickTrueBranch) + : Id(Id), PickTrueBranch(PickTrueBranch) {} + +void ReplaceIfStmtWithItsBody::run(const MatchFinder::MatchResult &Result) { + if (const IfStmt *Node = Result.Nodes.getStmtAs(Id)) { + const Stmt *Body = PickTrueBranch ? Node->getThen() : Node->getElse(); + if (Body) { + Replace.insert(replaceStmtWithStmt(*Result.SourceManager, *Node, *Body)); + } else if (!PickTrueBranch) { + // If we want to use the 'else'-branch, but it doesn't exist, delete + // the whole 'if'. + Replace.insert(replaceStmtWithText(*Result.SourceManager, *Node, "")); + } + } +} + +} // end namespace ast_matchers +} // end namespace clang diff --git a/unittests/ASTMatchers/CMakeLists.txt b/unittests/ASTMatchers/CMakeLists.txt index 8e61732141..9e02634778 100644 --- a/unittests/ASTMatchers/CMakeLists.txt +++ b/unittests/ASTMatchers/CMakeLists.txt @@ -1,5 +1,6 @@ add_clang_unittest(ASTMatchersTests - ASTMatchersTest.cpp) + ASTMatchersTest.cpp + RefactoringCallbacksTest.cpp) target_link_libraries(ASTMatchersTests gtest gtest_main clangASTMatchers clangTooling) diff --git a/unittests/ASTMatchers/RefactoringCallbacksTest.cpp b/unittests/ASTMatchers/RefactoringCallbacksTest.cpp new file mode 100644 index 0000000000..bb9f504a03 --- /dev/null +++ b/unittests/ASTMatchers/RefactoringCallbacksTest.cpp @@ -0,0 +1,98 @@ +//===- unittest/ASTMatchers/RefactoringCallbacksTest.cpp ------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#include "clang/ASTMatchers/ASTMatchers.h" +#include "clang/ASTMatchers/ASTMatchFinder.h" +#include "clang/ASTMatchers/RefactoringCallbacks.h" +#include "../Tooling/RewriterTestContext.h" +#include "gtest/gtest.h" + +namespace clang { +namespace ast_matchers { + +template +void expectRewritten(const std::string &Code, + const std::string &Expected, + const T &AMatcher, + RefactoringCallback &Callback) { + MatchFinder Finder; + Finder.addMatcher(AMatcher, &Callback); + OwningPtr Factory( + tooling::newFrontendActionFactory(&Finder)); + ASSERT_TRUE(tooling::runToolOnCode(Factory->create(), Code)) + << "Parsing error in \"" << Code << "\""; + RewriterTestContext Context; + FileID ID = Context.createInMemoryFile("input.cc", Code); + EXPECT_TRUE(tooling::applyAllReplacements(Callback.getReplacements(), + Context.Rewrite)); + EXPECT_EQ(Expected, Context.getRewrittenText(ID)); +} + +TEST(RefactoringCallbacksTest, ReplacesStmtsWithString) { + std::string Code = "void f() { int i = 1; }"; + std::string Expected = "void f() { ; }"; + ReplaceStmtWithText Callback("id", ";"); + expectRewritten(Code, Expected, id("id", declarationStatement()), Callback); +} + +TEST(RefactoringCallbacksTest, ReplacesStmtsInCalledMacros) { + std::string Code = "#define A void f() { int i = 1; }\nA"; + std::string Expected = "#define A void f() { ; }\nA"; + ReplaceStmtWithText Callback("id", ";"); + expectRewritten(Code, Expected, id("id", declarationStatement()), Callback); +} + +TEST(RefactoringCallbacksTest, IgnoresStmtsInUncalledMacros) { + std::string Code = "#define A void f() { int i = 1; }"; + std::string Expected = "#define A void f() { int i = 1; }"; + ReplaceStmtWithText Callback("id", ";"); + expectRewritten(Code, Expected, id("id", declarationStatement()), Callback); +} + +TEST(RefactoringCallbacksTest, ReplacesInteger) { + std::string Code = "void f() { int i = 1; }"; + std::string Expected = "void f() { int i = 2; }"; + ReplaceStmtWithText Callback("id", "2"); + expectRewritten(Code, Expected, id("id", expression(integerLiteral())), + Callback); +} + +TEST(RefactoringCallbacksTest, ReplacesStmtWithStmt) { + std::string Code = "void f() { int i = false ? 1 : i * 2; }"; + std::string Expected = "void f() { int i = i * 2; }"; + ReplaceStmtWithStmt Callback("always-false", "should-be"); + expectRewritten(Code, Expected, + id("always-false", conditionalOperator( + hasCondition(boolLiteral(equals(false))), + hasFalseExpression(id("should-be", expression())))), + Callback); +} + +TEST(RefactoringCallbacksTest, ReplacesIfStmt) { + std::string Code = "bool a; void f() { if (a) f(); else a = true; }"; + std::string Expected = "bool a; void f() { f(); }"; + ReplaceIfStmtWithItsBody Callback("id", true); + expectRewritten(Code, Expected, + id("id", ifStmt( + hasCondition(implicitCast(hasSourceExpression( + declarationReference(to(variable(hasName("a"))))))))), + Callback); +} + +TEST(RefactoringCallbacksTest, RemovesEntireIfOnEmptyElse) { + std::string Code = "void f() { if (false) int i = 0; }"; + std::string Expected = "void f() { }"; + ReplaceIfStmtWithItsBody Callback("id", false); + expectRewritten(Code, Expected, + id("id", ifStmt(hasCondition(boolLiteral(equals(false))))), + Callback); +} + +} // end namespace ast_matchers +} // end namespace clang -- 2.40.0