]> granicus.if.org Git - clang/commitdiff
Implement virtual dispatch. :-) This is self-consistent with clang, but not yet
authorMike Stump <mrs@apple.com>
Wed, 26 Aug 2009 01:54:35 +0000 (01:54 +0000)
committerMike Stump <mrs@apple.com>
Wed, 26 Aug 2009 01:54:35 +0000 (01:54 +0000)
necessarily perfectly consistent with gcc.

git-svn-id: https://llvm.org/svn/llvm-project/cfe/trunk@80064 91177308-0d34-0410-b5e6-96231b3b80d8

lib/CodeGen/CGCXX.cpp
lib/CodeGen/CodeGenFunction.h
test/CodeGenCXX/virt.cpp

index 3efda13db8959417889fddd41aa800578e1f88a8..4bf6a49774eba7f2351f506de8ff4ac90f456377 100644 (file)
@@ -200,15 +200,9 @@ RValue CodeGenFunction::EmitCXXMemberCallExpr(const CXXMemberCallExpr *CE) {
 
   const FunctionProtoType *FPT = MD->getType()->getAsFunctionProtoType();
 
-  if (MD->isVirtual()) {
-    ErrorUnsupported(CE, "virtual dispatch");
-  }
-
   const llvm::Type *Ty = 
     CGM.getTypes().GetFunctionType(CGM.getTypes().getFunctionInfo(MD), 
                                    FPT->isVariadic());
-  llvm::Constant *Callee = CGM.GetAddrOfFunction(GlobalDecl(MD), Ty);
-  
   llvm::Value *This;
   
   if (ME->isArrow())
@@ -217,6 +211,12 @@ RValue CodeGenFunction::EmitCXXMemberCallExpr(const CXXMemberCallExpr *CE) {
     LValue BaseLV = EmitLValue(ME->getBase());
     This = BaseLV.getAddress();
   }
+
+  llvm::Value *Callee;
+  if (MD->isVirtual())
+    Callee = BuildVirtualCall(MD, This, Ty);
+  else
+    Callee = CGM.GetAddrOfFunction(GlobalDecl(MD), Ty);
   
   return EmitCXXMemberCall(MD, Callee, This, 
                            CE->arg_begin(), CE->arg_end());
@@ -826,6 +826,10 @@ llvm::Constant *CodeGenModule::GenerateRtti(const CXXRecordDecl *RD) {
 }
 
 class VtableBuilder {
+public:
+  /// Index_t - Vtable index type.
+  typedef uint64_t Index_t;
+private:
   std::vector<llvm::Constant *> &methods;
   llvm::Type *Ptr8Ty;
   /// Class - The most derived class that this vtable is being built for.
@@ -840,7 +844,7 @@ class VtableBuilder {
   CodeGenModule &CGM;  // Per-module state.
   /// Index - Maps a method decl into a vtable index.  Useful for virtual
   /// dispatch codegen.
-  llvm::DenseMap<const CXXMethodDecl *, int32_t> Index;
+  llvm::DenseMap<const CXXMethodDecl *, Index_t> Index;
   typedef CXXRecordDecl::method_iterator method_iter;
 public:
   VtableBuilder(std::vector<llvm::Constant *> &meth,
@@ -852,6 +856,7 @@ public:
     Ptr8Ty = llvm::PointerType::get(llvm::Type::getInt8Ty(VMContext), 0);
   }
 
+  llvm::DenseMap<const CXXMethodDecl *, Index_t> &getIndex() { return Index; }
   llvm::Constant *GenerateVcall(const CXXMethodDecl *MD,
                                 const CXXRecordDecl *RD,
                                 bool VBoundary,
@@ -932,17 +937,17 @@ public:
     SeenVBase.clear();
   }
 
-  inline uint32_t nottoobig(uint64_t t) {
-    assert(t < (uint32_t)-1ULL || "vtable too big");
+  inline Index_t nottoobig(uint64_t t) {
+    assert(t < (Index_t)-1ULL || "vtable too big");
     return t;
   }
 #if 0
-  inline uint32_t nottoobig(uint32_t t) {
+  inline Index_t nottoobig(Index_t t) {
     return t;
   }
 #endif
 
-  void AddMethod(const CXXMethodDecl *MD, int32_t FirstIndex) {
+  void AddMethod(const CXXMethodDecl *MD, Index_t AddressPoint) {
     typedef CXXMethodDecl::method_iterator meth_iter;
 
     llvm::Constant *m;
@@ -963,34 +968,34 @@ public:
       om = CGM.GetAddrOfFunction(GlobalDecl(OMD), Ptr8Ty);
       om = llvm::ConstantExpr::getBitCast(om, Ptr8Ty);
 
-      for (int32_t i = FirstIndex, e = nottoobig(methods.size()); i != e; ++i) {
+      for (Index_t i = AddressPoint, e = methods.size();
+           i != e; ++i) {
         // FIXME: begin_overridden_methods might be too lax, covariance */
         if (methods[i] == om) {
           methods[i] = m;
-          Index[MD] = i;
+          Index[MD] = i - AddressPoint;
           return;
         }
       }
     }
 
     // else allocate a new slot.
-    Index[MD] = methods.size();
+    Index[MD] = methods.size() - AddressPoint;
     methods.push_back(m);
   }
 
-  void GenerateMethods(const CXXRecordDecl *RD, int32_t FirstIndex) {
+  void GenerateMethods(const CXXRecordDecl *RD, Index_t AddressPoint) {
     for (method_iter mi = RD->method_begin(), me = RD->method_end(); mi != me;
          ++mi)
       if (mi->isVirtual())
-        AddMethod(*mi, FirstIndex);
+        AddMethod(*mi, AddressPoint);
   }
 
   int64_t GenerateVtableForBase(const CXXRecordDecl *RD,
                                 bool forPrimary,
                                 bool VBoundary,
                                 int64_t Offset,
-                                bool ForVirtualBase,
-                                int32_t FirstIndex) {
+                                bool ForVirtualBase) {
     llvm::Constant *m = llvm::Constant::getNullValue(Ptr8Ty);
     int64_t AddressPoint=0;
 
@@ -1023,8 +1028,9 @@ public:
       if (PrimaryBaseWasVirtual)
         IndirectPrimary.insert(PrimaryBase);
       Top = false;
-      AddressPoint = GenerateVtableForBase(PrimaryBase, true, PrimaryBaseWasVirtual|VBoundary,
-                                           Offset, PrimaryBaseWasVirtual, FirstIndex);
+      AddressPoint = GenerateVtableForBase(PrimaryBase, true,
+                                           PrimaryBaseWasVirtual|VBoundary,
+                                           Offset, PrimaryBaseWasVirtual);
     }
 
     if (Top) {
@@ -1041,7 +1047,7 @@ public:
     }
 
     // And add the virtuals for the class to the primary vtable.
-    GenerateMethods(RD, FirstIndex);
+    GenerateMethods(RD, AddressPoint);
 
     // and then the non-virtual bases.
     for (CXXRecordDecl::base_class_const_iterator i = RD->bases_begin(),
@@ -1053,8 +1059,7 @@ public:
       if (Base != PrimaryBase || PrimaryBaseWasVirtual) {
         uint64_t o = Offset + Layout.getBaseClassOffset(Base);
         StartNewTable();
-        FirstIndex = methods.size();
-        GenerateVtableForBase(Base, true, false, o, false, FirstIndex);
+        GenerateVtableForBase(Base, true, false, o, false);
       }
     }
     return AddressPoint;
@@ -1071,8 +1076,7 @@ public:
         IndirectPrimary.insert(Base);
         StartNewTable();
         int64_t BaseOffset = BLayout.getVBaseClassOffset(Base);
-        int32_t FirstIndex = methods.size();
-        GenerateVtableForBase(Base, false, true, BaseOffset, true, FirstIndex);
+        GenerateVtableForBase(Base, false, true, BaseOffset, true);
       }
       if (Base->getNumVBases())
         GenerateVtableForVBases(Base, Class);
@@ -1080,6 +1084,43 @@ public:
   }
 };
 
+class VtableInfo {
+public:
+  typedef VtableBuilder::Index_t Index_t;
+private:
+  CodeGenModule &CGM;  // Per-module state.
+  /// Index_t - Vtable index type.
+  typedef llvm::DenseMap<const CXXMethodDecl *, Index_t> ElTy;
+  typedef llvm::DenseMap<const CXXRecordDecl *, ElTy *> MapTy;
+  // FIXME: Move to Context.
+  static MapTy IndexFor;
+public:
+  VtableInfo(CodeGenModule &cgm) : CGM(cgm) { }
+  void register_index(const CXXRecordDecl *RD, const ElTy &e) {
+    assert(IndexFor.find(RD) == IndexFor.end() || "Don't compute vtbl twice");
+    // We own a copy of this, it will go away shortly.
+    new ElTy (e);
+    IndexFor[RD] = new ElTy (e);
+  }
+  Index_t lookup(const CXXMethodDecl *MD) {
+    const CXXRecordDecl *RD = MD->getParent();
+    MapTy::iterator I = IndexFor.find(RD);
+    if (I == IndexFor.end()) {
+      std::vector<llvm::Constant *> methods;
+      VtableBuilder b(methods, RD, CGM);
+      b.GenerateVtableForBase(RD, true, false, 0, false);
+      b.GenerateVtableForVBases(RD, RD);
+      register_index(RD, b.getIndex());
+      I = IndexFor.find(RD);
+    }
+    assert(I->second->find(MD)!=I->second->end() || "Can't find vtable index");
+    return (*I->second)[MD];
+  }
+};
+
+// FIXME: Move to Context.
+VtableInfo::MapTy VtableInfo::IndexFor;
+
 llvm::Value *CodeGenFunction::GenerateVtable(const CXXRecordDecl *RD) {
   llvm::SmallString<256> OutName;
   llvm::raw_svector_ostream Out(OutName);
@@ -1095,7 +1136,7 @@ llvm::Value *CodeGenFunction::GenerateVtable(const CXXRecordDecl *RD) {
   VtableBuilder b(methods, RD, CGM);
 
   // First comes the vtables for all the non-virtual bases...
-  Offset = b.GenerateVtableForBase(RD, true, false, 0, false, 0);
+  Offset = b.GenerateVtableForBase(RD, true, false, 0, false);
 
   // then the vtables for all the virtual bases.
   b.GenerateVtableForVBases(RD, RD);
@@ -1112,6 +1153,31 @@ llvm::Value *CodeGenFunction::GenerateVtable(const CXXRecordDecl *RD) {
   return vtable;
 }
 
+// FIXME: move to Context
+static VtableInfo *vtableinfo;
+
+llvm::Value *
+CodeGenFunction::BuildVirtualCall(const CXXMethodDecl *MD, llvm::Value *&This,
+                                  const llvm::Type *Ty) {
+  // FIXME: If we know the dynamic type, we don't have to do a virtual dispatch.
+  
+  // FIXME: move to Context
+  if (vtableinfo == 0)
+    vtableinfo = new VtableInfo(CGM);
+
+  VtableInfo::Index_t Idx = vtableinfo->lookup(MD);
+
+  Ty = llvm::PointerType::get(Ty, 0);
+  Ty = llvm::PointerType::get(Ty, 0);
+  Ty = llvm::PointerType::get(Ty, 0);
+  llvm::Value *vtbl = Builder.CreateBitCast(This, Ty);
+  vtbl = Builder.CreateLoad(vtbl);
+  llvm::Value *vfn = Builder.CreateConstInBoundsGEP1_64(vtbl,
+                                                        Idx, "vfn");
+  vfn = Builder.CreateLoad(vfn);
+  return vfn;
+}
+
 /// EmitClassAggrMemberwiseCopy - This routine generates code to copy a class
 /// array of objects from SrcValue to DestValue. Copying can be either a bitwise
 /// copy or via a copy constructor call.
index 4fc6c7c773f36925df54bef4f0706241cca75916..e8f0cc5de5b79c3ab5ca69095d85e2502f923f5e 100644 (file)
@@ -826,6 +826,8 @@ public:
                   const Decl *TargetDecl = 0);
   RValue EmitCallExpr(const CallExpr *E);
   
+  llvm::Value *BuildVirtualCall(const CXXMethodDecl *MD, llvm::Value *&This,
+                                const llvm::Type *Ty);
   RValue EmitCXXMemberCall(const CXXMethodDecl *MD,
                            llvm::Value *Callee,
                            llvm::Value *This,
index 89411dea0166053eca11d3042b58b78e0b6d77e5..4583e0aa7cb264d54f38e3d1fd0690e2f744f350 100644 (file)
@@ -91,6 +91,71 @@ int main() {
 // CHECK-LP64: movl $1, 12(%rax)
 // CHECK-LP64: movl $2, 8(%rax)
 
+struct test12_A {
+  virtual void foo0() { }
+  virtual void foo() { }
+} *test12_pa;
+
+struct test12_B : public test12_A {
+  virtual void foo() { }
+} *test12_pb;
+
+struct test12_D : public test12_B {
+} *test12_pd;
+void test12_foo() {
+  test12_pa->foo0();
+  test12_pb->foo0();
+  test12_pd->foo0();
+  test12_pa->foo();
+  test12_pb->foo();
+  test12_pd->foo();
+}
+
+// CHECK-LPOPT32:__Z10test12_foov:
+// CHECK-LPOPT32: movl _test12_pa, %eax
+// CHECK-LPOPT32-NEXT: movl (%eax), %ecx
+// CHECK-LPOPT32-NEXT: movl %eax, (%esp)
+// CHECK-LPOPT32-NEXT: call *(%ecx)
+// CHECK-LPOPT32-NEXT: movl _test12_pb, %eax
+// CHECK-LPOPT32-NEXT: movl (%eax), %ecx
+// CHECK-LPOPT32-NEXT: movl %eax, (%esp)
+// CHECK-LPOPT32-NEXT: call *(%ecx)
+// CHECK-LPOPT32-NEXT: movl _test12_pd, %eax
+// CHECK-LPOPT32-NEXT: movl (%eax), %ecx
+// CHECK-LPOPT32-NEXT: movl %eax, (%esp)
+// CHECK-LPOPT32-NEXT: call *(%ecx)
+// CHECK-LPOPT32-NEXT: movl _test12_pa, %eax
+// CHECK-LPOPT32-NEXT: movl (%eax), %ecx
+// CHECK-LPOPT32-NEXT: movl %eax, (%esp)
+// CHECK-LPOPT32-NEXT: call *4(%ecx)
+// CHECK-LPOPT32-NEXT: movl _test12_pb, %eax
+// CHECK-LPOPT32-NEXT: movl (%eax), %ecx
+// CHECK-LPOPT32-NEXT: movl %eax, (%esp)
+// CHECK-LPOPT32-NEXT: call *4(%ecx)
+// CHECK-LPOPT32-NEXT: movl _test12_pd, %eax
+// CHECK-LPOPT32-NEXT: movl (%eax), %ecx
+// CHECK-LPOPT32-NEXT: movl %eax, (%esp)
+// CHECK-LPOPT32-NEXT: call *4(%ecx)
+
+// CHECK-LPOPT64:__Z10test12_foov:
+// CHECK-LPOPT64: movq _test12_pa(%rip), %rdi
+// CHECK-LPOPT64-NEXT: movq (%rdi), %rax
+// CHECK-LPOPT64-NEXT: call *(%rax)
+// CHECK-LPOPT64-NEXT: movq _test12_pb(%rip), %rdi
+// CHECK-LPOPT64-NEXT: movq (%rdi), %rax
+// CHECK-LPOPT64-NEXT: call *(%rax)
+// CHECK-LPOPT64-NEXT: movq _test12_pd(%rip), %rdi
+// CHECK-LPOPT64-NEXT: movq (%rdi), %rax
+// CHECK-LPOPT64-NEXT: call *(%rax)
+// CHECK-LPOPT64-NEXT: movq _test12_pa(%rip), %rdi
+// CHECK-LPOPT64-NEXT: movq (%rdi), %rax
+// CHECK-LPOPT64-NEXT: call *8(%rax)
+// CHECK-LPOPT64-NEXT: movq _test12_pb(%rip), %rdi
+// CHECK-LPOPT64-NEXT: movq (%rdi), %rax
+// CHECK-LPOPT64-NEXT: call *8(%rax)
+// CHECK-LPOPT64-NEXT: movq _test12_pd(%rip), %rdi
+// CHECK-LPOPT64-NEXT: movq (%rdi), %rax
+// CHECK-LPOPT64-NEXT: call *8(%rax)
 
 struct test6_B2 { virtual void funcB2(); char b[1000]; };
 struct test6_B1 : virtual test6_B2 { virtual void funcB1(); };
@@ -115,7 +180,7 @@ struct test3_B3 { virtual void funcB3(); };
 struct test3_B2 : virtual test3_B3 { virtual void funcB2(); };
 struct test3_B1 : virtual test3_B2 { virtual void funcB1(); };
 
-struct test3_D  : virtual test3_B1 {
+struct test3_D : virtual test3_B1 {
   virtual void funcD() { }
 };
 
@@ -652,7 +717,6 @@ struct test11_D : test11_B {
 // CHECK-LP64-NEXT: .quad __ZN8test11_D2D2Ev
 
 
-
 // CHECK-LP64: __ZTV1B:
 // CHECK-LP64-NEXT: .space 8
 // CHECK-LP64-NEXT: .quad __ZTI1B