From 3d102c8ddd12ab268a8484ee21f828ef5cd8f2c9 Mon Sep 17 00:00:00 2001
From: Alex Lorenz <arphaman@gmail.com>
Date: Wed, 1 Nov 2017 00:07:12 +0000
Subject: [PATCH] [refactor][selection] code ranges can be selected in objc
 methods

git-svn-id: https://llvm.org/svn/llvm-project/cfe/trunk@317054 91177308-0d34-0410-b5e6-96231b3b80d8
---
 lib/Tooling/Refactoring/ASTSelection.cpp | 11 ++--
 unittests/Tooling/ASTSelectionTest.cpp   | 76 ++++++++++++++++++++++++
 2 files changed, 83 insertions(+), 4 deletions(-)

diff --git a/lib/Tooling/Refactoring/ASTSelection.cpp b/lib/Tooling/Refactoring/ASTSelection.cpp
index 6ac432622c..71a0d44be1 100644
--- a/lib/Tooling/Refactoring/ASTSelection.cpp
+++ b/lib/Tooling/Refactoring/ASTSelection.cpp
@@ -347,6 +347,11 @@ CodeRangeASTSelection::create(SourceRange SelectionRange,
                                /*AreChildrenSelected=*/true);
 }
 
+static bool isFunctionLikeDeclaration(const Decl *D) {
+  // FIXME (Alex L): Test for BlockDecl.
+  return isa<FunctionDecl>(D) || isa<ObjCMethodDecl>(D);
+}
+
 bool CodeRangeASTSelection::isInFunctionLikeBodyOfCode() const {
   bool IsPrevCompound = false;
   // Scan through the parents (bottom-to-top) and check if the selection is
@@ -355,8 +360,7 @@ bool CodeRangeASTSelection::isInFunctionLikeBodyOfCode() const {
   for (const auto &Parent : llvm::reverse(Parents)) {
     const DynTypedNode &Node = Parent.get().Node;
     if (const auto *D = Node.get<Decl>()) {
-      // FIXME (Alex L): Test for BlockDecl && ObjCMethodDecl.
-      if (isa<FunctionDecl>(D))
+      if (isFunctionLikeDeclaration(D))
         return IsPrevCompound;
       // FIXME (Alex L): We should return false on top-level decls in functions
       // e.g. we don't want to extract:
@@ -372,8 +376,7 @@ const Decl *CodeRangeASTSelection::getFunctionLikeNearestParent() const {
   for (const auto &Parent : llvm::reverse(Parents)) {
     const DynTypedNode &Node = Parent.get().Node;
     if (const auto *D = Node.get<Decl>()) {
-      // FIXME (Alex L): Test for BlockDecl && ObjCMethodDecl.
-      if (isa<FunctionDecl>(D))
+      if (isFunctionLikeDeclaration(D))
         return D;
     }
   }
diff --git a/unittests/Tooling/ASTSelectionTest.cpp b/unittests/Tooling/ASTSelectionTest.cpp
index 94435d49a8..1435334d6c 100644
--- a/unittests/Tooling/ASTSelectionTest.cpp
+++ b/unittests/Tooling/ASTSelectionTest.cpp
@@ -896,4 +896,80 @@ void f(int x, int y) {
       });
 }
 
+TEST(ASTSelectionFinder, SimpleCodeRangeASTSelectionInObjCMethod) {
+  StringRef Source = R"(@interface I @end
+@implementation I
+- (void) f:(int)x with:(int) y {
+  int z = x;
+  [self f: 2 with: 3];
+  if (x == 0) {
+    return;
+  }
+  x = 1;
+  return;
+}
+- (void)f2 {
+  int m = 0;
+}
+@end
+)";
+  // Range that spans multiple methods is an invalid code range.
+  findSelectedASTNodesWithRange(
+      Source, {9, 2}, FileRange{{9, 2}, {13, 1}},
+      [](SourceRange SelectionRange, Optional<SelectedASTNode> Node) {
+        EXPECT_TRUE(Node);
+        Optional<CodeRangeASTSelection> SelectedCode =
+            CodeRangeASTSelection::create(SelectionRange, std::move(*Node));
+        EXPECT_FALSE(SelectedCode);
+      },
+      SelectionFinderVisitor::Lang_OBJC);
+  // Just 'z = x;':
+  findSelectedASTNodesWithRange(
+      Source, {4, 2}, FileRange{{4, 2}, {4, 13}},
+      [](SourceRange SelectionRange, Optional<SelectedASTNode> Node) {
+        EXPECT_TRUE(Node);
+        Optional<CodeRangeASTSelection> SelectedCode =
+            CodeRangeASTSelection::create(SelectionRange, std::move(*Node));
+        EXPECT_TRUE(SelectedCode);
+        EXPECT_EQ(SelectedCode->size(), 1u);
+        EXPECT_TRUE(isa<DeclStmt>((*SelectedCode)[0]));
+        ArrayRef<SelectedASTNode::ReferenceType> Parents =
+            SelectedCode->getParents();
+        EXPECT_EQ(Parents.size(), 4u);
+        EXPECT_TRUE(
+            isa<TranslationUnitDecl>(Parents[0].get().Node.get<Decl>()));
+        // 'I' @implementation.
+        EXPECT_TRUE(isa<ObjCImplDecl>(Parents[1].get().Node.get<Decl>()));
+        // Function 'f' definition.
+        EXPECT_TRUE(isa<ObjCMethodDecl>(Parents[2].get().Node.get<Decl>()));
+        // Function body of function 'F'.
+        EXPECT_TRUE(isa<CompoundStmt>(Parents[3].get().Node.get<Stmt>()));
+      },
+      SelectionFinderVisitor::Lang_OBJC);
+  // From '[self f: 2 with: 3]' until just before 'x = 1;':
+  findSelectedASTNodesWithRange(
+      Source, {5, 2}, FileRange{{5, 2}, {9, 1}},
+      [](SourceRange SelectionRange, Optional<SelectedASTNode> Node) {
+        EXPECT_TRUE(Node);
+        Optional<CodeRangeASTSelection> SelectedCode =
+            CodeRangeASTSelection::create(SelectionRange, std::move(*Node));
+        EXPECT_TRUE(SelectedCode);
+        EXPECT_EQ(SelectedCode->size(), 2u);
+        EXPECT_TRUE(isa<ObjCMessageExpr>((*SelectedCode)[0]));
+        EXPECT_TRUE(isa<IfStmt>((*SelectedCode)[1]));
+        ArrayRef<SelectedASTNode::ReferenceType> Parents =
+            SelectedCode->getParents();
+        EXPECT_EQ(Parents.size(), 4u);
+        EXPECT_TRUE(
+            isa<TranslationUnitDecl>(Parents[0].get().Node.get<Decl>()));
+        // 'I' @implementation.
+        EXPECT_TRUE(isa<ObjCImplDecl>(Parents[1].get().Node.get<Decl>()));
+        // Function 'f' definition.
+        EXPECT_TRUE(isa<ObjCMethodDecl>(Parents[2].get().Node.get<Decl>()));
+        // Function body of function 'F'.
+        EXPECT_TRUE(isa<CompoundStmt>(Parents[3].get().Node.get<Stmt>()));
+      },
+      SelectionFinderVisitor::Lang_OBJC);
+}
+
 } // end anonymous namespace
-- 
2.40.0