]> granicus.if.org Git - clang/commitdiff
During codegen of a virtual call we would extract any casts in the expression
authorRafael Espindola <rafael.espindola@gmail.com>
Tue, 26 Jun 2012 17:45:31 +0000 (17:45 +0000)
committerRafael Espindola <rafael.espindola@gmail.com>
Tue, 26 Jun 2012 17:45:31 +0000 (17:45 +0000)
to see if we had an underlying final class or method, but we would then
use the cast type to do the call, resulting in a direct call to the wrong
method.

git-svn-id: https://llvm.org/svn/llvm-project/cfe/trunk@159212 91177308-0d34-0410-b5e6-96231b3b80d8

include/clang/AST/DeclCXX.h
include/clang/AST/Expr.h
lib/AST/DeclCXX.cpp
lib/AST/Expr.cpp
lib/CodeGen/CGExprCXX.cpp
lib/Sema/SemaExpr.cpp
test/CodeGenCXX/devirtualize-virtual-function-calls-final.cpp

index 6fe71c8bb1f2c42160357d4805a9a875dbf078ff..8fde3ed257e00c4c8de221f481f4ddd3fec62f50 100644 (file)
@@ -1630,7 +1630,20 @@ public:
   /// supplied by IR generation to either forward to the function call operator
   /// or clone the function call operator.
   bool isLambdaStaticInvoker() const;
-  
+
+  /// \brief Find the method in RD that corresponds to this one.
+  ///
+  /// Find if RD or one of the classes it inherits from override this method.
+  /// If so, return it. RD is assumed to be a base class of the class defining
+  /// this method (or be the class itself).
+  CXXMethodDecl *
+  getCorrespondingMethodInClass(const CXXRecordDecl *RD);
+
+  const CXXMethodDecl *
+  getCorrespondingMethodInClass(const CXXRecordDecl *RD) const {
+    return const_cast<CXXMethodDecl*>(this)->getCorrespondingMethodInClass(RD);
+  }
+
   // Implement isa/cast/dyncast/etc.
   static bool classof(const Decl *D) { return classofKind(D->getKind()); }
   static bool classof(const CXXMethodDecl *D) { return true; }
index cdf9a58e8ca532ff4da3221d41d2dda5f3abfc76..77684eee00c4ef7f99387f56a5a27d497aafb1a0 100644 (file)
@@ -665,6 +665,13 @@ public:
 
   static bool hasAnyTypeDependentArguments(llvm::ArrayRef<Expr *> Exprs);
 
+  /// \brief If we have class type (or pointer to class type), return the
+  /// class decl. Return NULL otherwise.
+  ///
+  /// If this expression is a cast, this method looks through it to find the
+  /// most derived decl that can be infered from the expression.
+  const CXXRecordDecl *getMostDerivedClassDeclForType() const;
+
   static bool classof(const Stmt *T) {
     return T->getStmtClass() >= firstExprConstant &&
            T->getStmtClass() <= lastExprConstant;
index 3d76e68ff78fc2dc5ca435194a56ba8358d019c7..a8aabb68d759c894e89876a034036e15d5c4eca6 100644 (file)
@@ -1269,6 +1269,55 @@ bool CXXRecordDecl::mayBeAbstract() const {
 
 void CXXMethodDecl::anchor() { }
 
+static bool recursivelyOverrides(const CXXMethodDecl *DerivedMD,
+                                 const CXXMethodDecl *BaseMD) {
+  for (CXXMethodDecl::method_iterator I = DerivedMD->begin_overridden_methods(),
+         E = DerivedMD->end_overridden_methods(); I != E; ++I) {
+    const CXXMethodDecl *MD = *I;
+    if (MD->getCanonicalDecl() == BaseMD->getCanonicalDecl())
+      return true;
+    if (recursivelyOverrides(MD, BaseMD))
+      return true;
+  }
+  return false;
+}
+
+CXXMethodDecl *
+CXXMethodDecl::getCorrespondingMethodInClass(const CXXRecordDecl *RD) {
+  if (this->getParent()->getCanonicalDecl() == RD->getCanonicalDecl())
+    return this;
+
+  // Lookup doesn't work for destructors, so handle them separately.
+  if (isa<CXXDestructorDecl>(this)) {
+    CXXMethodDecl *MD = RD->getDestructor();
+    if (recursivelyOverrides(MD, this))
+      return MD;
+    return NULL;
+  }
+
+  lookup_const_result Candidates = RD->lookup(getDeclName());
+  for (NamedDecl * const * I = Candidates.first; I != Candidates.second; ++I) {
+    CXXMethodDecl *MD = dyn_cast<CXXMethodDecl>(*I);
+    if (!MD)
+      continue;
+    if (recursivelyOverrides(MD, this))
+      return MD;
+  }
+
+  for (CXXRecordDecl::base_class_const_iterator I = RD->bases_begin(),
+         E = RD->bases_end(); I != E; ++I) {
+    const RecordType *RT = I->getType()->getAs<RecordType>();
+    if (!RT)
+      continue;
+    const CXXRecordDecl *Base = cast<CXXRecordDecl>(RT->getDecl());
+    CXXMethodDecl *T = this->getCorrespondingMethodInClass(Base);
+    if (T)
+      return T;
+  }
+
+  return NULL;
+}
+
 CXXMethodDecl *
 CXXMethodDecl::Create(ASTContext &C, CXXRecordDecl *RD,
                       SourceLocation StartLoc,
index 3de4b5771a252164eb696bceffd0e2b43c53320e..22d15be6a87c8908f2381ba42e42b900a2715d1c 100644 (file)
 #include <cstring>
 using namespace clang;
 
+const CXXRecordDecl *Expr::getMostDerivedClassDeclForType() const {
+  const Expr *E = this;
+
+  while (true) {
+    E = E->IgnoreParens();
+    if (const CastExpr *CE = dyn_cast<CastExpr>(E)) {
+      if (CE->getCastKind() == CK_DerivedToBase ||
+          CE->getCastKind() == CK_UncheckedDerivedToBase ||
+          CE->getCastKind() == CK_NoOp) {
+        E = CE->getSubExpr();
+        continue;
+      }
+    }
+
+    break;
+  }
+
+  QualType DerivedType = E->getType();
+  if (DerivedType->isDependentType())
+    return NULL;
+  if (const PointerType *PTy = DerivedType->getAs<PointerType>())
+    DerivedType = PTy->getPointeeType();
+
+  const RecordType *Ty = DerivedType->castAs<RecordType>();
+  if (!Ty)
+    return NULL;
+
+  Decl *D = Ty->getDecl();
+  return cast<CXXRecordDecl>(D);
+}
+
 /// isKnownToHaveBooleanValue - Return true if this is an integer expression
 /// that is known to return 0 or 1.  This happens for _Bool/bool expressions
 /// but also int expressions which are produced by things like comparisons in
index f1d03957140b30224541789ae58e14c258dd0d80..f35287d5406f8b977c34bc4fab92c2553c2af4a5 100644 (file)
@@ -56,30 +56,6 @@ RValue CodeGenFunction::EmitCXXMemberCall(const CXXMethodDecl *MD,
                   Callee, ReturnValue, Args, MD);
 }
 
-static const CXXRecordDecl *getMostDerivedClassDecl(const Expr *Base) {
-  const Expr *E = Base;
-  
-  while (true) {
-    E = E->IgnoreParens();
-    if (const CastExpr *CE = dyn_cast<CastExpr>(E)) {
-      if (CE->getCastKind() == CK_DerivedToBase || 
-          CE->getCastKind() == CK_UncheckedDerivedToBase ||
-          CE->getCastKind() == CK_NoOp) {
-        E = CE->getSubExpr();
-        continue;
-      }
-    }
-
-    break;
-  }
-
-  QualType DerivedType = E->getType();
-  if (const PointerType *PTy = DerivedType->getAs<PointerType>())
-    DerivedType = PTy->getPointeeType();
-
-  return cast<CXXRecordDecl>(DerivedType->castAs<RecordType>()->getDecl());
-}
-
 // FIXME: Ideally Expr::IgnoreParenNoopCasts should do this, but it doesn't do
 // quite what we want.
 static const Expr *skipNoOpCastsAndParens(const Expr *E) {
@@ -126,7 +102,8 @@ static bool canDevirtualizeMemberFunctionCalls(ASTContext &Context,
   //   b->f();
   // }
   //
-  const CXXRecordDecl *MostDerivedClassDecl = getMostDerivedClassDecl(Base);
+  const CXXRecordDecl *MostDerivedClassDecl =
+    Base->getMostDerivedClassDeclForType();
   if (MostDerivedClassDecl->hasAttr<FinalAttr>())
     return true;
 
@@ -247,10 +224,13 @@ RValue CodeGenFunction::EmitCXXMemberCallExpr(const CXXMemberCallExpr *CE,
   //
   // We also don't emit a virtual call if the base expression has a record type
   // because then we know what the type is.
-  bool UseVirtualCall;
-  UseVirtualCall = MD->isVirtual() && !ME->hasQualifier()
-                   && !canDevirtualizeMemberFunctionCalls(getContext(),
-                                                          ME->getBase(), MD);
+  const Expr *Base = ME->getBase();
+  bool UseVirtualCall = MD->isVirtual() && !ME->hasQualifier()
+                        && !canDevirtualizeMemberFunctionCalls(getContext(),
+                                                               Base, MD);
+  const CXXRecordDecl *MostDerivedClassDecl =
+    Base->getMostDerivedClassDeclForType();
+
   llvm::Value *Callee;
   if (const CXXDestructorDecl *Dtor = dyn_cast<CXXDestructorDecl>(MD)) {
     if (UseVirtualCall) {
@@ -260,8 +240,13 @@ RValue CodeGenFunction::EmitCXXMemberCallExpr(const CXXMemberCallExpr *CE,
           MD->isVirtual() &&
           ME->hasQualifier())
         Callee = BuildAppleKextVirtualCall(MD, ME->getQualifier(), Ty);
-      else
-        Callee = CGM.GetAddrOfFunction(GlobalDecl(Dtor, Dtor_Complete), Ty);
+      else {
+        const CXXMethodDecl *DM =
+          Dtor->getCorrespondingMethodInClass(MostDerivedClassDecl);
+        assert(DM);
+        const CXXDestructorDecl *DDtor = cast<CXXDestructorDecl>(DM);
+        Callee = CGM.GetAddrOfFunction(GlobalDecl(DDtor, Dtor_Complete), Ty);
+      }
     }
   } else if (const CXXConstructorDecl *Ctor =
                dyn_cast<CXXConstructorDecl>(MD)) {
@@ -273,8 +258,12 @@ RValue CodeGenFunction::EmitCXXMemberCallExpr(const CXXMemberCallExpr *CE,
         MD->isVirtual() &&
         ME->hasQualifier())
       Callee = BuildAppleKextVirtualCall(MD, ME->getQualifier(), Ty);
-    else 
-      Callee = CGM.GetAddrOfFunction(MD, Ty);
+    else {
+      const CXXMethodDecl *DerivedMethod =
+        MD->getCorrespondingMethodInClass(MostDerivedClassDecl);
+      assert(DerivedMethod);
+      Callee = CGM.GetAddrOfFunction(DerivedMethod, Ty);
+    }
   }
 
   return EmitCXXMemberCall(MD, Callee, ReturnValue, This, /*VTT=*/0,
index 07f3c1d6c755e1e325652b5fe499bcdf17ffa8a1..4dd6f0b5524af31e52ca853ea06f07491f47d9f0 100644 (file)
@@ -10844,6 +10844,22 @@ static void MarkExprReferenced(Sema &SemaRef, SourceLocation Loc,
   }
 
   SemaRef.MarkAnyDeclReferenced(Loc, D);
+
+  // If this is a call to a method via a cast, also mark the method in the
+  // derived class used in case codegen can devirtualize the call.
+  const MemberExpr *ME = dyn_cast<MemberExpr>(E);
+  if (!ME)
+    return;
+  CXXMethodDecl *MD = dyn_cast<CXXMethodDecl>(ME->getMemberDecl());
+  if (!MD)
+    return;
+  const Expr *Base = ME->getBase();
+  const CXXRecordDecl *MostDerivedClassDecl
+    = Base->getMostDerivedClassDeclForType();
+  if (!MostDerivedClassDecl)
+    return;
+  CXXMethodDecl *DM = MD->getCorrespondingMethodInClass(MostDerivedClassDecl);
+  SemaRef.MarkAnyDeclReferenced(Loc, DM);
 } 
 
 /// \brief Perform reference-marking and odr-use handling for a DeclRefExpr.
index 3de75ed3db541b6b17ec67148a7b9c238a6267fb..f7b10f647bf49a7708c1a46f33b78d70401a660b 100644 (file)
@@ -1,4 +1,4 @@
-// RUN: %clang_cc1 %s -emit-llvm -o - | FileCheck %s
+// RUN: %clang_cc1 -std=c++11 %s -emit-llvm -o - | FileCheck %s
 
 namespace Test1 {
   struct A {
@@ -49,3 +49,61 @@ namespace Test3 {
     return static_cast<B*>(v)->f();
   }
 }
+
+namespace Test4 {
+  struct A {
+    virtual void f();
+  };
+
+  struct B final : A {
+    virtual void f();
+  };
+
+  // CHECK: define void @_ZN5Test41fEPNS_1BE
+  void f(B* d) {
+    // CHECK: call void @_ZN5Test41B1fEv
+    static_cast<A*>(d)->f();
+  }
+}
+
+namespace Test5 {
+  struct A {
+    virtual void f();
+  };
+
+  struct B : A {
+    virtual void f();
+  };
+
+  struct C final : B {
+  };
+
+  // CHECK: define void @_ZN5Test51fEPNS_1CE
+  void f(C* d) {
+    // CHECK: call void @_ZN5Test51B1fEv
+    static_cast<A*>(d)->f();
+  }
+}
+
+namespace Test6 {
+  struct A {
+    virtual ~A();
+  };
+
+  struct B : public A {
+    virtual ~B();
+  };
+
+  struct C {
+    virtual ~C();
+  };
+
+  struct D final : public C, public B {
+  };
+
+  // CHECK: define void @_ZN5Test61fEPNS_1DE
+  void f(D* d) {
+    // CHECK: call void @_ZN5Test61DD1Ev
+    static_cast<A*>(d)->~A();
+  }
+}