]> granicus.if.org Git - clang/commitdiff
Change GetAddressCXXOfBaseClass to use CXXBasePaths for calculating base class offset...
authorAnders Carlsson <andersca@mac.com>
Tue, 6 Oct 2009 22:43:30 +0000 (22:43 +0000)
committerAnders Carlsson <andersca@mac.com>
Tue, 6 Oct 2009 22:43:30 +0000 (22:43 +0000)
git-svn-id: https://llvm.org/svn/llvm-project/cfe/trunk@83426 91177308-0d34-0410-b5e6-96231b3b80d8

lib/CodeGen/CGCXX.cpp
lib/CodeGen/CGCXXClass.cpp
lib/CodeGen/CodeGenFunction.h
test/CodeGenCXX/virtual-base-cast.cpp [new file with mode: 0644]

index a119c5af932e0c19ab064523b681774ea1af4ccc..20b2bdcd5e188fe5542796facba6d0c1deffdde4 100644 (file)
@@ -782,8 +782,8 @@ public:
       return i->second;
     // FIXME: temporal botch, is this data here, by the time we need it?
 
-    // FIXME: Locate the containing virtual base first.
-    return 42;
+    assert(false && "FIXME: Locate the containing virtual base first");
+    return 0;
   }
 
   bool OverrideMethod(const CXXMethodDecl *MD, llvm::Constant *m,
@@ -888,18 +888,26 @@ public:
       const CXXRecordDecl *RD = i->first;
       int64_t Offset = i->second;
       for (method_iter mi = RD->method_begin(), me = RD->method_end(); mi != me;
-           ++mi)
-        if (mi->isVirtual()) {
-          const CXXMethodDecl *MD = *mi;
+           ++mi) {
+        if (!mi->isVirtual())
+          continue;
+
+        const CXXMethodDecl *MD = *mi;
+        llvm::Constant *m = 0;
+//        if (const CXXDestructorDecl *Dtor = dyn_cast<CXXDestructorDecl>(MD))
+//          m = wrap(CGM.GetAddrOfCXXDestructor(Dtor, Dtor_Complete));
+//        else {
           const FunctionProtoType *FPT = 
             MD->getType()->getAs<FunctionProtoType>();
           const llvm::Type *Ty =
             CGM.getTypes().GetFunctionType(CGM.getTypes().getFunctionInfo(MD),
                                            FPT->isVariadic());
           
-          llvm::Constant *m = wrap(CGM.GetAddrOfFunction(MD, Ty));
-          OverrideMethod(MD, m, MorallyVirtual, Offset);
-        }
+          m = wrap(CGM.GetAddrOfFunction(MD, Ty));
+//        }
+
+        OverrideMethod(MD, m, MorallyVirtual, Offset);
+      }
     }
   }
 
@@ -1322,6 +1330,36 @@ llvm::Constant *CodeGenModule::BuildCovariantThunk(const CXXMethodDecl *MD,
   return m;
 }
 
+llvm::Value *
+CodeGenFunction::GetVirtualCXXBaseClassOffset(llvm::Value *This,
+                                              const CXXRecordDecl *ClassDecl,
+                                           const CXXRecordDecl *BaseClassDecl) {
+  // FIXME: move to Context
+  if (vtableinfo == 0)
+    vtableinfo = new VtableInfo(CGM);
+  
+  const llvm::Type *Int8PtrTy = 
+    llvm::Type::getInt8Ty(VMContext)->getPointerTo();
+
+  llvm::Value *VTablePtr = Builder.CreateBitCast(This, 
+                                                 Int8PtrTy->getPointerTo());
+  VTablePtr = Builder.CreateLoad(VTablePtr, "vtable");
+
+  llvm::Value *VBaseOffsetPtr = 
+    Builder.CreateConstGEP1_64(VTablePtr, 
+                               vtableinfo->VBlookup(ClassDecl, BaseClassDecl),
+                               "vbase.offset.ptr");
+  const llvm::Type *PtrDiffTy = 
+    ConvertType(getContext().getPointerDiffType());
+  
+  VBaseOffsetPtr = Builder.CreateBitCast(VBaseOffsetPtr, 
+                                         PtrDiffTy->getPointerTo());
+                                         
+  llvm::Value *VBaseOffset = Builder.CreateLoad(VBaseOffsetPtr, "vbase.offset");
+  
+  return VBaseOffset;
+}
+
 llvm::Value *
 CodeGenFunction::BuildVirtualCall(const CXXMethodDecl *MD, llvm::Value *&This,
                                   const llvm::Type *Ty) {
index 9c8174bc22f71e1ccad46ae646ef2500b6ec256f..ff879f5786ffc7bc11f312d4f3b10f4e72a752f2 100644 (file)
 //===----------------------------------------------------------------------===//
 
 #include "CodeGenFunction.h"
+#include "clang/AST/CXXInheritance.h"
 #include "clang/AST/RecordLayout.h"
+
 using namespace clang;
 using namespace CodeGen;
 
-static bool
-GetNestedPaths(llvm::SmallVectorImpl<const CXXRecordDecl *> &NestedBasePaths,
-               const CXXRecordDecl *ClassDecl,
-               const CXXRecordDecl *BaseClassDecl) {
-  for (CXXRecordDecl::base_class_const_iterator i = ClassDecl->bases_begin(),
-      e = ClassDecl->bases_end(); i != e; ++i) {
-    if (i->isVirtual())
-      continue;
-    const CXXRecordDecl *Base =
-      cast<CXXRecordDecl>(i->getType()->getAs<RecordType>()->getDecl());
-    if (Base == BaseClassDecl) {
-      NestedBasePaths.push_back(BaseClassDecl);
-      return true;
-    }
-  }
-  // BaseClassDecl not an immediate base of ClassDecl.
-  for (CXXRecordDecl::base_class_const_iterator i = ClassDecl->bases_begin(),
-       e = ClassDecl->bases_end(); i != e; ++i) {
-    if (i->isVirtual())
-      continue;
-    const CXXRecordDecl *Base =
-      cast<CXXRecordDecl>(i->getType()->getAs<RecordType>()->getDecl());
-    if (GetNestedPaths(NestedBasePaths, Base, BaseClassDecl)) {
-      NestedBasePaths.push_back(Base);
-      return true;
-    }
-  }
-  return false;
-}
+static uint64_t 
+ComputeNonVirtualBaseClassOffset(ASTContext &Context, CXXBasePaths &Paths,
+                                 unsigned Start) {
+  uint64_t Offset = 0;
 
-static uint64_t ComputeBaseClassOffset(ASTContext &Context,
-                                       const CXXRecordDecl *ClassDecl,
-                                       const CXXRecordDecl *BaseClassDecl) {
-    uint64_t Offset = 0;
+  const CXXBasePath &Path = Paths.front();
+  for (unsigned i = Start, e = Path.size(); i != e; ++i) {
+    const CXXBasePathElement& Element = Path[i];
 
-    llvm::SmallVector<const CXXRecordDecl *, 16> NestedBasePaths;
-    GetNestedPaths(NestedBasePaths, ClassDecl, BaseClassDecl);
-    assert(NestedBasePaths.size() > 0 &&
-           "AddressCXXOfBaseClass - inheritence path failed");
-    NestedBasePaths.push_back(ClassDecl);
+    // Get the layout.
+    const ASTRecordLayout &Layout = Context.getASTRecordLayout(Element.Class);
     
-    for (unsigned i = NestedBasePaths.size() - 1; i > 0; i--) {
-        const CXXRecordDecl *DerivedClass = NestedBasePaths[i];
-        const CXXRecordDecl *BaseClass = NestedBasePaths[i-1];
-        const ASTRecordLayout &Layout = 
-            Context.getASTRecordLayout(DerivedClass);
-        
-        Offset += Layout.getBaseClassOffset(BaseClass) / 8;
-    }
+    const CXXBaseSpecifier *BS = Element.Base;
+    assert(!BS->isVirtual() && "Should not see virtual bases here!");
     
-    return Offset;
+    const CXXRecordDecl *Base = 
+      cast<CXXRecordDecl>(BS->getType()->getAs<RecordType>()->getDecl());
+    
+    // Add the offset.
+    Offset += Layout.getBaseClassOffset(Base) / 8;
+  }
+
+  return Offset;
 }
 
 llvm::Constant *
@@ -75,12 +49,15 @@ CodeGenModule::GetCXXBaseClassOffset(const CXXRecordDecl *ClassDecl,
   if (ClassDecl == BaseClassDecl)
     return 0;
 
-  QualType BTy =
-    getContext().getCanonicalType(
-      getContext().getTypeDeclType(const_cast<CXXRecordDecl*>(BaseClassDecl)));
+  CXXBasePaths Paths(/*FindAmbiguities=*/false,
+                     /*RecordPaths=*/true, /*DetectVirtual=*/false);
+  if (!const_cast<CXXRecordDecl *>(ClassDecl)->
+        isDerivedFrom(const_cast<CXXRecordDecl *>(BaseClassDecl), Paths)) {
+    assert(false && "Class must be derived from the passed in base class!");
+    return 0;
+  }
 
-  uint64_t Offset = ComputeBaseClassOffset(getContext(), 
-                                           ClassDecl, BaseClassDecl);
+  uint64_t Offset = ComputeNonVirtualBaseClassOffset(getContext(), Paths, 0);
   if (!Offset)
     return 0;
 
@@ -90,19 +67,63 @@ CodeGenModule::GetCXXBaseClassOffset(const CXXRecordDecl *ClassDecl,
   return llvm::ConstantInt::get(PtrDiffTy, Offset);
 }
 
+static llvm::Value *GetCXXBaseClassOffset(CodeGenFunction &CGF,
+                                          llvm::Value *BaseValue,
+                                          const CXXRecordDecl *ClassDecl,
+                                          const CXXRecordDecl *BaseClassDecl) {
+  CXXBasePaths Paths(/*FindAmbiguities=*/false,
+                     /*RecordPaths=*/true, /*DetectVirtual=*/true);
+  if (!const_cast<CXXRecordDecl *>(ClassDecl)->
+        isDerivedFrom(const_cast<CXXRecordDecl *>(BaseClassDecl), Paths)) {
+    assert(false && "Class must be derived from the passed in base class!");
+    return 0;
+  }
+
+  unsigned Start = 0;
+  llvm::Value *VirtualOffset = 0;
+  if (const RecordType *RT = Paths.getDetectedVirtual()) {
+    const CXXRecordDecl *VBase = cast<CXXRecordDecl>(RT->getDecl());
+    
+    VirtualOffset = 
+      CGF.GetVirtualCXXBaseClassOffset(BaseValue, ClassDecl, VBase);
+    
+    const CXXBasePath &Path = Paths.front();
+    unsigned e = Path.size();
+    for (Start = 0; Start != e; ++Start) {
+      const CXXBasePathElement& Element = Path[Start];
+      
+      if (Element.Class == VBase)
+        break;
+    }
+  }
+  
+  uint64_t Offset = 
+    ComputeNonVirtualBaseClassOffset(CGF.getContext(), Paths, Start);
+  
+  if (!Offset)
+    return VirtualOffset;
+  
+  const llvm::Type *PtrDiffTy = 
+    CGF.ConvertType(CGF.getContext().getPointerDiffType());
+  llvm::Value *NonVirtualOffset = llvm::ConstantInt::get(PtrDiffTy, Offset);
+  
+  if (VirtualOffset)
+    return CGF.Builder.CreateAdd(VirtualOffset, NonVirtualOffset);
+                    
+  return NonVirtualOffset;
+}
+
 llvm::Value *
 CodeGenFunction::GetAddressCXXOfBaseClass(llvm::Value *BaseValue,
                                           const CXXRecordDecl *ClassDecl,
                                           const CXXRecordDecl *BaseClassDecl,
                                           bool NullCheckValue) {
-  llvm::Constant *Offset = CGM.GetCXXBaseClassOffset(ClassDecl, BaseClassDecl);
-  
   QualType BTy =
     getContext().getCanonicalType(
       getContext().getTypeDeclType(const_cast<CXXRecordDecl*>(BaseClassDecl)));
   const llvm::Type *BasePtrTy = llvm::PointerType::getUnqual(ConvertType(BTy));
 
-  if (!Offset) {
+  if (ClassDecl == BaseClassDecl) {
     // Just cast back.
     return Builder.CreateBitCast(BaseValue, BasePtrTy);
   }
@@ -125,10 +146,15 @@ CodeGenFunction::GetAddressCXXOfBaseClass(llvm::Value *BaseValue,
   
   const llvm::Type *Int8PtrTy = 
     llvm::PointerType::getUnqual(llvm::Type::getInt8Ty(VMContext));
+
+  llvm::Value *Offset = 
+    GetCXXBaseClassOffset(*this, BaseValue, ClassDecl, BaseClassDecl);
   
-  // Apply the offset.
-  BaseValue = Builder.CreateBitCast(BaseValue, Int8PtrTy);
-  BaseValue = Builder.CreateGEP(BaseValue, Offset, "add.ptr");
+  if (Offset) {
+    // Apply the offset.
+    BaseValue = Builder.CreateBitCast(BaseValue, Int8PtrTy);
+    BaseValue = Builder.CreateGEP(BaseValue, Offset, "add.ptr");
+  }
   
   // Cast back.
   BaseValue = Builder.CreateBitCast(BaseValue, BasePtrTy);
index 34b3860a48f8d3428b0330598164b579be19bdfd..42de9fb62e688043985da94802eff022f0b13c4e 100644 (file)
@@ -589,6 +589,11 @@ public:
                                         const CXXRecordDecl *BaseClassDecl,
                                         bool NullCheckValue);
   
+  llvm::Value *
+  GetVirtualCXXBaseClassOffset(llvm::Value *This,
+                               const CXXRecordDecl *ClassDecl,
+                               const CXXRecordDecl *BaseClassDecl);
+    
   void EmitClassAggrMemberwiseCopy(llvm::Value *DestValue,
                                    llvm::Value *SrcValue,
                                    const ArrayType *Array,
diff --git a/test/CodeGenCXX/virtual-base-cast.cpp b/test/CodeGenCXX/virtual-base-cast.cpp
new file mode 100644 (file)
index 0000000..9a728a8
--- /dev/null
@@ -0,0 +1,9 @@
+// RUN: clang-cc -emit-llvm-only %s
+
+struct A { virtual ~A(); };
+struct B : A { virtual ~B(); };
+struct C : virtual B { virtual ~C(); };
+
+void f(C *c) {
+  A* a = c;
+}
\ No newline at end of file