From: Anders Carlsson Date: Sat, 3 Oct 2009 19:43:08 +0000 (+0000) Subject: Implement code generation of member function pointer calls. Fixes PR5121. X-Git-Url: https://granicus.if.org/sourcecode?a=commitdiff_plain;h=375c31c4673f83f925de221752cf801c2fbbb246;p=clang Implement code generation of member function pointer calls. Fixes PR5121. git-svn-id: https://llvm.org/svn/llvm-project/cfe/trunk@83271 91177308-0d34-0410-b5e6-96231b3b80d8 --- diff --git a/lib/CodeGen/CGCXX.cpp b/lib/CodeGen/CGCXX.cpp index 1ea60eff56..2d5c62e3e2 100644 --- a/lib/CodeGen/CGCXX.cpp +++ b/lib/CodeGen/CGCXX.cpp @@ -199,6 +199,9 @@ RValue CodeGenFunction::EmitCXXMemberCall(const CXXMethodDecl *MD, } RValue CodeGenFunction::EmitCXXMemberCallExpr(const CXXMemberCallExpr *CE) { + if (isa(CE->getCallee())) + return EmitCXXMemberPointerCallExpr(CE); + const MemberExpr *ME = cast(CE->getCallee()); const CXXMethodDecl *MD = cast(ME->getMemberDecl()); @@ -241,6 +244,112 @@ RValue CodeGenFunction::EmitCXXMemberCallExpr(const CXXMemberCallExpr *CE) { CE->arg_begin(), CE->arg_end()); } +RValue +CodeGenFunction::EmitCXXMemberPointerCallExpr(const CXXMemberCallExpr *E) { + const BinaryOperator *BO = cast(E->getCallee()); + const DeclRefExpr *BaseExpr = cast(BO->getLHS()); + const DeclRefExpr *MemFn = cast(BO->getRHS()); + + const MemberPointerType *MPT = MemFn->getType()->getAs(); + const FunctionProtoType *FPT = + MPT->getPointeeType()->getAs(); + const CXXRecordDecl *RD = + cast(cast(MPT->getClass())->getDecl()); + + const llvm::FunctionType *FTy = + CGM.getTypes().GetFunctionType(CGM.getTypes().getFunctionInfo(RD, FPT), + FPT->isVariadic()); + + const llvm::Type *Int8PtrTy = + llvm::Type::getInt8Ty(VMContext)->getPointerTo(); + + // Get the member function pointer. + llvm::Value *MemFnPtr = + CreateTempAlloca(ConvertType(MemFn->getType()), "mem.fn"); + EmitAggExpr(MemFn, MemFnPtr, /*VolatileDest=*/false); + + // Emit the 'this' pointer. + llvm::Value *This; + + if (BO->getOpcode() == BinaryOperator::PtrMemI) + This = EmitScalarExpr(BaseExpr); + else + This = EmitLValue(BaseExpr).getAddress(); + + // Adjust it. + llvm::Value *Adj = Builder.CreateStructGEP(MemFnPtr, 1); + Adj = Builder.CreateLoad(Adj, "mem.fn.adj"); + + llvm::Value *Ptr = Builder.CreateBitCast(This, Int8PtrTy, "ptr"); + Ptr = Builder.CreateGEP(Ptr, Adj, "adj"); + + This = Builder.CreateBitCast(Ptr, This->getType(), "this"); + + llvm::Value *FnPtr = Builder.CreateStructGEP(MemFnPtr, 0, "mem.fn.ptr"); + + const llvm::Type *PtrDiffTy = ConvertType(getContext().getPointerDiffType()); + + llvm::Value *FnAsInt = Builder.CreateLoad(FnPtr, "fn"); + + // If the LSB in the function pointer is 1, the function pointer points to + // a virtual function. + llvm::Value *IsVirtual + = Builder.CreateAnd(FnAsInt, llvm::ConstantInt::get(PtrDiffTy, 1), + "and"); + + IsVirtual = Builder.CreateTrunc(IsVirtual, + llvm::Type::getInt1Ty(VMContext)); + + llvm::BasicBlock *FnVirtual = createBasicBlock("fn.virtual"); + llvm::BasicBlock *FnNonVirtual = createBasicBlock("fn.nonvirtual"); + llvm::BasicBlock *FnEnd = createBasicBlock("fn.end"); + + Builder.CreateCondBr(IsVirtual, FnVirtual, FnNonVirtual); + EmitBlock(FnVirtual); + + const llvm::Type *VTableTy = + FTy->getPointerTo()->getPointerTo()->getPointerTo(); + + llvm::Value *VTable = Builder.CreateBitCast(This, VTableTy); + VTable = Builder.CreateLoad(VTable); + + VTable = Builder.CreateGEP(VTable, FnAsInt, "fn"); + + // Since the function pointer is 1 plus the virtual table offset, we + // subtract 1 by using a GEP. + VTable = Builder.CreateConstGEP1_64(VTable, -1); + + llvm::Value *VirtualFn = Builder.CreateLoad(VTable, "virtualfn"); + + EmitBranch(FnEnd); + EmitBlock(FnNonVirtual); + + // If the function is not virtual, just load the pointer. + llvm::Value *NonVirtualFn = Builder.CreateLoad(FnPtr, "fn"); + NonVirtualFn = Builder.CreateIntToPtr(NonVirtualFn, FTy->getPointerTo()); + + EmitBlock(FnEnd); + + llvm::PHINode *Callee = Builder.CreatePHI(FTy->getPointerTo()); + Callee->reserveOperandSpace(2); + Callee->addIncoming(VirtualFn, FnVirtual); + Callee->addIncoming(NonVirtualFn, FnNonVirtual); + + CallArgList Args; + + QualType ThisType = + getContext().getPointerType(getContext().getTagDeclType(RD)); + + // Push the this ptr. + Args.push_back(std::make_pair(RValue::get(This), ThisType)); + + // And the rest of the call args + EmitCallArgs(Args, FPT, E->arg_begin(), E->arg_end()); + QualType ResultType = BO->getType()->getAs()->getResultType(); + return EmitCall(CGM.getTypes().getFunctionInfo(ResultType, Args), + Callee, Args, 0); +} + RValue CodeGenFunction::EmitCXXOperatorMemberCallExpr(const CXXOperatorCallExpr *E, const CXXMethodDecl *MD) { diff --git a/lib/CodeGen/CGCall.cpp b/lib/CodeGen/CGCall.cpp index 0b9b3fc270..2a1843eb6c 100644 --- a/lib/CodeGen/CGCall.cpp +++ b/lib/CodeGen/CGCall.cpp @@ -63,6 +63,21 @@ static unsigned getCallingConventionForDecl(const Decl *D) { return llvm::CallingConv::C; } +const CGFunctionInfo &CodeGenTypes::getFunctionInfo(const CXXRecordDecl *RD, + const FunctionProtoType *FTP) { + llvm::SmallVector ArgTys; + + // Add the 'this' pointer. + ArgTys.push_back(Context.getPointerType(Context.getTagDeclType(RD))); + + for (unsigned i = 0, e = FTP->getNumArgs(); i != e; ++i) + ArgTys.push_back(FTP->getArgType(i)); + + // FIXME: Set calling convention correctly, it needs to be associated with the + // type somehow. + return getFunctionInfo(FTP->getResultType(), ArgTys, 0); +} + const CGFunctionInfo &CodeGenTypes::getFunctionInfo(const CXXMethodDecl *MD) { llvm::SmallVector ArgTys; // Add the 'this' pointer unless this is a static method. diff --git a/lib/CodeGen/CodeGenFunction.h b/lib/CodeGen/CodeGenFunction.h index 09a20e2227..34b3860a48 100644 --- a/lib/CodeGen/CodeGenFunction.h +++ b/lib/CodeGen/CodeGenFunction.h @@ -864,10 +864,12 @@ public: CallExpr::const_arg_iterator ArgBeg, CallExpr::const_arg_iterator ArgEnd); RValue EmitCXXMemberCallExpr(const CXXMemberCallExpr *E); + RValue EmitCXXMemberPointerCallExpr(const CXXMemberCallExpr *E); RValue EmitCXXOperatorMemberCallExpr(const CXXOperatorCallExpr *E, const CXXMethodDecl *MD); + RValue EmitBuiltinExpr(const FunctionDecl *FD, unsigned BuiltinID, const CallExpr *E); diff --git a/lib/CodeGen/CodeGenTypes.h b/lib/CodeGen/CodeGenTypes.h index 53df106820..ad71e0ad05 100644 --- a/lib/CodeGen/CodeGenTypes.h +++ b/lib/CodeGen/CodeGenTypes.h @@ -181,7 +181,11 @@ public: const CGFunctionInfo &getFunctionInfo(const FunctionDecl *FD); const CGFunctionInfo &getFunctionInfo(const CXXMethodDecl *MD); const CGFunctionInfo &getFunctionInfo(const ObjCMethodDecl *MD); - + + // getFunctionInfo - Get the function info for a member function. + const CGFunctionInfo &getFunctionInfo(const CXXRecordDecl *RD, + const FunctionProtoType *FTP); + /// getFunctionInfo - Get the function info for a function described by a /// return type and argument types. If the calling convention is not /// specified, the "C" calling convention will be used. diff --git a/test/CodeGenCXX/member-function-pointers.cpp b/test/CodeGenCXX/member-function-pointers.cpp index 57e2e7f2d6..9727a9dabd 100644 --- a/test/CodeGenCXX/member-function-pointers.cpp +++ b/test/CodeGenCXX/member-function-pointers.cpp @@ -49,3 +49,8 @@ void f2() { // CHECK: store i64 0, i64* [[pa2adj]] void (A::*pa3)() = &A::vf; } + +void f3(A *a, A &ar) { + (a->*pa)(); + (ar.*pa)(); +}