]> granicus.if.org Git - clang/commitdiff
Tighten up blocks rewriter to handle casts and some other interesting cases.
authorSteve Naroff <snaroff@apple.com>
Wed, 15 Oct 2008 18:38:58 +0000 (18:38 +0000)
committerSteve Naroff <snaroff@apple.com>
Wed, 15 Oct 2008 18:38:58 +0000 (18:38 +0000)
This fixes <rdar://problem/6289007> clang block rewriter: ^ in cast is not rewritten.

git-svn-id: https://llvm.org/svn/llvm-project/cfe/trunk@57591 91177308-0d34-0410-b5e6-96231b3b80d8

Driver/RewriteBlocks.cpp
test/Rewriter/block-test.c

index 05cbb704d1251fb951965b14dab3d12d18bc1732..2e7ce0559f76b33cc2a014584ca0222a46afb47b 100644 (file)
@@ -140,8 +140,12 @@ public:
   void RewriteCategoryDecl(ObjCCategoryDecl *CatDecl);
   void RewriteProtocolDecl(ObjCProtocolDecl *PDecl);
   void RewriteMethodDecl(ObjCMethodDecl *MDecl);
+
+  void RewriteFunctionTypeProto(QualType funcType, NamedDecl *D);
+  void CheckFunctionPointerDecl(QualType dType, NamedDecl *ND);
+  void RewriteCastExpr(CastExpr *CE);
   
-  bool BlockPointerTypeTakesAnyBlockArguments(QualType QT);
+  bool PointerTypeTakesAnyBlockArguments(QualType QT);
   void GetExtentOfArgList(const char *Name, const char *&LParen, const char *&RParen);
 };
   
@@ -738,6 +742,28 @@ void RewriteBlocks::RewriteBlockDeclRefExpr(BlockDeclRefExpr *BDRE) {
   InsertText(BDRE->getLocStart(), "*", 1);
 }
 
+void RewriteBlocks::RewriteCastExpr(CastExpr *CE) {
+  SourceLocation LocStart = CE->getLocStart();
+  SourceLocation LocEnd = CE->getLocEnd();
+  
+  const char *startBuf = SM->getCharacterData(LocStart);
+  const char *endBuf = SM->getCharacterData(LocEnd);
+  
+  // advance the location to startArgList.
+  const char *argPtr = startBuf;
+  
+  while (*argPtr++ && (argPtr < endBuf)) {
+    switch (*argPtr) {
+      case '^': 
+        // Replace the '^' with '*'.
+        LocStart = LocStart.getFileLocWithOffset(argPtr-startBuf);
+        ReplaceText(LocStart, 1, "*", 1);
+        break;
+    }
+  }
+  return;
+}
+
 void RewriteBlocks::RewriteBlockPointerFunctionArgs(FunctionDecl *FD) {
   SourceLocation DeclLoc = FD->getLocation();
   unsigned parenCount = 0;
@@ -773,10 +799,17 @@ void RewriteBlocks::RewriteBlockPointerFunctionArgs(FunctionDecl *FD) {
   return;
 }
 
-bool RewriteBlocks::BlockPointerTypeTakesAnyBlockArguments(QualType QT) {
-  const BlockPointerType *BPT = QT->getAsBlockPointerType();
-  assert(BPT && "BlockPointerTypeTakeAnyBlockArguments(): not a block pointer type");
-  const FunctionTypeProto *FTP = BPT->getPointeeType()->getAsFunctionTypeProto();
+bool RewriteBlocks::PointerTypeTakesAnyBlockArguments(QualType QT) {
+  const FunctionTypeProto *FTP;
+  const PointerType *PT = QT->getAsPointerType();
+  if (PT) {
+    FTP = PT->getPointeeType()->getAsFunctionTypeProto();
+    assert(FTP && "BlockPointerTypeTakeAnyBlockArguments(): not a function pointer type");
+  } else {
+    const BlockPointerType *BPT = QT->getAsBlockPointerType();
+    assert(BPT && "BlockPointerTypeTakeAnyBlockArguments(): not a block pointer type");
+    FTP = BPT->getPointeeType()->getAsFunctionTypeProto();
+  }
   if (FTP) {
     for (FunctionTypeProto::arg_type_iterator I = FTP->arg_type_begin(), 
          E = FTP->arg_type_end(); I != E; ++I)
@@ -829,13 +862,15 @@ void RewriteBlocks::RewriteBlockPointerDecl(NamedDecl *ND) {
   // scan backward (from the decl location) for the end of the previous decl.
   while (*startBuf != '^' && *startBuf != ';' && startBuf != MainFileStart)
     startBuf--;
-  assert((*startBuf == '^') && 
-         "RewriteBlockPointerDecl() scan error: no caret");
-  // Replace the '^' with '*', computing a negative offset.
-  DeclLoc = DeclLoc.getFileLocWithOffset(startBuf-endBuf);
-  ReplaceText(DeclLoc, 1, "*", 1);
-  
-  if (BlockPointerTypeTakesAnyBlockArguments(DeclT)) {
+    
+  // *startBuf != '^' if we are dealing with a pointer to function that
+  // may take block argument types (which will be handled below).
+  if (*startBuf == '^') {
+    // Replace the '^' with '*', computing a negative offset.
+    DeclLoc = DeclLoc.getFileLocWithOffset(startBuf-endBuf);
+    ReplaceText(DeclLoc, 1, "*", 1);
+  }
+  if (PointerTypeTakesAnyBlockArguments(DeclT)) {
     // Replace the '^' with '*' for arguments.
     DeclLoc = ND->getLocation();
     startBuf = SM->getCharacterData(DeclLoc);
@@ -981,6 +1016,9 @@ Stmt *RewriteBlocks::RewriteFunctionBody(Stmt *S) {
     if (CE->getCallee()->getType()->isBlockPointerType())
       RewriteBlockCall(CE);
   }
+  if (CastExpr *CE = dyn_cast<CastExpr>(S)) {
+    RewriteCastExpr(CE);
+  }
   if (DeclStmt *DS = dyn_cast<DeclStmt>(S)) {
     for (DeclStmt::decl_iterator DI = DS->decl_begin(), DE = DS->decl_end();
          DI != DE; ++DI) {
@@ -989,10 +1027,14 @@ Stmt *RewriteBlocks::RewriteFunctionBody(Stmt *S) {
       if (ValueDecl *ND = dyn_cast<ValueDecl>(SD)) {
         if (isBlockPointerType(ND->getType()))
           RewriteBlockPointerDecl(ND);
+        else if (ND->getType()->isFunctionPointerType()) 
+          CheckFunctionPointerDecl(ND->getType(), ND);
       }
       if (TypedefDecl *TD = dyn_cast<TypedefDecl>(SD)) {
         if (isBlockPointerType(TD->getUnderlyingType()))
           RewriteBlockPointerDecl(TD);
+        else if (TD->getUnderlyingType()->isFunctionPointerType()) 
+          CheckFunctionPointerDecl(TD->getUnderlyingType(), TD);
       }
     }
   }
@@ -1005,25 +1047,33 @@ Stmt *RewriteBlocks::RewriteFunctionBody(Stmt *S) {
   return S;
 }
 
+void RewriteBlocks::RewriteFunctionTypeProto(QualType funcType, NamedDecl *D) {    
+  if (FunctionTypeProto *fproto = dyn_cast<FunctionTypeProto>(funcType)) {
+    for (FunctionTypeProto::arg_type_iterator I = fproto->arg_type_begin(), 
+         E = fproto->arg_type_end(); I && (I != E); ++I)
+      if (isBlockPointerType(*I)) {
+        // All the args are checked/rewritten. Don't call twice!
+        RewriteBlockPointerDecl(D);
+        break;
+      }
+  }
+}
+
+void RewriteBlocks::CheckFunctionPointerDecl(QualType funcType, NamedDecl *ND) {
+  const PointerType *PT = funcType->getAsPointerType();
+  if (PT && PointerTypeTakesAnyBlockArguments(funcType))
+    RewriteFunctionTypeProto(PT->getPointeeType(), ND);
+}
+
 /// HandleDeclInMainFile - This is called for each top-level decl defined in the
 /// main file of the input.
 void RewriteBlocks::HandleDeclInMainFile(Decl *D) {
   if (FunctionDecl *FD = dyn_cast<FunctionDecl>(D)) {
-  
     // Since function prototypes don't have ParmDecl's, we check the function
     // prototype. This enables us to rewrite function declarations and
     // definitions using the same code.
-    QualType funcType = FD->getType();
+    RewriteFunctionTypeProto(FD->getType(), FD);
     
-    if (FunctionTypeProto *fproto = dyn_cast<FunctionTypeProto>(funcType)) {
-      for (FunctionTypeProto::arg_type_iterator I = fproto->arg_type_begin(), 
-           E = fproto->arg_type_end(); I && (I != E); ++I)
-        if (isBlockPointerType(*I)) {
-          // All the args are checked/rewritten. Don't call twice!
-          RewriteBlockPointerDecl(FD);
-          break;
-        }
-    }
     if (Stmt *Body = FD->getBody()) {
       CurFunctionDef = FD;
       FD->setBody(RewriteFunctionBody(Body));
@@ -1058,6 +1108,15 @@ void RewriteBlocks::HandleDeclInMainFile(Decl *D) {
           // Do the rewrite, using S.size() which contains the rewritten size.
           ReplaceText(CBE->getLocStart(), S.size(), Init.c_str(), Init.size());
           SynthesizeBlockLiterals(VD->getTypeSpecStartLoc(), VD->getName());
+        } else if (CastExpr *CE = dyn_cast<CastExpr>(VD->getInit())) {
+          RewriteCastExpr(CE);
+        }
+      }
+    } else if (VD->getType()->isFunctionPointerType()) {
+      CheckFunctionPointerDecl(VD->getType(), VD);
+      if (VD->getInit()) {
+        if (CastExpr *CE = dyn_cast<CastExpr>(VD->getInit())) {
+          RewriteCastExpr(CE);
         }
       }
     }
@@ -1066,6 +1125,8 @@ void RewriteBlocks::HandleDeclInMainFile(Decl *D) {
   if (TypedefDecl *TD = dyn_cast<TypedefDecl>(D)) {
     if (isBlockPointerType(TD->getUnderlyingType()))
       RewriteBlockPointerDecl(TD);
+    else if (TD->getUnderlyingType()->isFunctionPointerType()) 
+      CheckFunctionPointerDecl(TD->getUnderlyingType(), TD);
     return;
   }
   if (RecordDecl *RD = dyn_cast<RecordDecl>(D)) {
index 0a6bde0886ad368518ebfdd6d8ba5409ba637cd7..82b63a09f9d1654e1653b091cb8338b0e27f4a64 100644 (file)
@@ -1,5 +1,16 @@
 // RUN: clang -rewrite-blocks %s -o -
 
+static int (^block)(const void *, const void *) = (int (^)(const void *, const void *))0;
+static int (*func)(int (^block)(void *, void *)) = (int (*)(int (^block)(void *, void *)))0;
+
+typedef int (^block_T)(const void *, const void *);
+typedef int (*func_T)(int (^block)(void *, void *));
+
+void foo(const void *a, const void *b, void *c) {
+    int (^block)(const void *, const void *) = (int (^)(const void *, const void *))c;
+    int (*func)(int (^block)(void *, void *)) = (int (*)(int (^block)(void *, void *)))c;
+}
+
 typedef void (^test_block_t)();
 
 int main(int argc, char **argv) {