From: Lang Hames Date: Fri, 11 Nov 2016 21:42:09 +0000 (+0000) Subject: [ORC] Re-apply 286620 with fixes for the ErrorSuccess class. X-Git-Url: https://granicus.if.org/sourcecode?a=commitdiff_plain;h=085827f84324b95e41ede1a9507e7d2fb8d505a9;p=llvm [ORC] Re-apply 286620 with fixes for the ErrorSuccess class. git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@286639 91177308-0d34-0410-b5e6-96231b3b80d8 --- diff --git a/examples/Kaleidoscope/BuildingAJIT/Chapter5/RemoteJITUtils.h b/examples/Kaleidoscope/BuildingAJIT/Chapter5/RemoteJITUtils.h index c95532e8db3..718b99e4b24 100644 --- a/examples/Kaleidoscope/BuildingAJIT/Chapter5/RemoteJITUtils.h +++ b/examples/Kaleidoscope/BuildingAJIT/Chapter5/RemoteJITUtils.h @@ -14,7 +14,7 @@ #ifndef LLVM_TOOLS_LLI_REMOTEJITUTILS_H #define LLVM_TOOLS_LLI_REMOTEJITUTILS_H -#include "llvm/ExecutionEngine/Orc/RPCByteChannel.h" +#include "llvm/ExecutionEngine/Orc/RawByteChannel.h" #include "llvm/ExecutionEngine/RTDyldMemoryManager.h" #include @@ -25,7 +25,7 @@ #endif /// RPC channel that reads from and writes from file descriptors. -class FDRPCChannel final : public llvm::orc::remote::RPCByteChannel { +class FDRPCChannel final : public llvm::orc::rpc::RawByteChannel { public: FDRPCChannel(int InFD, int OutFD) : InFD(InFD), OutFD(OutFD) {} diff --git a/examples/Kaleidoscope/BuildingAJIT/Chapter5/toy.cpp b/examples/Kaleidoscope/BuildingAJIT/Chapter5/toy.cpp index 9c21098971a..f5a06cf2bf4 100644 --- a/examples/Kaleidoscope/BuildingAJIT/Chapter5/toy.cpp +++ b/examples/Kaleidoscope/BuildingAJIT/Chapter5/toy.cpp @@ -1265,8 +1265,8 @@ int main(int argc, char *argv[]) { BinopPrecedence['*'] = 40; // highest. auto TCPChannel = connect(); - MyRemote Remote = ExitOnErr(MyRemote::Create(*TCPChannel)); - TheJIT = llvm::make_unique(Remote); + auto Remote = ExitOnErr(MyRemote::Create(*TCPChannel)); + TheJIT = llvm::make_unique(*Remote); // Automatically inject a definition for 'printExprResult'. FunctionProtos["printExprResult"] = @@ -1288,7 +1288,7 @@ int main(int argc, char *argv[]) { TheJIT = nullptr; // Send a terminate message to the remote to tell it to exit cleanly. - ExitOnErr(Remote.terminateSession()); + ExitOnErr(Remote->terminateSession()); return 0; } diff --git a/include/llvm/ExecutionEngine/Orc/OrcError.h b/include/llvm/ExecutionEngine/Orc/OrcError.h index 1b3f25fae16..8841aa77f62 100644 --- a/include/llvm/ExecutionEngine/Orc/OrcError.h +++ b/include/llvm/ExecutionEngine/Orc/OrcError.h @@ -29,6 +29,7 @@ enum class OrcErrorCode : int { RemoteIndirectStubsOwnerIdAlreadyInUse, UnexpectedRPCCall, UnexpectedRPCResponse, + UnknownRPCFunction }; Error orcError(OrcErrorCode ErrCode); diff --git a/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetClient.h b/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetClient.h index d549fc31deb..5b2f8921fef 100644 --- a/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetClient.h +++ b/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetClient.h @@ -8,7 +8,7 @@ //===----------------------------------------------------------------------===// // // This file defines the OrcRemoteTargetClient class and helpers. This class -// can be used to communicate over an RPCByteChannel with an +// can be used to communicate over an RawByteChannel with an // OrcRemoteTargetServer instance to support remote-JITing. // //===----------------------------------------------------------------------===// @@ -36,23 +36,6 @@ namespace remote { template class OrcRemoteTargetClient : public OrcRemoteTargetRPCAPI { public: - // FIXME: Remove move/copy ops once MSVC supports synthesizing move ops. - - OrcRemoteTargetClient(const OrcRemoteTargetClient &) = delete; - OrcRemoteTargetClient &operator=(const OrcRemoteTargetClient &) = delete; - - OrcRemoteTargetClient(OrcRemoteTargetClient &&Other) - : Channel(Other.Channel), ExistingError(std::move(Other.ExistingError)), - RemoteTargetTriple(std::move(Other.RemoteTargetTriple)), - RemotePointerSize(std::move(Other.RemotePointerSize)), - RemotePageSize(std::move(Other.RemotePageSize)), - RemoteTrampolineSize(std::move(Other.RemoteTrampolineSize)), - RemoteIndirectStubSize(std::move(Other.RemoteIndirectStubSize)), - AllocatorIds(std::move(Other.AllocatorIds)), - IndirectStubOwnerIds(std::move(Other.IndirectStubOwnerIds)), - CallbackManager(std::move(Other.CallbackManager)) {} - - OrcRemoteTargetClient &operator=(OrcRemoteTargetClient &&) = delete; /// Remote memory manager. class RCMemoryManager : public RuntimeDyld::MemoryManager { @@ -62,18 +45,10 @@ public: DEBUG(dbgs() << "Created remote allocator " << Id << "\n"); } - RCMemoryManager(RCMemoryManager &&Other) - : Client(std::move(Other.Client)), Id(std::move(Other.Id)), - Unmapped(std::move(Other.Unmapped)), - Unfinalized(std::move(Other.Unfinalized)) {} - - RCMemoryManager operator=(RCMemoryManager &&Other) { - Client = std::move(Other.Client); - Id = std::move(Other.Id); - Unmapped = std::move(Other.Unmapped); - Unfinalized = std::move(Other.Unfinalized); - return *this; - } + RCMemoryManager(const RCMemoryManager&) = delete; + RCMemoryManager& operator=(const RCMemoryManager&) = delete; + RCMemoryManager(RCMemoryManager&&) = default; + RCMemoryManager& operator=(RCMemoryManager&&) = default; ~RCMemoryManager() override { Client.destroyRemoteAllocator(Id); @@ -367,18 +342,10 @@ public: Alloc(uint64_t Size, unsigned Align) : Size(Size), Align(Align), Contents(new char[Size + Align - 1]) {} - Alloc(Alloc &&Other) - : Size(std::move(Other.Size)), Align(std::move(Other.Align)), - Contents(std::move(Other.Contents)), - RemoteAddr(std::move(Other.RemoteAddr)) {} - - Alloc &operator=(Alloc &&Other) { - Size = std::move(Other.Size); - Align = std::move(Other.Align); - Contents = std::move(Other.Contents); - RemoteAddr = std::move(Other.RemoteAddr); - return *this; - } + Alloc(const Alloc&) = delete; + Alloc& operator=(const Alloc&) = delete; + Alloc(Alloc&&) = default; + Alloc& operator=(Alloc&&) = default; uint64_t getSize() const { return Size; } @@ -405,24 +372,10 @@ public: struct ObjectAllocs { ObjectAllocs() = default; - - ObjectAllocs(ObjectAllocs &&Other) - : RemoteCodeAddr(std::move(Other.RemoteCodeAddr)), - RemoteRODataAddr(std::move(Other.RemoteRODataAddr)), - RemoteRWDataAddr(std::move(Other.RemoteRWDataAddr)), - CodeAllocs(std::move(Other.CodeAllocs)), - RODataAllocs(std::move(Other.RODataAllocs)), - RWDataAllocs(std::move(Other.RWDataAllocs)) {} - - ObjectAllocs &operator=(ObjectAllocs &&Other) { - RemoteCodeAddr = std::move(Other.RemoteCodeAddr); - RemoteRODataAddr = std::move(Other.RemoteRODataAddr); - RemoteRWDataAddr = std::move(Other.RemoteRWDataAddr); - CodeAllocs = std::move(Other.CodeAllocs); - RODataAllocs = std::move(Other.RODataAllocs); - RWDataAllocs = std::move(Other.RWDataAllocs); - return *this; - } + ObjectAllocs(const ObjectAllocs &) = delete; + ObjectAllocs& operator=(const ObjectAllocs &) = delete; + ObjectAllocs(ObjectAllocs&&) = default; + ObjectAllocs& operator=(ObjectAllocs&&) = default; JITTargetAddress RemoteCodeAddr = 0; JITTargetAddress RemoteRODataAddr = 0; @@ -588,23 +541,21 @@ public: /// Create an OrcRemoteTargetClient. /// Channel is the ChannelT instance to communicate on. It is assumed that /// the channel is ready to be read from and written to. - static Expected Create(ChannelT &Channel) { + static Expected> + Create(ChannelT &Channel) { Error Err = Error::success(); - OrcRemoteTargetClient H(Channel, Err); + std::unique_ptr + Client(new OrcRemoteTargetClient(Channel, Err)); if (Err) return std::move(Err); - return Expected(std::move(H)); + return std::move(Client); } /// Call the int(void) function at the given address in the target and return /// its result. Expected callIntVoid(JITTargetAddress Addr) { DEBUG(dbgs() << "Calling int(*)(void) " << format("0x%016x", Addr) << "\n"); - - auto Listen = [&](RPCByteChannel &C, uint32_t Id) { - return listenForCompileRequests(C, Id); - }; - return callSTHandling(Channel, Listen, Addr); + return callB(Addr); } /// Call the int(int, char*[]) function at the given address in the target and @@ -613,11 +564,7 @@ public: const std::vector &Args) { DEBUG(dbgs() << "Calling int(*)(int, char*[]) " << format("0x%016x", Addr) << "\n"); - - auto Listen = [&](RPCByteChannel &C, uint32_t Id) { - return listenForCompileRequests(C, Id); - }; - return callSTHandling(Channel, Listen, Addr, Args); + return callB(Addr, Args); } /// Call the void() function at the given address in the target and wait for @@ -625,11 +572,7 @@ public: Error callVoidVoid(JITTargetAddress Addr) { DEBUG(dbgs() << "Calling void(*)(void) " << format("0x%016x", Addr) << "\n"); - - auto Listen = [&](RPCByteChannel &C, uint32_t Id) { - return listenForCompileRequests(C, Id); - }; - return callSTHandling(Channel, Listen, Addr); + return callB(Addr); } /// Create an RCMemoryManager which will allocate its memory on the remote @@ -638,7 +581,7 @@ public: assert(!MM && "MemoryManager should be null before creation."); auto Id = AllocatorIds.getNext(); - if (auto Err = callST(Channel, Id)) + if (auto Err = callB(Id)) return Err; MM = llvm::make_unique(*this, Id); return Error::success(); @@ -649,7 +592,7 @@ public: Error createIndirectStubsManager(std::unique_ptr &I) { assert(!I && "Indirect stubs manager should be null before creation."); auto Id = IndirectStubOwnerIds.getNext(); - if (auto Err = callST(Channel, Id)) + if (auto Err = callB(Id)) return Err; I = llvm::make_unique(*this, Id); return Error::success(); @@ -662,7 +605,7 @@ public: return std::move(ExistingError); // Emit the resolver block on the JIT server. - if (auto Err = callST(Channel)) + if (auto Err = callB()) return std::move(Err); // Create the callback manager. @@ -679,18 +622,28 @@ public: if (ExistingError) return std::move(ExistingError); - return callST(Channel, Name); + return callB(Name); } /// Get the triple for the remote target. const std::string &getTargetTriple() const { return RemoteTargetTriple; } - Error terminateSession() { return callST(Channel); } + Error terminateSession() { return callB(); } private: - OrcRemoteTargetClient(ChannelT &Channel, Error &Err) : Channel(Channel) { + + OrcRemoteTargetClient(ChannelT &Channel, Error &Err) + : OrcRemoteTargetRPCAPI(Channel) { ErrorAsOutParameter EAO(&Err); - if (auto RIOrErr = callST(Channel)) { + + addHandler( + [this](JITTargetAddress Addr) -> JITTargetAddress { + if (CallbackManager) + return CallbackManager->executeCompileCallback(Addr); + return 0; + }); + + if (auto RIOrErr = callB()) { std::tie(RemoteTargetTriple, RemotePointerSize, RemotePageSize, RemoteTrampolineSize, RemoteIndirectStubSize) = *RIOrErr; Err = Error::success(); @@ -700,11 +653,11 @@ private: } Error deregisterEHFrames(JITTargetAddress Addr, uint32_t Size) { - return callST(Channel, Addr, Size); + return callB(Addr, Size); } void destroyRemoteAllocator(ResourceIdMgr::ResourceId Id) { - if (auto Err = callST(Channel, Id)) { + if (auto Err = callB(Id)) { // FIXME: This will be triggered by a removeModuleSet call: Propagate // error return up through that. llvm_unreachable("Failed to destroy remote allocator."); @@ -714,12 +667,12 @@ private: Error destroyIndirectStubsManager(ResourceIdMgr::ResourceId Id) { IndirectStubOwnerIds.release(Id); - return callST(Channel, Id); + return callB(Id); } Expected> emitIndirectStubs(ResourceIdMgr::ResourceId Id, uint32_t NumStubsRequired) { - return callST(Channel, Id, NumStubsRequired); + return callB(Id, NumStubsRequired); } Expected> emitTrampolineBlock() { @@ -727,7 +680,7 @@ private: if (ExistingError) return std::move(ExistingError); - return callST(Channel); + return callB(); } uint32_t getIndirectStubSize() const { return RemoteIndirectStubSize; } @@ -736,42 +689,17 @@ private: uint32_t getTrampolineSize() const { return RemoteTrampolineSize; } - Error listenForCompileRequests(RPCByteChannel &C, uint32_t &Id) { - assert(CallbackManager && - "No calback manager. enableCompileCallbacks must be called first"); - - // Check for an 'out-of-band' error, e.g. from an MM destructor. - if (ExistingError) - return std::move(ExistingError); - - // FIXME: CompileCallback could be an anonymous lambda defined at the use - // site below, but that triggers a GCC 4.7 ICE. When we move off - // GCC 4.7, tidy this up. - auto CompileCallback = - [this](JITTargetAddress Addr) -> Expected { - return this->CallbackManager->executeCompileCallback(Addr); - }; - - if (Id == RequestCompileId) { - if (auto Err = handle(C, CompileCallback)) - return Err; - return Error::success(); - } - // else - return orcError(OrcErrorCode::UnexpectedRPCCall); - } - Expected> readMem(char *Dst, JITTargetAddress Src, uint64_t Size) { // Check for an 'out-of-band' error, e.g. from an MM destructor. if (ExistingError) return std::move(ExistingError); - return callST(Channel, Src, Size); + return callB(Src, Size); } Error registerEHFrames(JITTargetAddress &RAddr, uint32_t Size) { - return callST(Channel, RAddr, Size); + return callB(RAddr, Size); } Expected reserveMem(ResourceIdMgr::ResourceId Id, @@ -781,12 +709,12 @@ private: if (ExistingError) return std::move(ExistingError); - return callST(Channel, Id, Size, Align); + return callB(Id, Size, Align); } Error setProtections(ResourceIdMgr::ResourceId Id, JITTargetAddress RemoteSegAddr, unsigned ProtFlags) { - return callST(Channel, Id, RemoteSegAddr, ProtFlags); + return callB(Id, RemoteSegAddr, ProtFlags); } Error writeMem(JITTargetAddress Addr, const char *Src, uint64_t Size) { @@ -794,7 +722,7 @@ private: if (ExistingError) return std::move(ExistingError); - return callST(Channel, DirectBufferWriter(Src, Addr, Size)); + return callB(DirectBufferWriter(Src, Addr, Size)); } Error writePointer(JITTargetAddress Addr, JITTargetAddress PtrVal) { @@ -802,12 +730,11 @@ private: if (ExistingError) return std::move(ExistingError); - return callST(Channel, Addr, PtrVal); + return callB(Addr, PtrVal); } static Error doNothing() { return Error::success(); } - ChannelT &Channel; Error ExistingError = Error::success(); std::string RemoteTargetTriple; uint32_t RemotePointerSize = 0; diff --git a/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetRPCAPI.h b/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetRPCAPI.h index 33d6b604c61..413e286a347 100644 --- a/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetRPCAPI.h +++ b/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetRPCAPI.h @@ -16,7 +16,7 @@ #ifndef LLVM_EXECUTIONENGINE_ORC_ORCREMOTETARGETRPCAPI_H #define LLVM_EXECUTIONENGINE_ORC_ORCREMOTETARGETRPCAPI_H -#include "RPCByteChannel.h" +#include "RawByteChannel.h" #include "RPCUtils.h" #include "llvm/ExecutionEngine/JITSymbol.h" @@ -40,13 +40,24 @@ private: uint64_t Size; }; +} // end namespace remote + +namespace rpc { + template <> -class SerializationTraits { +class RPCTypeName { public: + static const char *getName() { return "DirectBufferWriter"; } +}; - static const char* getName() { return "DirectBufferWriter"; } +template +class SerializationTraits:: + value>::type> { +public: - static Error serialize(RPCByteChannel &C, const DirectBufferWriter &DBW) { + static Error serialize(ChannelT &C, const remote::DirectBufferWriter &DBW) { if (auto EC = serializeSeq(C, DBW.getDst())) return EC; if (auto EC = serializeSeq(C, DBW.getSize())) @@ -54,7 +65,7 @@ public: return C.appendBytes(DBW.getSrc(), DBW.getSize()); } - static Error deserialize(RPCByteChannel &C, DirectBufferWriter &DBW) { + static Error deserialize(ChannelT &C, remote::DirectBufferWriter &DBW) { JITTargetAddress Dst; if (auto EC = deserializeSeq(C, Dst)) return EC; @@ -63,13 +74,18 @@ public: return EC; char *Addr = reinterpret_cast(static_cast(Dst)); - DBW = DirectBufferWriter(0, Dst, Size); + DBW = remote::DirectBufferWriter(0, Dst, Size); return C.readBytes(Addr, Size); } }; -class OrcRemoteTargetRPCAPI : public RPC { +} // end namespace rpc + +namespace remote { + +class OrcRemoteTargetRPCAPI + : public rpc::SingleThreadedRPC { protected: class ResourceIdMgr { public: @@ -93,119 +109,162 @@ protected: public: // FIXME: Remove constructors once MSVC supports synthesizing move-ops. - OrcRemoteTargetRPCAPI() = default; - OrcRemoteTargetRPCAPI(const OrcRemoteTargetRPCAPI &) = delete; - OrcRemoteTargetRPCAPI &operator=(const OrcRemoteTargetRPCAPI &) = delete; - - OrcRemoteTargetRPCAPI(OrcRemoteTargetRPCAPI &&) {} - OrcRemoteTargetRPCAPI &operator=(OrcRemoteTargetRPCAPI &&) { return *this; } - - enum JITFuncId : uint32_t { - InvalidId = RPCFunctionIdTraits::InvalidId, - CallIntVoidId = RPCFunctionIdTraits::FirstValidId, - CallMainId, - CallVoidVoidId, - CreateRemoteAllocatorId, - CreateIndirectStubsOwnerId, - DeregisterEHFramesId, - DestroyRemoteAllocatorId, - DestroyIndirectStubsOwnerId, - EmitIndirectStubsId, - EmitResolverBlockId, - EmitTrampolineBlockId, - GetSymbolAddressId, - GetRemoteInfoId, - ReadMemId, - RegisterEHFramesId, - ReserveMemId, - RequestCompileId, - SetProtectionsId, - TerminateSessionId, - WriteMemId, - WritePtrId - }; - - static const char *getJITFuncIdName(JITFuncId Id); - - typedef Function CallIntVoid; - - typedef Function Args)> - CallMain; - - typedef Function CallVoidVoid; - - typedef Function - CreateRemoteAllocator; - - typedef Function - CreateIndirectStubsOwner; - - typedef Function - DeregisterEHFrames; - - typedef Function - DestroyRemoteAllocator; - - typedef Function - DestroyIndirectStubsOwner; + OrcRemoteTargetRPCAPI(rpc::RawByteChannel &C) + : rpc::SingleThreadedRPC(C, true) {} + + class CallIntVoid : public rpc::Function { + public: + static const char* getName() { return "CallIntVoid"; } + }; + + class CallMain + : public rpc::Function Args)> { + public: + static const char* getName() { return "CallMain"; } + }; + + class CallVoidVoid : public rpc::Function { + public: + static const char* getName() { return "CallVoidVoid"; } + }; + + class CreateRemoteAllocator + : public rpc::Function { + public: + static const char* getName() { return "CreateRemoteAllocator"; } + }; + + class CreateIndirectStubsOwner + : public rpc::Function { + public: + static const char* getName() { return "CreateIndirectStubsOwner"; } + }; + + class DeregisterEHFrames + : public rpc::Function { + public: + static const char* getName() { return "DeregisterEHFrames"; } + }; + + class DestroyRemoteAllocator + : public rpc::Function { + public: + static const char* getName() { return "DestroyRemoteAllocator"; } + }; + + class DestroyIndirectStubsOwner + : public rpc::Function { + public: + static const char* getName() { return "DestroyIndirectStubsOwner"; } + }; /// EmitIndirectStubs result is (StubsBase, PtrsBase, NumStubsEmitted). - typedef Function( - ResourceIdMgr::ResourceId StubsOwnerID, - uint32_t NumStubsRequired)> - EmitIndirectStubs; + class EmitIndirectStubs + : public rpc::Function( + ResourceIdMgr::ResourceId StubsOwnerID, + uint32_t NumStubsRequired)> { + public: + static const char* getName() { return "EmitIndirectStubs"; } + }; - typedef Function EmitResolverBlock; + class EmitResolverBlock : public rpc::Function { + public: + static const char* getName() { return "EmitResolverBlock"; } + }; /// EmitTrampolineBlock result is (BlockAddr, NumTrampolines). - typedef Function()> - EmitTrampolineBlock; + class EmitTrampolineBlock + : public rpc::Function()> { + public: + static const char* getName() { return "EmitTrampolineBlock"; } + }; - typedef Function - GetSymbolAddress; + class GetSymbolAddress + : public rpc::Function { + public: + static const char* getName() { return "GetSymbolAddress"; } + }; /// GetRemoteInfo result is (Triple, PointerSize, PageSize, TrampolineSize, /// IndirectStubsSize). - typedef Function()> - GetRemoteInfo; + class GetRemoteInfo + : public rpc::Function()> { + public: + static const char* getName() { return "GetRemoteInfo"; } + }; - typedef Function(JITTargetAddress Src, uint64_t Size)> - ReadMem; + class ReadMem + : public rpc::Function(JITTargetAddress Src, + uint64_t Size)> { + public: + static const char* getName() { return "ReadMem"; } + }; - typedef Function - RegisterEHFrames; + class RegisterEHFrames + : public rpc::Function { + public: + static const char* getName() { return "RegisterEHFrames"; } + }; - typedef Function - ReserveMem; + class ReserveMem + : public rpc::Function { + public: + static const char* getName() { return "ReserveMem"; } + }; - typedef Function - RequestCompile; + class RequestCompile + : public rpc::Function { + public: + static const char* getName() { return "RequestCompile"; } + }; + + class SetProtections + : public rpc::Function { + public: + static const char* getName() { return "SetProtections"; } + }; - typedef Function - SetProtections; + class TerminateSession : public rpc::Function { + public: + static const char* getName() { return "TerminateSession"; } + }; - typedef Function TerminateSession; + class WriteMem : public rpc::Function { + public: + static const char* getName() { return "WriteMem"; } + }; - typedef Function WriteMem; + class WritePtr + : public rpc::Function { + public: + static const char* getName() { return "WritePtr"; } + }; - typedef Function - WritePtr; }; } // end namespace remote diff --git a/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetServer.h b/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetServer.h index e3dfaf77566..bda4cd15342 100644 --- a/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetServer.h +++ b/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetServer.h @@ -41,94 +41,51 @@ public: OrcRemoteTargetServer(ChannelT &Channel, SymbolLookupFtor SymbolLookup, EHFrameRegistrationFtor EHFramesRegister, EHFrameRegistrationFtor EHFramesDeregister) - : Channel(Channel), SymbolLookup(std::move(SymbolLookup)), + : OrcRemoteTargetRPCAPI(Channel), SymbolLookup(std::move(SymbolLookup)), EHFramesRegister(std::move(EHFramesRegister)), - EHFramesDeregister(std::move(EHFramesDeregister)) {} + EHFramesDeregister(std::move(EHFramesDeregister)), + TerminateFlag(false) { + + using ThisT = typename std::remove_reference::type; + addHandler(*this, &ThisT::handleCallIntVoid); + addHandler(*this, &ThisT::handleCallMain); + addHandler(*this, &ThisT::handleCallVoidVoid); + addHandler(*this, + &ThisT::handleCreateRemoteAllocator); + addHandler(*this, + &ThisT::handleCreateIndirectStubsOwner); + addHandler(*this, &ThisT::handleDeregisterEHFrames); + addHandler(*this, + &ThisT::handleDestroyRemoteAllocator); + addHandler(*this, + &ThisT::handleDestroyIndirectStubsOwner); + addHandler(*this, &ThisT::handleEmitIndirectStubs); + addHandler(*this, &ThisT::handleEmitResolverBlock); + addHandler(*this, &ThisT::handleEmitTrampolineBlock); + addHandler(*this, &ThisT::handleGetSymbolAddress); + addHandler(*this, &ThisT::handleGetRemoteInfo); + addHandler(*this, &ThisT::handleReadMem); + addHandler(*this, &ThisT::handleRegisterEHFrames); + addHandler(*this, &ThisT::handleReserveMem); + addHandler(*this, &ThisT::handleSetProtections); + addHandler(*this, &ThisT::handleTerminateSession); + addHandler(*this, &ThisT::handleWriteMem); + addHandler(*this, &ThisT::handleWritePtr); + } // FIXME: Remove move/copy ops once MSVC supports synthesizing move ops. OrcRemoteTargetServer(const OrcRemoteTargetServer &) = delete; OrcRemoteTargetServer &operator=(const OrcRemoteTargetServer &) = delete; - OrcRemoteTargetServer(OrcRemoteTargetServer &&Other) - : Channel(Other.Channel), SymbolLookup(std::move(Other.SymbolLookup)), - EHFramesRegister(std::move(Other.EHFramesRegister)), - EHFramesDeregister(std::move(Other.EHFramesDeregister)) {} - + OrcRemoteTargetServer(OrcRemoteTargetServer &&Other) = default; OrcRemoteTargetServer &operator=(OrcRemoteTargetServer &&) = delete; - Error handleKnownFunction(JITFuncId Id) { - typedef OrcRemoteTargetServer ThisT; - - DEBUG(dbgs() << "Handling known proc: " << getJITFuncIdName(Id) << "\n"); - - switch (Id) { - case CallIntVoidId: - return handle(Channel, *this, &ThisT::handleCallIntVoid); - case CallMainId: - return handle(Channel, *this, &ThisT::handleCallMain); - case CallVoidVoidId: - return handle(Channel, *this, &ThisT::handleCallVoidVoid); - case CreateRemoteAllocatorId: - return handle(Channel, *this, - &ThisT::handleCreateRemoteAllocator); - case CreateIndirectStubsOwnerId: - return handle( - Channel, *this, &ThisT::handleCreateIndirectStubsOwner); - case DeregisterEHFramesId: - return handle(Channel, *this, - &ThisT::handleDeregisterEHFrames); - case DestroyRemoteAllocatorId: - return handle( - Channel, *this, &ThisT::handleDestroyRemoteAllocator); - case DestroyIndirectStubsOwnerId: - return handle( - Channel, *this, &ThisT::handleDestroyIndirectStubsOwner); - case EmitIndirectStubsId: - return handle(Channel, *this, - &ThisT::handleEmitIndirectStubs); - case EmitResolverBlockId: - return handle(Channel, *this, - &ThisT::handleEmitResolverBlock); - case EmitTrampolineBlockId: - return handle(Channel, *this, - &ThisT::handleEmitTrampolineBlock); - case GetSymbolAddressId: - return handle(Channel, *this, - &ThisT::handleGetSymbolAddress); - case GetRemoteInfoId: - return handle(Channel, *this, &ThisT::handleGetRemoteInfo); - case ReadMemId: - return handle(Channel, *this, &ThisT::handleReadMem); - case RegisterEHFramesId: - return handle(Channel, *this, - &ThisT::handleRegisterEHFrames); - case ReserveMemId: - return handle(Channel, *this, &ThisT::handleReserveMem); - case SetProtectionsId: - return handle(Channel, *this, - &ThisT::handleSetProtections); - case WriteMemId: - return handle(Channel, *this, &ThisT::handleWriteMem); - case WritePtrId: - return handle(Channel, *this, &ThisT::handleWritePtr); - default: - return orcError(OrcErrorCode::UnexpectedRPCCall); - } - - llvm_unreachable("Unhandled JIT RPC procedure Id."); - } Expected requestCompile(JITTargetAddress TrampolineAddr) { - auto Listen = [&](RPCByteChannel &C, uint32_t Id) { - return handleKnownFunction(static_cast(Id)); - }; - - return callSTHandling(Channel, Listen, TrampolineAddr); + return callB(TrampolineAddr); } - Error handleTerminateSession() { - return handle(Channel, []() { return Error::success(); }); - } + bool receivedTerminate() const { return TerminateFlag; } private: struct Allocator { @@ -365,15 +322,16 @@ private: IndirectStubSize); } - Expected> handleReadMem(JITTargetAddress RSrc, uint64_t Size) { - char *Src = reinterpret_cast(static_cast(RSrc)); + Expected> handleReadMem(JITTargetAddress RSrc, + uint64_t Size) { + uint8_t *Src = reinterpret_cast(static_cast(RSrc)); DEBUG(dbgs() << " Reading " << Size << " bytes from " << format("0x%016x", RSrc) << "\n"); - std::vector Buffer; + std::vector Buffer; Buffer.resize(Size); - for (char *P = Src; Size != 0; --Size) + for (uint8_t *P = Src; Size != 0; --Size) Buffer.push_back(*P++); return Buffer; @@ -421,6 +379,11 @@ private: return Allocator.setProtections(LocalAddr, Flags); } + Error handleTerminateSession() { + TerminateFlag = true; + return Error::success(); + } + Error handleWriteMem(DirectBufferWriter DBW) { DEBUG(dbgs() << " Writing " << DBW.getSize() << " bytes to " << format("0x%016x", DBW.getDst()) << "\n"); @@ -436,7 +399,6 @@ private: return Error::success(); } - ChannelT &Channel; SymbolLookupFtor SymbolLookup; EHFrameRegistrationFtor EHFramesRegister, EHFramesDeregister; std::map Allocators; @@ -444,6 +406,7 @@ private: std::map IndirectStubsOwners; sys::OwningMemoryBlock ResolverBlock; std::vector TrampolineBlocks; + bool TerminateFlag; }; } // end namespace remote diff --git a/include/llvm/ExecutionEngine/Orc/RPCByteChannel.h b/include/llvm/ExecutionEngine/Orc/RPCByteChannel.h deleted file mode 100644 index c8cb42d5374..00000000000 --- a/include/llvm/ExecutionEngine/Orc/RPCByteChannel.h +++ /dev/null @@ -1,231 +0,0 @@ -//===- llvm/ExecutionEngine/Orc/RPCByteChannel.h ----------------*- C++ -*-===// -// -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. -// -//===----------------------------------------------------------------------===// - -#ifndef LLVM_EXECUTIONENGINE_ORC_RPCBYTECHANNEL_H -#define LLVM_EXECUTIONENGINE_ORC_RPCBYTECHANNEL_H - -#include "OrcError.h" -#include "RPCSerialization.h" -#include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/StringRef.h" -#include "llvm/Support/Endian.h" -#include "llvm/Support/Error.h" -#include -#include -#include -#include -#include -#include -#include - -namespace llvm { -namespace orc { -namespace remote { - -/// Interface for byte-streams to be used with RPC. -class RPCByteChannel { -public: - virtual ~RPCByteChannel() {} - - /// Read Size bytes from the stream into *Dst. - virtual Error readBytes(char *Dst, unsigned Size) = 0; - - /// Read size bytes from *Src and append them to the stream. - virtual Error appendBytes(const char *Src, unsigned Size) = 0; - - /// Flush the stream if possible. - virtual Error send() = 0; - - /// Get the lock for stream reading. - std::mutex &getReadLock() { return readLock; } - - /// Get the lock for stream writing. - std::mutex &getWriteLock() { return writeLock; } - -private: - std::mutex readLock, writeLock; -}; - -/// Notify the channel that we're starting a message send. -/// Locks the channel for writing. -inline Error startSendMessage(RPCByteChannel &C) { - C.getWriteLock().lock(); - return Error::success(); -} - -/// Notify the channel that we're ending a message send. -/// Unlocks the channel for writing. -inline Error endSendMessage(RPCByteChannel &C) { - C.getWriteLock().unlock(); - return Error::success(); -} - -/// Notify the channel that we're starting a message receive. -/// Locks the channel for reading. -inline Error startReceiveMessage(RPCByteChannel &C) { - C.getReadLock().lock(); - return Error::success(); -} - -/// Notify the channel that we're ending a message receive. -/// Unlocks the channel for reading. -inline Error endReceiveMessage(RPCByteChannel &C) { - C.getReadLock().unlock(); - return Error::success(); -} - -template ::value>:: - type> -class RPCByteChannelPrimitiveSerialization { -public: - static Error serialize(ChannelT &C, T V) { - support::endian::byte_swap(V); - return C.appendBytes(reinterpret_cast(&V), sizeof(T)); - }; - - static Error deserialize(ChannelT &C, T &V) { - if (auto Err = C.readBytes(reinterpret_cast(&V), sizeof(T))) - return Err; - support::endian::byte_swap(V); - return Error::success(); - }; -}; - -template -class SerializationTraits - : public RPCByteChannelPrimitiveSerialization { -public: - static const char* getName() { return "uint64_t"; } -}; - -template -class SerializationTraits - : public RPCByteChannelPrimitiveSerialization { -public: - static const char* getName() { return "int64_t"; } -}; - -template -class SerializationTraits - : public RPCByteChannelPrimitiveSerialization { -public: - static const char* getName() { return "uint32_t"; } -}; - -template -class SerializationTraits - : public RPCByteChannelPrimitiveSerialization { -public: - static const char* getName() { return "int32_t"; } -}; - -template -class SerializationTraits - : public RPCByteChannelPrimitiveSerialization { -public: - static const char* getName() { return "uint16_t"; } -}; - -template -class SerializationTraits - : public RPCByteChannelPrimitiveSerialization { -public: - static const char* getName() { return "int16_t"; } -}; - -template -class SerializationTraits - : public RPCByteChannelPrimitiveSerialization { -public: - static const char* getName() { return "uint8_t"; } -}; - -template -class SerializationTraits - : public RPCByteChannelPrimitiveSerialization { -public: - static const char* getName() { return "int8_t"; } -}; - -template -class SerializationTraits - : public RPCByteChannelPrimitiveSerialization { -public: - static const char* getName() { return "char"; } - - static Error serialize(RPCByteChannel &C, char V) { - return serializeSeq(C, static_cast(V)); - }; - - static Error deserialize(RPCByteChannel &C, char &V) { - uint8_t VV; - if (auto Err = deserializeSeq(C, VV)) - return Err; - V = static_cast(V); - return Error::success(); - }; -}; - -template -class SerializationTraits::value>:: - type> { -public: - static const char* getName() { return "bool"; } - - static Error serialize(ChannelT &C, bool V) { - return C.appendBytes(reinterpret_cast(&V), 1); - } - - static Error deserialize(ChannelT &C, bool &V) { - return C.readBytes(reinterpret_cast(&V), 1); - } -}; - -template -class SerializationTraits::value>:: - type> { -public: - static const char* getName() { return "std::string"; } - - static Error serialize(RPCByteChannel &C, StringRef S) { - if (auto Err = SerializationTraits:: - serialize(C, static_cast(S.size()))) - return Err; - return C.appendBytes((const char *)S.bytes_begin(), S.size()); - } - - /// RPC channel serialization for std::strings. - static Error serialize(RPCByteChannel &C, const std::string &S) { - return serialize(C, StringRef(S)); - } - - /// RPC channel deserialization for std::strings. - static Error deserialize(RPCByteChannel &C, std::string &S) { - uint64_t Count = 0; - if (auto Err = SerializationTraits:: - deserialize(C, Count)) - return Err; - S.resize(Count); - return C.readBytes(&S[0], Count); - } -}; - -} // end namespace remote -} // end namespace orc -} // end namespace llvm - -#endif // LLVM_EXECUTIONENGINE_ORC_RPCBYTECHANNEL_H diff --git a/include/llvm/ExecutionEngine/Orc/RPCSerialization.h b/include/llvm/ExecutionEngine/Orc/RPCSerialization.h index 0e9f5157f29..d1503e91b4f 100644 --- a/include/llvm/ExecutionEngine/Orc/RPCSerialization.h +++ b/include/llvm/ExecutionEngine/Orc/RPCSerialization.h @@ -17,7 +17,164 @@ namespace llvm { namespace orc { -namespace remote { +namespace rpc { + +template +class RPCTypeName; + +/// TypeNameSequence is a utility for rendering sequences of types to a string +/// by rendering each type, separated by ", ". +template class RPCTypeNameSequence {}; + +/// Render an empty TypeNameSequence to an ostream. +template +OStream &operator<<(OStream &OS, const RPCTypeNameSequence<> &V) { + return OS; +} + +/// Render a TypeNameSequence of a single type to an ostream. +template +OStream &operator<<(OStream &OS, const RPCTypeNameSequence &V) { + OS << RPCTypeName::getName(); + return OS; +} + +/// Render a TypeNameSequence of more than one type to an ostream. +template +OStream& +operator<<(OStream &OS, const RPCTypeNameSequence &V) { + OS << RPCTypeName::getName() << ", " + << RPCTypeNameSequence(); + return OS; +} + +template <> +class RPCTypeName { +public: + static const char* getName() { return "void"; } +}; + +template <> +class RPCTypeName { +public: + static const char* getName() { return "int8_t"; } +}; + +template <> +class RPCTypeName { +public: + static const char* getName() { return "uint8_t"; } +}; + +template <> +class RPCTypeName { +public: + static const char* getName() { return "int16_t"; } +}; + +template <> +class RPCTypeName { +public: + static const char* getName() { return "uint16_t"; } +}; + +template <> +class RPCTypeName { +public: + static const char* getName() { return "int32_t"; } +}; + +template <> +class RPCTypeName { +public: + static const char* getName() { return "uint32_t"; } +}; + +template <> +class RPCTypeName { +public: + static const char* getName() { return "int64_t"; } +}; + +template <> +class RPCTypeName { +public: + static const char* getName() { return "uint64_t"; } +}; + +template <> +class RPCTypeName { +public: + static const char* getName() { return "bool"; } +}; + +template <> +class RPCTypeName { +public: + static const char* getName() { return "std::string"; } +}; + +template +class RPCTypeName> { +public: + static const char* getName() { + std::lock_guard Lock(NameMutex); + if (Name.empty()) + raw_string_ostream(Name) << "std::pair<" << RPCTypeNameSequence() + << ">"; + return Name.data(); + } +private: + static std::mutex NameMutex; + static std::string Name; +}; + +template +std::mutex RPCTypeName>::NameMutex; +template +std::string RPCTypeName>::Name; + +template +class RPCTypeName> { +public: + static const char* getName() { + std::lock_guard Lock(NameMutex); + if (Name.empty()) + raw_string_ostream(Name) << "std::tuple<" + << RPCTypeNameSequence() << ">"; + return Name.data(); + } +private: + static std::mutex NameMutex; + static std::string Name; +}; + +template +std::mutex RPCTypeName>::NameMutex; +template +std::string RPCTypeName>::Name; + +template +class RPCTypeName> { +public: + static const char*getName() { + std::lock_guard Lock(NameMutex); + if (Name.empty()) + raw_string_ostream(Name) << "std::vector<" << RPCTypeName::getName() + << ">"; + return Name.data(); + } + +private: + static std::mutex NameMutex; + static std::string Name; +}; + +template +std::mutex RPCTypeName>::NameMutex; +template +std::string RPCTypeName>::Name; + /// The SerializationTraits class describes how to serialize and /// deserialize an instance of type T to/from an abstract channel of type @@ -51,71 +208,92 @@ namespace remote { /// } /// /// @endcode -template +template class SerializationTraits {}; -/// TypeNameSequence is a utility for rendering sequences of types to a string -/// by rendering each type, separated by ", ". -template class TypeNameSequence {}; +template +class SequenceTraits { +public: + static Error emitSeparator(ChannelT &C) { return Error::success(); } + static Error consumeSeparator(ChannelT &C) { return Error::success(); } +}; -/// Render a TypeNameSequence of a single type to an ostream. -template -OStream &operator<<(OStream &OS, const TypeNameSequence &V) { - OS << SerializationTraits::getName(); - return OS; -} +/// Utility class for serializing sequences of values of varying types. +/// Specializations of this class contain 'serialize' and 'deserialize' methods +/// for the given channel. The ArgTs... list will determine the "over-the-wire" +/// types to be serialized. The serialize and deserialize methods take a list +/// CArgTs... ("caller arg types") which must be the same length as ArgTs..., +/// but may be different types from ArgTs, provided that for each CArgT there +/// is a SerializationTraits specialization +/// SerializeTraits with methods that can serialize the +/// caller argument to over-the-wire value. +template +class SequenceSerialization; -/// Render a TypeNameSequence of more than one type to an ostream. -template -OStream & -operator<<(OStream &OS, - const TypeNameSequence &V) { - OS << SerializationTraits::getName() << ", " - << TypeNameSequence(); - return OS; -} +template +class SequenceSerialization { +public: + static Error serialize(ChannelT &C) { return Error::success(); } + static Error deserialize(ChannelT &C) { return Error::success(); } +}; -/// RPC channel serialization for a variadic list of arguments. -template -Error serializeSeq(ChannelT &C, const T &Arg, const Ts &... Args) { - if (auto Err = SerializationTraits::serialize(C, Arg)) - return Err; - return serializeSeq(C, Args...); -} +template +class SequenceSerialization { +public: -/// RPC channel serialization for an (empty) variadic list of arguments. -template Error serializeSeq(ChannelT &C) { - return Error::success(); -} + template + static Error serialize(ChannelT &C, const CArgT &CArg) { + return SerializationTraits::serialize(C, CArg); + } + + template + static Error deserialize(ChannelT &C, CArgT &CArg) { + return SerializationTraits::deserialize(C, CArg); + } +}; + +template +class SequenceSerialization { +public: -/// RPC channel deserialization for a variadic list of arguments. -template -Error deserializeSeq(ChannelT &C, T &Arg, Ts &... Args) { - if (auto Err = SerializationTraits::deserialize(C, Arg)) - return Err; - return deserializeSeq(C, Args...); + template + static Error serialize(ChannelT &C, const CArgT &CArg, + const CArgTs&... CArgs) { + if (auto Err = + SerializationTraits::serialize(C, CArg)) + return Err; + if (auto Err = SequenceTraits::emitSeparator(C)) + return Err; + return SequenceSerialization::serialize(C, CArgs...); + } + + template + static Error deserialize(ChannelT &C, CArgT &CArg, + CArgTs&... CArgs) { + if (auto Err = + SerializationTraits::deserialize(C, CArg)) + return Err; + if (auto Err = SequenceTraits::consumeSeparator(C)) + return Err; + return SequenceSerialization::deserialize(C, CArgs...); + } +}; + +template +Error serializeSeq(ChannelT &C, const ArgTs &... Args) { + return SequenceSerialization::serialize(C, Args...); } -/// RPC channel serialization for an (empty) variadic list of arguments. -template Error deserializeSeq(ChannelT &C) { - return Error::success(); +template +Error deserializeSeq(ChannelT &C, ArgTs &... Args) { + return SequenceSerialization::deserialize(C, Args...); } /// SerializationTraits default specialization for std::pair. template class SerializationTraits> { public: - static const char *getName() { - std::lock_guard Lock(NameMutex); - if (Name.empty()) - Name = (std::ostringstream() - << "std::pair<" << TypeNameSequence() << ">") - .str(); - - return Name.data(); - } - static Error serialize(ChannelT &C, const std::pair &V) { return serializeSeq(C, V.first, V.second); } @@ -123,31 +301,12 @@ public: static Error deserialize(ChannelT &C, std::pair &V) { return deserializeSeq(C, V.first, V.second); } - -private: - static std::mutex NameMutex; - static std::string Name; }; -template -std::mutex SerializationTraits>::NameMutex; - -template -std::string SerializationTraits>::Name; - /// SerializationTraits default specialization for std::tuple. template class SerializationTraits> { public: - static const char *getName() { - std::lock_guard Lock(NameMutex); - if (Name.empty()) - Name = (std::ostringstream() - << "std::tuple<" << TypeNameSequence() << ">") - .str(); - - return Name.data(); - } /// RPC channel serialization for std::tuple. static Error serialize(ChannelT &C, const std::tuple &V) { @@ -173,68 +332,41 @@ private: llvm::index_sequence _) { return deserializeSeq(C, std::get(V)...); } - - static std::mutex NameMutex; - static std::string Name; }; -template -std::mutex SerializationTraits>::NameMutex; - -template -std::string SerializationTraits>::Name; - /// SerializationTraits default specialization for std::vector. template class SerializationTraits> { public: - static const char *getName() { - std::lock_guard Lock(NameMutex); - if (Name.empty()) - Name = (std::ostringstream() << "std::vector<" - << TypeNameSequence() << ">") - .str(); - return Name.data(); - } + /// Serialize a std::vector from std::vector. static Error serialize(ChannelT &C, const std::vector &V) { - if (auto Err = SerializationTraits::serialize( - C, static_cast(V.size()))) + if (auto Err = serializeSeq(C, static_cast(V.size()))) return Err; for (const auto &E : V) - if (auto Err = SerializationTraits::serialize(C, E)) + if (auto Err = serializeSeq(C, E)) return Err; return Error::success(); } + /// Deserialize a std::vector to a std::vector. static Error deserialize(ChannelT &C, std::vector &V) { uint64_t Count = 0; - if (auto Err = - SerializationTraits::deserialize(C, Count)) + if (auto Err = deserializeSeq(C, Count)) return Err; V.resize(Count); for (auto &E : V) - if (auto Err = SerializationTraits::deserialize(C, E)) + if (auto Err = deserializeSeq(C, E)) return Err; return Error::success(); } - -private: - static std::mutex NameMutex; - static std::string Name; }; -template -std::mutex SerializationTraits>::NameMutex; - -template -std::string SerializationTraits>::Name; - -} // end namespace remote +} // end namespace rpc } // end namespace orc } // end namespace llvm diff --git a/include/llvm/ExecutionEngine/Orc/RPCUtils.h b/include/llvm/ExecutionEngine/Orc/RPCUtils.h index 436c037e920..2ff27efd72d 100644 --- a/include/llvm/ExecutionEngine/Orc/RPCUtils.h +++ b/include/llvm/ExecutionEngine/Orc/RPCUtils.h @@ -1,4 +1,4 @@ -//===----- RPCUTils.h - Basic tilities for building RPC APIs ----*- C++ -*-===// +//===------- RPCUTils.h - Utilities for building RPC APIs -------*- C++ -*-===// // // The LLVM Compiler Infrastructure // @@ -7,7 +7,11 @@ // //===----------------------------------------------------------------------===// // -// Basic utilities for building RPC APIs. +// Utilities to support construction of simple RPC APIs. +// +// The RPC utilities aim for ease of use (minimal conceptual overhead) for C++ +// programmers, high performance, low memory overhead, and efficient use of the +// communications channel. // //===----------------------------------------------------------------------===// @@ -15,10 +19,12 @@ #define LLVM_EXECUTIONENGINE_ORC_RPCUTILS_H #include +#include #include #include "llvm/ADT/STLExtras.h" #include "llvm/ExecutionEngine/Orc/OrcError.h" +#include "llvm/ExecutionEngine/Orc/RPCSerialization.h" #ifdef _MSC_VER // concrt.h depends on eh.h for __uncaught_exception declaration @@ -39,32 +45,92 @@ namespace llvm { namespace orc { -namespace remote { +namespace rpc { -/// Describes reserved RPC Function Ids. -/// -/// The default implementation will serve for integer and enum function id -/// types. If you want to use a custom type as your FunctionId you can -/// specialize this class and provide unique values for InvalidId, -/// ResponseId and FirstValidId. +template +class Function; -template class RPCFunctionIdTraits { +// RPC Function class. +// DerivedFunc should be a user defined class with a static 'getName()' method +// returning a const char* representing the function's name. +template +class Function { public: - static const T InvalidId = static_cast(0); - static const T ResponseId = static_cast(1); - static const T FirstValidId = static_cast(2); + + /// User defined function type. + using Type = RetT(ArgTs...); + + /// Return type. + using ReturnType = RetT; + + /// Returns the full function prototype as a string. + static const char *getPrototype() { + std::lock_guard Lock(NameMutex); + if (Name.empty()) + raw_string_ostream(Name) + << RPCTypeName::getName() << " " << DerivedFunc::getName() + << "(" << llvm::orc::rpc::RPCTypeNameSequence() << ")"; + return Name.data(); + } +private: + static std::mutex NameMutex; + static std::string Name; }; -// Base class containing utilities that require partial specialization. -// These cannot be included in RPC, as template class members cannot be -// partially specialized. -class RPCBase { -protected: - // FIXME: Remove MSVCPError/MSVCPExpected once MSVC's future implementation - // supports classes without default constructors. +template +std::mutex Function::NameMutex; + +template +std::string Function::Name; + +/// Allocates RPC function ids during autonegotiation. +/// Specializations of this class must provide four members: +/// +/// static T getInvalidId(): +/// Should return a reserved id that will be used to represent missing +/// functions during autonegotiation. +/// +/// static T getResponseId(): +/// Should return a reserved id that will be used to send function responses +/// (return values). +/// +/// static T getNegotiateId(): +/// Should return a reserved id for the negotiate function, which will be used +/// to negotiate ids for user defined functions. +/// +/// template T allocate(): +/// Allocate a unique id for function Func. +template +class RPCFunctionIdAllocator; + +/// This specialization of RPCFunctionIdAllocator provides a default +/// implementation for integral types. +template +class RPCFunctionIdAllocator::value + >::type> { +public: + + static T getInvalidId() { return T(0); } + static T getResponseId() { return T(1); } + static T getNegotiateId() { return T(2); } + + template + T allocate(){ return NextId++; } +private: + T NextId = 3; +}; + +namespace detail { + +// FIXME: Remove MSVCPError/MSVCPExpected once MSVC's future implementation +// supports classes without default constructors. #ifdef _MSC_VER +namespace msvc_hacks { + // Work around MSVC's future implementation's use of default constructors: // A default constructed value in the promise will be overwritten when the // real error is set - so the default constructed Error has to be checked @@ -86,7 +152,7 @@ protected: MSVCPError(Error Err) : Error(std::move(Err)) {} }; - // Likewise for Expected: + // Work around MSVC's future implementation, similar to MSVCPError. template class MSVCPExpected : public Expected { public: @@ -123,488 +189,531 @@ protected: nullptr) : Expected(std::move(Other)) {} }; +} // end namespace msvc_hacks + #endif // _MSC_VER - // RPC Function description type. - // - // This class provides the information and operations needed to support the - // RPC primitive operations (call, expect, etc) for a given function. It - // is specialized for void and non-void functions to deal with the differences - // betwen the two. Both specializations have the same interface: - // - // Id - The function's unique identifier. - // ErrorReturn - The return type for blocking calls. - // readResult - Deserialize a result from a channel. - // abandon - Abandon a promised result. - // respond - Retun a result on the channel. - template - class FunctionHelper {}; - - // RPC Function description specialization for non-void functions. - template - class FunctionHelper { - public: - static_assert(FuncId != RPCFunctionIdTraits::InvalidId && - FuncId != RPCFunctionIdTraits::ResponseId, - "Cannot define custom function with InvalidId or ResponseId. " - "Please use RPCFunctionTraits::FirstValidId."); +// ResultTraits provides typedefs and utilities specific to the return type +// of functions. +template +class ResultTraits { +public: + + // The return type wrapped in llvm::Expected. + using ErrorReturnType = Expected; + +#ifdef _MSC_VER + // The ErrorReturnType wrapped in a std::promise. + using ReturnPromiseType = std::promise>; + + // The ErrorReturnType wrapped in a std::future. + using ReturnFutureType = std::future>; +#else + // The ErrorReturnType wrapped in a std::promise. + using ReturnPromiseType = std::promise; - static const FunctionIdT Id = FuncId; + // The ErrorReturnType wrapped in a std::future. + using ReturnFutureType = std::future; +#endif - typedef Expected ErrorReturn; + // Create a 'blank' value of the ErrorReturnType, ready and safe to + // overwrite. + static ErrorReturnType createBlankErrorReturnValue() { + return ErrorReturnType(RetT()); + } + + // Consume an abandoned ErrorReturnType. + static void consumeAbandoned(ErrorReturnType RetOrErr) { + consumeError(RetOrErr.takeError()); + } +}; + +// ResultTraits specialization for void functions. +template <> +class ResultTraits { +public: + + // For void functions, ErrorReturnType is llvm::Error. + using ErrorReturnType = Error; - // FIXME: Ditch PErrorReturn (replace it with plain ErrorReturn) once MSVC's - // std::future implementation supports types without default - // constructors. #ifdef _MSC_VER - typedef MSVCPExpected PErrorReturn; + // The ErrorReturnType wrapped in a std::promise. + using ReturnPromiseType = std::promise; + + // The ErrorReturnType wrapped in a std::future. + using ReturnFutureType = std::future; #else - typedef Expected PErrorReturn; + // The ErrorReturnType wrapped in a std::promise. + using ReturnPromiseType = std::promise; + + // The ErrorReturnType wrapped in a std::future. + using ReturnFutureType = std::future; #endif - template - static Error readResult(ChannelT &C, std::promise &P) { - RetT Val; - auto Err = deserializeSeq(C, Val); - auto Err2 = endReceiveMessage(C); - Err = joinErrors(std::move(Err), std::move(Err2)); - if (Err) - return Err; + // Create a 'blank' value of the ErrorReturnType, ready and safe to + // overwrite. + static ErrorReturnType createBlankErrorReturnValue() { + return ErrorReturnType::success(); + } - P.set_value(std::move(Val)); - return Error::success(); - } + // Consume an abandoned ErrorReturnType. + static void consumeAbandoned(ErrorReturnType Err) { + consumeError(std::move(Err)); + } +}; - static void abandon(std::promise &P) { - P.set_value( - make_error("RPC function call failed to return", - inconvertibleErrorCode())); - } +// ResultTraits is equivalent to ResultTraits. This allows +// handlers for void RPC functions to return either void (in which case they +// implicitly succeed) or Error (in which case their error return is +// propagated). See usage in HandlerTraits::runHandlerHelper. +template <> +class ResultTraits : public ResultTraits {}; + +// ResultTraits> is equivalent to ResultTraits. This allows +// handlers for RPC functions returning a T to return either a T (in which +// case they implicitly succeed) or Expected (in which case their error +// return is propagated). See usage in HandlerTraits::runHandlerHelper. +template +class ResultTraits> : public ResultTraits {}; + +// Send a response of the given wire return type (WireRetT) over the +// channel, with the given sequence number. +template +static Error respond(ChannelT &C, const FunctionIdT &ResponseId, + SequenceNumberT SeqNo, Expected ResultOrErr) { + // If this was an error bail out. + // FIXME: Send an "error" message to the client if this is not a channel + // failure? + if (auto Err = ResultOrErr.takeError()) + return Err; + + // Open the response message. + if (auto Err = C.startSendMessage(ResponseId, SeqNo)) + return Err; + + // Serialize the result. + if (auto Err = SerializationTraits:: + serialize(C, *ResultOrErr)) + return Err; + + // Close the response message. + return C.endSendMessage(); +} + +// Send an empty response message on the given channel to indicate that +// the handler ran. +template +static Error respond(ChannelT &C, const FunctionIdT &ResponseId, + SequenceNumberT SeqNo, Error Err) { + if (Err) + return Err; + if (auto Err2 = C.startSendMessage(ResponseId, SeqNo)) + return Err2; + return C.endSendMessage(); +} + +// This template class provides utilities related to RPC function handlers. +// The base case applies to non-function types (the template class is +// specialized for function types) and inherits from the appropriate +// speciilization for the given non-function type's call operator. +template +class HandlerTraits + : public HandlerTraits::type::operator())> {}; + +// Traits for handlers with a given function type. +template +class HandlerTraits { +public: - static void consumeAbandoned(std::future &P) { - consumeError(P.get().takeError()); - } + // Function type of the handler. + using Type = RetT(ArgTs...); - template - static Error respond(ChannelT &C, SequenceNumberT SeqNo, - ErrorReturn &Result) { - FunctionIdT ResponseId = RPCFunctionIdTraits::ResponseId; + // Return type of the handler. + using ReturnType = RetT; - // If the handler returned an error then bail out with that. - if (!Result) - return Result.takeError(); + // A std::tuple wrapping the handler arguments. + using ArgStorage = + std::tuple< + typename std::decay< + typename std::remove_reference::type>::type...>; - // Otherwise open a new message on the channel and send the result. - if (auto Err = startSendMessage(C)) - return Err; - if (auto Err = serializeSeq(C, ResponseId, SeqNo, *Result)) - return Err; - return endSendMessage(C); - } - }; + // Call the given handler with the given arguments. + template + static typename ResultTraits::ErrorReturnType + runHandler(HandlerT &Handler, ArgStorage &Args) { + return runHandlerHelper(Handler, Args, + llvm::index_sequence_for()); + } - // RPC Function description specialization for void functions. - template - class FunctionHelper { - public: - static_assert(FuncId != RPCFunctionIdTraits::InvalidId && - FuncId != RPCFunctionIdTraits::ResponseId, - "Cannot define custom function with InvalidId or ResponseId. " - "Please use RPCFunctionTraits::FirstValidId."); + // Serialize arguments to the channel. + template + static Error serializeArgs(ChannelT &C, const CArgTs... CArgs) { + return SequenceSerialization::serialize(C, CArgs...); + } - static const FunctionIdT Id = FuncId; + // Deserialize arguments from the channel. + template + static Error deserializeArgs(ChannelT &C, std::tuple &Args) { + return deserializeArgsHelper(C, Args, + llvm::index_sequence_for()); + } - typedef Error ErrorReturn; +private: - // FIXME: Ditch PErrorReturn (replace it with plain ErrorReturn) once MSVC's - // std::future implementation supports types without default - // constructors. -#ifdef _MSC_VER - typedef MSVCPError PErrorReturn; -#else - typedef Error PErrorReturn; -#endif + // For non-void user handlers: unwrap the args tuple and call the handler, + // returning the result. + template + static typename std::enable_if< + !std::is_void::value, + typename ResultTraits::ErrorReturnType>::type + runHandlerHelper(HandlerT &Handler, ArgStorage &Args, + llvm::index_sequence) { + return Handler(std::move(std::get(Args))...); + } - template - static Error readResult(ChannelT &C, std::promise &P) { - // Void functions don't have anything to deserialize, so we're good. - P.set_value(Error::success()); - return endReceiveMessage(C); - } + // For void user handlers: unwrap the args tuple and call the handler, then + // return Error::success(). + template + static typename std::enable_if< + std::is_void::value, + typename ResultTraits::ErrorReturnType>::type + runHandlerHelper(HandlerT &Handler, ArgStorage &Args, + llvm::index_sequence) { + Handler(std::move(std::get(Args))...); + return ResultTraits::ErrorReturnType::success(); + } - static void abandon(std::promise &P) { - P.set_value( - make_error("RPC function call failed to return", - inconvertibleErrorCode())); - } + template + static + Error deserializeArgsHelper(ChannelT &C, std::tuple &Args, + llvm::index_sequence _) { + return SequenceSerialization:: + deserialize(C, std::get(Args)...); + } - static void consumeAbandoned(std::future &P) { - consumeError(P.get()); - } +}; - template - static Error respond(ChannelT &C, SequenceNumberT SeqNo, - ErrorReturn &Result) { - const FunctionIdT ResponseId = - RPCFunctionIdTraits::ResponseId; +// Handler traits for class methods (especially call operators for lambdas). +template +class HandlerTraits + : public HandlerTraits {}; - // If the handler returned an error then bail out with that. - if (Result) - return std::move(Result); +// Handler traits for const class methods (especially call operators for +// lambdas). +template +class HandlerTraits + : public HandlerTraits {}; - // Otherwise open a new message on the channel and send the result. - if (auto Err = startSendMessage(C)) - return Err; - if (auto Err = serializeSeq(C, ResponseId, SeqNo)) - return Err; - return endSendMessage(C); - } - }; +// Utility to peel the Expected wrapper off a response handler error type. +template +class UnwrapResponseHandlerArg; - // Helper for the call primitive. - template - class CallHelper; +template +class UnwrapResponseHandlerArg)> { +public: + using ArgType = ArgT; +}; - template - class CallHelper> { - public: - static Error call(ChannelT &C, SequenceNumberT SeqNo, - const ArgTs &... Args) { - if (auto Err = startSendMessage(C)) - return Err; - if (auto Err = serializeSeq(C, FuncId, SeqNo, Args...)) - return Err; - return endSendMessage(C); +template +class UnwrapResponseHandlerArg)> { +public: + using ArgType = ArgT; +}; + + +// ResponseHandler represents a handler for a not-yet-received function call +// result. +template +class ResponseHandler { +public: + virtual ~ResponseHandler() {} + + // Reads the function result off the wire and acts on it. The meaning of + // "act" will depend on how this method is implemented in any given + // ResponseHandler subclass but could, for example, mean running a + // user-specified handler or setting a promise value. + virtual Error handleResponse(ChannelT &C) = 0; + + // Abandons this outstanding result. + virtual void abandon() = 0; + + // Create an error instance representing an abandoned response. + static Error createAbandonedResponseError() { + return make_error("RPC function call failed to return", + inconvertibleErrorCode()); + } +}; + +// ResponseHandler subclass for RPC functions with non-void returns. +template +class ResponseHandlerImpl : public ResponseHandler { +public: + ResponseHandlerImpl(HandlerT Handler) + : Handler(std::move(Handler)) {} + + // Handle the result by deserializing it from the channel then passing it + // to the user defined handler. + Error handleResponse(ChannelT &C) override { + using ArgType = typename UnwrapResponseHandlerArg< + typename HandlerTraits::Type>::ArgType; + ArgType Result; + if (auto Err = SerializationTraits:: + deserialize(C, Result)) + return Err; + if (auto Err = C.endReceiveMessage()) + return Err; + return Handler(Result); + } + + // Abandon this response by calling the handler with an 'abandoned response' + // error. + void abandon() override { + if (auto Err = Handler(this->createAbandonedResponseError())) { + // Handlers should not fail when passed an abandoned response error. + report_fatal_error(std::move(Err)); } - }; + } - // Helper for handle primitive. - template - class HandlerHelper; +private: + HandlerT Handler; +}; - template - class HandlerHelper> { - public: - template - static Error handle(ChannelT &C, HandlerT Handler) { - return readAndHandle(C, Handler, llvm::index_sequence_for()); +// ResponseHandler subclass for RPC functions with void returns. +template +class ResponseHandlerImpl + : public ResponseHandler { +public: + ResponseHandlerImpl(HandlerT Handler) + : Handler(std::move(Handler)) {} + + // Handle the result (no actual value, just a notification that the function + // has completed on the remote end) by calling the user-defined handler with + // Error::success(). + Error handleResponse(ChannelT &C) override { + if (auto Err = C.endReceiveMessage()) + return Err; + return Handler(Error::success()); + } + + // Abandon this response by calling the handler with an 'abandoned response' + // error. + void abandon() override { + if (auto Err = Handler(this->createAbandonedResponseError())) { + // Handlers should not fail when passed an abandoned response error. + report_fatal_error(std::move(Err)); } + } - private: - typedef FunctionHelper Func; - - template - static Error readAndHandle(ChannelT &C, HandlerT Handler, - llvm::index_sequence _) { - std::tuple RPCArgs; - SequenceNumberT SeqNo; - // GCC 4.7 and 4.8 incorrectly issue a -Wunused-but-set-variable warning - // for RPCArgs. Void cast RPCArgs to work around this for now. - // FIXME: Remove this workaround once we can assume a working GCC version. - (void)RPCArgs; - if (auto Err = deserializeSeq(C, SeqNo, std::get(RPCArgs)...)) - return Err; +private: + HandlerT Handler; +}; - // We've deserialized the arguments, so unlock the channel for reading - // before we call the handler. This allows recursive RPC calls. - if (auto Err = endReceiveMessage(C)) - return Err; +// Create a ResponseHandler from a given user handler. +template +std::unique_ptr> +createResponseHandler(HandlerT H) { + return llvm::make_unique< + ResponseHandlerImpl>(std::move(H)); +} + +// Helper for wrapping member functions up as functors. This is useful for +// installing methods as result handlers. +template +class MemberFnWrapper { +public: + using MethodT = RetT(ClassT::*)(ArgTs...); + MemberFnWrapper(ClassT &Instance, MethodT Method) + : Instance(Instance), Method(Method) {} + RetT operator()(ArgTs &&... Args) { + return (Instance.*Method)(std::move(Args)...); + } +private: + ClassT &Instance; + MethodT Method; +}; - // Run the handler and get the result. - auto Result = Handler(std::get(RPCArgs)...); +// Helper that provides a Functor for deserializing arguments. +template class ReadArgs { +public: + Error operator()() { return Error::success(); } +}; - // Return the result to the client. - return Func::template respond(C, SeqNo, - Result); - } - }; +template +class ReadArgs : public ReadArgs { +public: + ReadArgs(ArgT &Arg, ArgTs &... Args) + : ReadArgs(Args...), Arg(Arg) {} - // Helper for wrapping member functions up as functors. - template - class MemberFnWrapper { - public: - typedef RetT (ClassT::*MethodT)(ArgTs...); - MemberFnWrapper(ClassT &Instance, MethodT Method) - : Instance(Instance), Method(Method) {} - RetT operator()(ArgTs &... Args) { return (Instance.*Method)(Args...); } - - private: - ClassT &Instance; - MethodT Method; - }; + Error operator()(ArgT &ArgVal, ArgTs &... ArgVals) { + this->Arg = std::move(ArgVal); + return ReadArgs::operator()(ArgVals...); + } +private: + ArgT &Arg; +}; - // Helper that provides a Functor for deserializing arguments. - template class ReadArgs { - public: - Error operator()() { return Error::success(); } - }; +// Manage sequence numbers. +template +class SequenceNumberManager { +public: + // Reset, making all sequence numbers available. + void reset() { + std::lock_guard Lock(SeqNoLock); + NextSequenceNumber = 0; + FreeSequenceNumbers.clear(); + } - template - class ReadArgs : public ReadArgs { - public: - ReadArgs(ArgT &Arg, ArgTs &... Args) - : ReadArgs(Args...), Arg(Arg) {} + // Get the next available sequence number. Will re-use numbers that have + // been released. + SequenceNumberT getSequenceNumber() { + std::lock_guard Lock(SeqNoLock); + if (FreeSequenceNumbers.empty()) + return NextSequenceNumber++; + auto SequenceNumber = FreeSequenceNumbers.back(); + FreeSequenceNumbers.pop_back(); + return SequenceNumber; + } - Error operator()(ArgT &ArgVal, ArgTs &... ArgVals) { - this->Arg = std::move(ArgVal); - return ReadArgs::operator()(ArgVals...); - } + // Release a sequence number, making it available for re-use. + void releaseSequenceNumber(SequenceNumberT SequenceNumber) { + std::lock_guard Lock(SeqNoLock); + FreeSequenceNumbers.push_back(SequenceNumber); + } - private: - ArgT &Arg; - }; +private: + std::mutex SeqNoLock; + SequenceNumberT NextSequenceNumber = 0; + std::vector FreeSequenceNumbers; }; /// Contains primitive utilities for defining, calling and handling calls to /// remote procedures. ChannelT is a bidirectional stream conforming to the -/// RPCChannel interface (see RPCChannel.h), and FunctionIdT is a procedure -/// identifier type that must be serializable on ChannelT. +/// RPCChannel interface (see RPCChannel.h), FunctionIdT is a procedure +/// identifier type that must be serializable on ChannelT, and SequenceNumberT +/// is an integral type that will be used to number in-flight function calls. /// /// These utilities support the construction of very primitive RPC utilities. /// Their intent is to ensure correct serialization and deserialization of /// procedure arguments, and to keep the client and server's view of the API in /// sync. -/// -/// These utilities do not support return values. These can be handled by -/// declaring a corresponding '.*Response' procedure and expecting it after a -/// call). They also do not support versioning: the client and server *must* be -/// compiled with the same procedure definitions. -/// -/// -/// -/// Overview (see comments individual types/methods for details): -/// -/// Function : -/// -/// associates a unique serializable id with an argument list. -/// -/// -/// call(Channel, Args...) : -/// -/// Calls the remote procedure 'Func' by serializing Func's id followed by its -/// arguments and sending the resulting bytes to 'Channel'. -/// -/// -/// handle(Channel, : -/// -/// Handles a call to 'Func' by deserializing its arguments and calling the -/// given functor. This assumes that the id for 'Func' has already been -/// deserialized. -/// -/// expect(Channel, : -/// -/// The same as 'handle', except that the procedure id should not have been -/// read yet. Expect will deserialize the id and assert that it matches Func's -/// id. If it does not, and unexpected RPC call error is returned. -template -class RPC : public RPCBase { -public: - /// RPC default constructor. - RPC() = default; +template +class RPCBase { +protected: - /// RPC instances cannot be copied. - RPC(RPC &&) = default; - RPC &operator=(RPC &&) = default; + class OrcRPCInvalid : public Function { + public: + static const char *getName() { return "__orc_rpc$invalid"; } + }; - /// Utility class for defining/referring to RPC procedures. - /// - /// Typedefs of this utility are used when calling/handling remote procedures. - /// - /// FuncId should be a unique value of FunctionIdT (i.e. not used with any - /// other Function typedef in the RPC API being defined. - /// - /// the template argument Ts... gives the argument list for the remote - /// procedure. - /// - /// E.g. - /// - /// typedef Function<0, bool> Func1; - /// typedef Function<1, std::string, std::vector> Func2; - /// - /// if (auto Err = call(Channel, true)) - /// /* handle Err */; - /// - /// if (auto Err = expect(Channel, - /// [](std::string &S, std::vector &V) { - /// // Stuff. - /// return Error::success(); - /// }) - /// /* handle Err */; - /// - template - using Function = FunctionHelper; + class OrcRPCResponse : public Function { + public: + static const char *getName() { return "__orc_rpc$response"; } + }; - /// Return type for non-blocking call primitives. - template - using NonBlockingCallResult = std::future; + class OrcRPCNegotiate + : public Function { + public: + static const char *getName() { return "__orc_rpc$negotiate"; } + }; - /// Return type for non-blocking call-with-seq primitives. - template - using NonBlockingCallWithSeqResult = - std::pair, SequenceNumberT>; +public: - /// Call Func on Channel C. Does not block, does not call send. Returns a pair - /// of a future result and the sequence number assigned to the result. - /// - /// This utility function is primarily used for single-threaded mode support, - /// where the sequence number can be used to wait for the corresponding - /// result. In multi-threaded mode the appendCallNB method, which does not - /// return the sequence numeber, should be preferred. - template - Expected> - appendCallNBWithSeq(ChannelT &C, const ArgTs &... Args) { - auto SeqNo = SequenceNumberMgr.getSequenceNumber(); - std::promise Promise; - auto Result = Promise.get_future(); - OutstandingResults[SeqNo] = - createOutstandingResult(std::move(Promise)); - - if (auto Err = CallHelper::call(C, SeqNo, - Args...)) { - abandonOutstandingResults(); - Func::consumeAbandoned(Result); - return std::move(Err); - } else - return NonBlockingCallWithSeqResult(std::move(Result), SeqNo); + /// Construct an RPC instance on a channel. + RPCBase(ChannelT &C, bool LazyAutoNegotiation) + : C(C), LazyAutoNegotiation(LazyAutoNegotiation) { + // Hold ResponseId in a special variable, since we expect Response to be + // called relatively frequently, and want to avoid the map lookup. + ResponseId = FnIdAllocator.getResponseId(); + RemoteFunctionIds[OrcRPCResponse::getPrototype()] = ResponseId; + + // Register the negotiate function id and handler. + auto NegotiateId = FnIdAllocator.getNegotiateId(); + RemoteFunctionIds[OrcRPCNegotiate::getPrototype()] = NegotiateId; + Handlers[NegotiateId] = + wrapHandler([this](const std::string &Name) { + return handleNegotiate(Name); + }, LaunchPolicy()); } - /// The same as appendCallNBWithSeq, except that it calls C.send() to - /// flush the channel after serializing the call. - template - Expected> - callNBWithSeq(ChannelT &C, const ArgTs &... Args) { - auto Result = appendCallNBWithSeq(C, Args...); - if (!Result) - return Result; - if (auto Err = C.send()) { - abandonOutstandingResults(); - Func::consumeAbandoned(Result->first); - return std::move(Err); + /// Append a call Func, does not call send on the channel. + /// The first argument specifies a user-defined handler to be run when the + /// function returns. The handler should take an Expected, + /// or an Error (if Func::ReturnType is void). The handler will be called + /// with an error if the return value is abandoned due to a channel error. + template + Error appendCallAsync(HandlerT Handler, const ArgTs &... Args) { + // Look up the function ID. + FunctionIdT FnId; + if (auto FnIdOrErr = getRemoteFunctionId()) + FnId = *FnIdOrErr; + else { + // This isn't a channel error so we don't want to abandon other pending + // responses, but we still need to run the user handler with an error to + // let them know the call failed. + if (auto Err = Handler(orcError(OrcErrorCode::UnknownRPCFunction))) + report_fatal_error(std::move(Err)); + return FnIdOrErr.takeError(); } - return Result; - } - /// Serialize Args... to channel C, but do not call send. - /// Returns an error if serialization fails, otherwise returns a - /// std::future> (or a future for void functions). - template - Expected> appendCallNB(ChannelT &C, - const ArgTs &... Args) { - auto FutureResAndSeqOrErr = appendCallNBWithSeq(C, Args...); - if (FutureResAndSeqOrErr) - return std::move(FutureResAndSeqOrErr->first); - return FutureResAndSeqOrErr.takeError(); - } + // Allocate a sequence number. + auto SeqNo = SequenceNumberMgr.getSequenceNumber(); + assert(!PendingResponses.count(SeqNo) && + "Sequence number already allocated"); + + // Install the user handler. + PendingResponses[SeqNo] = + detail::createResponseHandler( + std::move(Handler)); + + // Open the function call message. + if (auto Err = C.startSendMessage(FnId, SeqNo)) { + abandonPendingResponses(); + return joinErrors(std::move(Err), C.endSendMessage()); + } - /// The same as appendCallNB, except that it calls C.send to flush the - /// channel after serializing the call. - template - Expected> callNB(ChannelT &C, - const ArgTs &... Args) { - auto FutureResAndSeqOrErr = callNBWithSeq(C, Args...); - if (FutureResAndSeqOrErr) - return std::move(FutureResAndSeqOrErr->first); - return FutureResAndSeqOrErr.takeError(); - } + // Serialize the call arguments. + if (auto Err = + detail::HandlerTraits:: + serializeArgs(C, Args...)) { + abandonPendingResponses(); + return joinErrors(std::move(Err), C.endSendMessage()); + } - /// Call Func on Channel C. Blocks waiting for a result. Returns an Error - /// for void functions or an Expected for functions returning a T. - /// - /// This function is for use in threaded code where another thread is - /// handling responses and incoming calls. - template - typename Func::ErrorReturn callB(ChannelT &C, const ArgTs &... Args) { - if (auto FutureResOrErr = callNBWithSeq(C, Args...)) { - if (auto Err = C.send()) { - abandonOutstandingResults(); - Func::consumeAbandoned(FutureResOrErr->first); - return std::move(Err); - } - return FutureResOrErr->first.get(); - } else - return FutureResOrErr.takeError(); - } + // Close the function call messagee. + if (auto Err = C.endSendMessage()) { + abandonPendingResponses(); + return std::move(Err); + } - /// Call Func on Channel C. Block waiting for a result. While blocked, run - /// HandleOther to handle incoming calls (Response calls will be handled - /// implicitly before calling HandleOther). Returns an Error for void - /// functions or an Expected for functions returning a T. - /// - /// This function is for use in single threaded mode when the calling thread - /// must act as both sender and receiver. - template - typename Func::ErrorReturn - callSTHandling(ChannelT &C, HandleFtor &HandleOther, const ArgTs &... Args) { - if (auto ResultAndSeqNoOrErr = callNBWithSeq(C, Args...)) { - auto &ResultAndSeqNo = *ResultAndSeqNoOrErr; - if (auto Err = waitForResult(C, ResultAndSeqNo.second, HandleOther)) - return std::move(Err); - return ResultAndSeqNo.first.get(); - } else - return ResultAndSeqNoOrErr.takeError(); + return Error::success(); } - /// Call Func on Channel C. Block waiting for a result. Returns an Error for - /// void functions or an Expected for functions returning a T. - template - typename Func::ErrorReturn callST(ChannelT &C, const ArgTs &... Args) { - return callSTHandling(C, handleNone, Args...); - } - /// Start receiving a new function call. - /// - /// Calls startReceiveMessage on the channel, then deserializes a FunctionId - /// into Id. - Error startReceivingFunction(ChannelT &C, FunctionIdT &Id) { - if (auto Err = startReceiveMessage(C)) + template + Error callAsync(HandlerT Handler, const ArgTs &... Args) { + if (auto Err = appendCallAsync(std::move(Handler), Args...)) return Err; - - return deserializeSeq(C, Id); + return C.send(); } - /// Deserialize args for Func from C and call Handler. The signature of - /// handler must conform to 'Error(Args...)' where Args... matches - /// the arguments used in the Func typedef. - template - static Error handle(ChannelT &C, HandlerT Handler) { - return HandlerHelper::handle(C, Handler); - } - - /// Helper version of 'handle' for calling member functions. - template - static Error handle(ChannelT &C, ClassT &Instance, - RetT (ClassT::*HandlerMethod)(ArgTs...)) { - return handle( - C, MemberFnWrapper(Instance, HandlerMethod)); - } - - /// Deserialize a FunctionIdT from C and verify it matches the id for Func. - /// If the id does match, deserialize the arguments and call the handler - /// (similarly to handle). - /// If the id does not match, return an unexpect RPC call error and do not - /// deserialize any further bytes. - template - Error expect(ChannelT &C, HandlerT Handler) { - FunctionIdT FuncId; - if (auto Err = startReceivingFunction(C, FuncId)) - return std::move(Err); - if (FuncId != Func::Id) - return orcError(OrcErrorCode::UnexpectedRPCCall); - return handle(C, Handler); - } + /// Handle one incoming call. + Error handleOne() { + FunctionIdT FnId; + SequenceNumberT SeqNo; + if (auto Err = C.startReceiveMessage(FnId, SeqNo)) + return Err; + if (FnId == ResponseId) + return handleResponse(SeqNo); + auto I = Handlers.find(FnId); + if (I != Handlers.end()) + return I->second(C, SeqNo); - /// Helper version of expect for calling member functions. - template - static Error expect(ChannelT &C, ClassT &Instance, - Error (ClassT::*HandlerMethod)(ArgTs...)) { - return expect( - C, MemberFnWrapper(Instance, HandlerMethod)); + // else: No handler found. Report error to client? + return orcError(OrcErrorCode::UnexpectedRPCCall); } /// Helper for handling setter procedures - this method returns a functor that @@ -621,160 +730,417 @@ public: /// /* Handle Args */ ; /// template - static ReadArgs readArgs(ArgTs &... Args) { - return ReadArgs(Args...); + static detail::ReadArgs readArgs(ArgTs &... Args) { + return detail::ReadArgs(Args...); } - /// Read a response from Channel. - /// This should be called from the receive loop to retrieve results. - Error handleResponse(ChannelT &C, SequenceNumberT *SeqNoRet = nullptr) { - SequenceNumberT SeqNo; - if (auto Err = deserializeSeq(C, SeqNo)) { - abandonOutstandingResults(); - return Err; - } +protected: + // The LaunchPolicy type allows a launch policy to be specified when adding + // a function handler. See addHandlerImpl. + using LaunchPolicy = std::function)>; + + /// Add the given handler to the handler map and make it available for + /// autonegotiation and execution. + template + void addHandlerImpl(HandlerT Handler, LaunchPolicy Launch) { + FunctionIdT NewFnId = FnIdAllocator.template allocate(); + LocalFunctionIds[Func::getPrototype()] = NewFnId; + Handlers[NewFnId] = wrapHandler(std::move(Handler), + std::move(Launch)); + } - if (SeqNoRet) - *SeqNoRet = SeqNo; + // Abandon all outstanding results. + void abandonPendingResponses() { + for (auto &KV : PendingResponses) + KV.second->abandon(); + PendingResponses.clear(); + SequenceNumberMgr.reset(); + } - auto I = OutstandingResults.find(SeqNo); - if (I == OutstandingResults.end()) { - abandonOutstandingResults(); + Error handleResponse(SequenceNumberT SeqNo) { + auto I = PendingResponses.find(SeqNo); + if (I == PendingResponses.end()) { + abandonPendingResponses(); return orcError(OrcErrorCode::UnexpectedRPCResponse); } - if (auto Err = I->second->readResult(C)) { - abandonOutstandingResults(); - // FIXME: Release sequence numbers? + auto PRHandler = std::move(I->second); + PendingResponses.erase(I); + SequenceNumberMgr.releaseSequenceNumber(SeqNo); + + if (auto Err = PRHandler->handleResponse(C)) { + abandonPendingResponses(); + SequenceNumberMgr.reset(); return Err; } - OutstandingResults.erase(I); - SequenceNumberMgr.releaseSequenceNumber(SeqNo); - return Error::success(); } - // Loop waiting for a result with the given sequence number. - // This can be used as a receive loop if the user doesn't have a default. - template - Error waitForResult(ChannelT &C, SequenceNumberT TgtSeqNo, - HandleOtherFtor &HandleOther = handleNone) { - bool GotTgtResult = false; + FunctionIdT handleNegotiate(const std::string &Name) { + auto I = LocalFunctionIds.find(Name); + if (I == LocalFunctionIds.end()) + return FnIdAllocator.getInvalidId(); + return I->second; + } - while (!GotTgtResult) { - FunctionIdT Id = RPCFunctionIdTraits::InvalidId; - if (auto Err = startReceivingFunction(C, Id)) - return Err; - if (Id == RPCFunctionIdTraits::ResponseId) { - SequenceNumberT SeqNo; - if (auto Err = handleResponse(C, &SeqNo)) - return Err; - GotTgtResult = (SeqNo == TgtSeqNo); - } else if (auto Err = HandleOther(C, Id)) - return Err; + // Find the remote FunctionId for the given function, which must be in the + // RemoteFunctionIds map. + template + Expected getRemoteFunctionId() { + // Try to find the id for the given function. + auto I = RemoteFunctionIds.find(Func::getPrototype()); + + // If we have it in the map, return it. + if (I != RemoteFunctionIds.end()) + return I->second; + + // Otherwise, if we have auto-negotiation enabled, try to negotiate it. + if (LazyAutoNegotiation) { + auto &Impl = static_cast(*this); + if (auto RemoteIdOrErr = + Impl.template callB(Func::getPrototype())) { + auto &RemoteId = *RemoteIdOrErr; + + // If autonegotiation indicates that the remote end doesn't support this + // function, return an unknown function error. + if (RemoteId == FnIdAllocator.getInvalidId()) + return orcError(OrcErrorCode::UnknownRPCFunction); + + // Autonegotiation succeeded and returned a valid id. Update the map and + // return the id. + RemoteFunctionIds[Func::getPrototype()] = RemoteId; + return RemoteId; + } else { + // Autonegotiation failed. Return the error. + return RemoteIdOrErr.takeError(); + } } - return Error::success(); + // No key was available in the map and autonegotiation wasn't enabled. + // Return an unknown function error. + return orcError(OrcErrorCode::UnknownRPCFunction); } - // Default handler for 'other' (non-response) functions when waiting for a - // result from the channel. - static Error handleNone(ChannelT &, FunctionIdT) { - return orcError(OrcErrorCode::UnexpectedRPCCall); - }; + using WrappedHandlerFn = std::function; + + // Wrap the given user handler in the necessary argument-deserialization code, + // result-serialization code, and call to the launch policy (if present). + template + WrappedHandlerFn wrapHandler(HandlerT Handler, LaunchPolicy Launch) { + return + [this, Handler, Launch](ChannelT &Channel, SequenceNumberT SeqNo) -> Error { + // Start by deserializing the arguments. + auto Args = + std::make_shared::ArgStorage>(); + if (auto Err = detail::HandlerTraits:: + deserializeArgs(Channel, *Args)) + return Err; + + // GCC 4.7 and 4.8 incorrectly issue a -Wunused-but-set-variable warning + // for RPCArgs. Void cast RPCArgs to work around this for now. + // FIXME: Remove this workaround once we can assume a working GCC version. + (void)Args; + + // End receieve message, unlocking the channel for reading. + if (auto Err = Channel.endReceiveMessage()) + return Err; + + // Build the handler/responder. + auto Responder = + [this, Handler, Args, &Channel, SeqNo]() mutable -> Error { + using HTraits = detail::HandlerTraits; + using FuncReturn = typename Func::ReturnType; + return detail::respond(Channel, ResponseId, SeqNo, + HTraits::runHandler(Handler, + *Args)); + }; + + // If there is an explicit launch policy then use it to launch the + // handler. + if (Launch) + return Launch(std::move(Responder)); + + // Otherwise run the handler on the listener thread. + return Responder(); + }; + } + + ChannelT &C; + + bool LazyAutoNegotiation; + RPCFunctionIdAllocator FnIdAllocator; + + FunctionIdT ResponseId; + std::map LocalFunctionIds; + std::map RemoteFunctionIds; + + std::map Handlers; + + detail::SequenceNumberManager SequenceNumberMgr; + std::map>> + PendingResponses; +}; + +} // end namespace detail + + +template +class MultiThreadedRPC + : public detail::RPCBase, + ChannelT, FunctionIdT, SequenceNumberT> { private: - // Manage sequence numbers. - class SequenceNumberManager { - public: - SequenceNumberManager() = default; + using BaseClass = + detail::RPCBase, + ChannelT, FunctionIdT, SequenceNumberT>; - SequenceNumberManager(const SequenceNumberManager &) = delete; - SequenceNumberManager &operator=(const SequenceNumberManager &) = delete; +public: - SequenceNumberManager(SequenceNumberManager &&Other) - : NextSequenceNumber(std::move(Other.NextSequenceNumber)), - FreeSequenceNumbers(std::move(Other.FreeSequenceNumbers)) {} + MultiThreadedRPC(ChannelT &C, bool LazyAutoNegotiation) + : BaseClass(C, LazyAutoNegotiation) {} - SequenceNumberManager &operator=(SequenceNumberManager &&Other) { - NextSequenceNumber = std::move(Other.NextSequenceNumber); - FreeSequenceNumbers = std::move(Other.FreeSequenceNumbers); - return *this; - } + /// The LaunchPolicy type allows a launch policy to be specified when adding + /// a function handler. See addHandler. + using LaunchPolicy = typename BaseClass::LaunchPolicy; - void reset() { - std::lock_guard Lock(SeqNoLock); - NextSequenceNumber = 0; - FreeSequenceNumbers.clear(); - } + /// Add a handler for the given RPC function. + /// This installs the given handler functor for the given RPC Function, and + /// makes the RPC function available for negotiation/calling from the remote. + /// + /// The optional LaunchPolicy argument can be used to control how the handler + /// is run when called: + /// + /// * If no LaunchPolicy is given, the handler code will be run on the RPC + /// handler thread that is reading from the channel. This handler cannot + /// make blocking RPC calls (since it would be blocking the thread used to + /// get the result), but can make non-blocking calls. + /// + /// * If a LaunchPolicy is given, the user's handler will be wrapped in a + /// call to serialize and send the result, and the resulting functor (with + /// type 'Error()' will be passed to the LaunchPolicy. The user can then + /// choose to add the wrapped handler to a work queue, spawn a new thread, + /// or anything else. + template + void addHandler(HandlerT Handler, LaunchPolicy Launch = LaunchPolicy()) { + return this->template addHandlerImpl(std::move(Handler), + std::move(Launch)); + } + + /// Negotiate a function id for Func with the other end of the channel. + template + Error negotiateFunction() { + using OrcRPCNegotiate = typename BaseClass::OrcRPCNegotiate; + + if (auto RemoteIdOrErr = callB(Func::getPrototype())) { + this->RemoteFunctionIds[Func::getPrototype()] = *RemoteIdOrErr; + return Error::success(); + } else + return RemoteIdOrErr.takeError(); + } - SequenceNumberT getSequenceNumber() { - std::lock_guard Lock(SeqNoLock); - if (FreeSequenceNumbers.empty()) - return NextSequenceNumber++; - auto SequenceNumber = FreeSequenceNumbers.back(); - FreeSequenceNumbers.pop_back(); - return SequenceNumber; + /// Convenience method for negotiating multiple functions at once. + template + Error negotiateFunctions() { + return negotiateFunction(); + } + + /// Convenience method for negotiating multiple functions at once. + template + Error negotiateFunctions() { + if (auto Err = negotiateFunction()) + return Err; + return negotiateFunctions(); + } + + /// Return type for non-blocking call primitives. + template + using NonBlockingCallResult = + typename detail::ResultTraits::ReturnFutureType; + + /// Call Func on Channel C. Does not block, does not call send. Returns a pair + /// of a future result and the sequence number assigned to the result. + /// + /// This utility function is primarily used for single-threaded mode support, + /// where the sequence number can be used to wait for the corresponding + /// result. In multi-threaded mode the appendCallNB method, which does not + /// return the sequence numeber, should be preferred. + template + Expected> + appendCallNB(const ArgTs &... Args) { + using RTraits = detail::ResultTraits; + using ErrorReturn = typename RTraits::ErrorReturnType; + using ErrorReturnPromise = typename RTraits::ReturnPromiseType; + + // FIXME: Stack allocate and move this into the handler once LLVM builds + // with C++14. + auto Promise = std::make_shared(); + auto FutureResult = Promise->get_future(); + + if (auto Err = this->template appendCallAsync( + [Promise](ErrorReturn RetOrErr) { + Promise->set_value(std::move(RetOrErr)); + return Error::success(); + }, Args...)) { + this->abandonPendingResponses(); + RTraits::consumeAbandoned(FutureResult.get()); + return std::move(Err); } + return std::move(FutureResult); + } - void releaseSequenceNumber(SequenceNumberT SequenceNumber) { - std::lock_guard Lock(SeqNoLock); - FreeSequenceNumbers.push_back(SequenceNumber); + /// The same as appendCallNBWithSeq, except that it calls C.send() to + /// flush the channel after serializing the call. + template + Expected> + callNB(const ArgTs &... Args) { + auto Result = appendCallNB(Args...); + if (!Result) + return Result; + if (auto Err = this->C.send()) { + this->abandonPendingResponses(); + detail::ResultTraits:: + consumeAbandoned(std::move(Result->get())); + return std::move(Err); } + return Result; + } - private: - std::mutex SeqNoLock; - SequenceNumberT NextSequenceNumber = 0; - std::vector FreeSequenceNumbers; - }; + /// Call Func on Channel C. Blocks waiting for a result. Returns an Error + /// for void functions or an Expected for functions returning a T. + /// + /// This function is for use in threaded code where another thread is + /// handling responses and incoming calls. + template + typename detail::ResultTraits::ErrorReturnType + callB(const ArgTs &... Args) { + if (auto FutureResOrErr = callNB(Args...)) { + if (auto Err = this->C.send()) { + this->abandonPendingResponses(); + detail::ResultTraits:: + consumeAbandoned(std::move(FutureResOrErr->get())); + return std::move(Err); + } + return FutureResOrErr->get(); + } else + return FutureResOrErr.takeError(); + } - // Base class for results that haven't been returned from the other end of the - // RPC connection yet. - class OutstandingResult { - public: - virtual ~OutstandingResult() {} - virtual Error readResult(ChannelT &C) = 0; - virtual void abandon() = 0; - }; + /// Handle incoming RPC calls. + Error handlerLoop() { + while (true) + if (auto Err = this->handleOne()) + return Err; + return Error::success(); + } - // Outstanding results for a specific function. - template - class OutstandingResultImpl : public OutstandingResult { - private: - public: - OutstandingResultImpl(std::promise &&P) - : P(std::move(P)) {} +}; - Error readResult(ChannelT &C) override { return Func::readResult(C, P); } +template +class SingleThreadedRPC + : public detail::RPCBase, + ChannelT, FunctionIdT, + SequenceNumberT> { +private: - void abandon() override { Func::abandon(P); } + using BaseClass = detail::RPCBase, + ChannelT, FunctionIdT, SequenceNumberT>; - private: - std::promise P; - }; + using LaunchPolicy = typename BaseClass::LaunchPolicy; + +public: + + SingleThreadedRPC(ChannelT &C, bool LazyAutoNegotiation) + : BaseClass(C, LazyAutoNegotiation) {} - // Create an outstanding result for the given function. + template + void addHandler(HandlerT Handler) { + return this->template addHandlerImpl(std::move(Handler), + LaunchPolicy()); + } + + template + void addHandler(ClassT &Object, RetT (ClassT::*Method)(ArgTs...)) { + addHandler( + detail::MemberFnWrapper(Object, Method)); + } + + /// Negotiate a function id for Func with the other end of the channel. template - std::unique_ptr - createOutstandingResult(std::promise &&P) { - return llvm::make_unique>(std::move(P)); + Error negotiateFunction() { + using OrcRPCNegotiate = typename BaseClass::OrcRPCNegotiate; + + if (auto RemoteIdOrErr = callB(Func::getPrototype())) { + this->RemoteFunctionIds[Func::getPrototype()] = *RemoteIdOrErr; + return Error::success(); + } else + return RemoteIdOrErr.takeError(); } - // Abandon all outstanding results. - void abandonOutstandingResults() { - for (auto &KV : OutstandingResults) - KV.second->abandon(); - OutstandingResults.clear(); - SequenceNumberMgr.reset(); + /// Convenience method for negotiating multiple functions at once. + template + Error negotiateFunctions() { + return negotiateFunction(); + } + + /// Convenience method for negotiating multiple functions at once. + template + Error negotiateFunctions() { + if (auto Err = negotiateFunction()) + return Err; + return negotiateFunctions(); } - SequenceNumberManager SequenceNumberMgr; - std::map> - OutstandingResults; + template + typename detail::ResultTraits::ErrorReturnType + callB(const ArgTs &... Args) { + bool ReceivedResponse = false; + using ResultType = + typename detail::ResultTraits::ErrorReturnType; + auto Result = detail::ResultTraits::createBlankErrorReturnValue(); + + // We have to 'Check' result (which we know is in a success state at this + // point) so that it can be overwritten in the async handler. + (void)!!Result; + + if (auto Err = this->template appendCallAsync( + [&](ResultType R) { + Result = std::move(R); + ReceivedResponse = true; + return Error::success(); + }, Args...)) { + this->abandonPendingResponses(); + detail::ResultTraits:: + consumeAbandoned(std::move(Result)); + return std::move(Err); + } + + while (!ReceivedResponse) { + if (auto Err = this->handleOne()) { + this->abandonPendingResponses(); + detail::ResultTraits:: + consumeAbandoned(std::move(Result)); + return std::move(Err); + } + } + + return Result; + } + + //using detail::RPCBase::handleOne; + }; -} // end namespace remote +} // end namespace rpc } // end namespace orc } // end namespace llvm diff --git a/include/llvm/ExecutionEngine/Orc/RawByteChannel.h b/include/llvm/ExecutionEngine/Orc/RawByteChannel.h new file mode 100644 index 00000000000..c80074ffd7f --- /dev/null +++ b/include/llvm/ExecutionEngine/Orc/RawByteChannel.h @@ -0,0 +1,182 @@ +//===- llvm/ExecutionEngine/Orc/RawByteChannel.h ----------------*- C++ -*-===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_EXECUTIONENGINE_ORC_RAWBYTECHANNEL_H +#define LLVM_EXECUTIONENGINE_ORC_RAWBYTECHANNEL_H + +#include "OrcError.h" +#include "RPCSerialization.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Endian.h" +#include "llvm/Support/Error.h" +#include +#include +#include +#include +#include +#include +#include + +namespace llvm { +namespace orc { +namespace rpc { + +/// Interface for byte-streams to be used with RPC. +class RawByteChannel { +public: + virtual ~RawByteChannel() {} + + /// Read Size bytes from the stream into *Dst. + virtual Error readBytes(char *Dst, unsigned Size) = 0; + + /// Read size bytes from *Src and append them to the stream. + virtual Error appendBytes(const char *Src, unsigned Size) = 0; + + /// Flush the stream if possible. + virtual Error send() = 0; + + /// Notify the channel that we're starting a message send. + /// Locks the channel for writing. + template + Error startSendMessage(const FunctionIdT &FnId, const SequenceIdT &SeqNo) { + if (auto Err = serializeSeq(*this, FnId, SeqNo)) + return Err; + writeLock.lock(); + return Error::success(); + } + + /// Notify the channel that we're ending a message send. + /// Unlocks the channel for writing. + Error endSendMessage() { + writeLock.unlock(); + return Error::success(); + } + + /// Notify the channel that we're starting a message receive. + /// Locks the channel for reading. + template + Error startReceiveMessage(FunctionIdT &FnId, SequenceNumberT &SeqNo) { + readLock.lock(); + return deserializeSeq(*this, FnId, SeqNo); + } + + /// Notify the channel that we're ending a message receive. + /// Unlocks the channel for reading. + Error endReceiveMessage() { + readLock.unlock(); + return Error::success(); + } + + /// Get the lock for stream reading. + std::mutex &getReadLock() { return readLock; } + + /// Get the lock for stream writing. + std::mutex &getWriteLock() { return writeLock; } + +private: + std::mutex readLock, writeLock; +}; + +template +class SerializationTraits::value && + (std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same::value)>::type> { +public: + static Error serialize(ChannelT &C, T V) { + support::endian::byte_swap(V); + return C.appendBytes(reinterpret_cast(&V), sizeof(T)); + }; + + static Error deserialize(ChannelT &C, T &V) { + if (auto Err = C.readBytes(reinterpret_cast(&V), sizeof(T))) + return Err; + support::endian::byte_swap(V); + return Error::success(); + }; +}; + +template +class SerializationTraits::value>:: + type> { +public: + static Error serialize(ChannelT &C, bool V) { + return C.appendBytes(reinterpret_cast(&V), 1); + } + + static Error deserialize(ChannelT &C, bool &V) { + return C.readBytes(reinterpret_cast(&V), 1); + } +}; + +template +class SerializationTraits::value>:: + type> { +public: + /// RPC channel serialization for std::strings. + static Error serialize(RawByteChannel &C, StringRef S) { + if (auto Err = serializeSeq(C, static_cast(S.size()))) + return Err; + return C.appendBytes((const char *)S.data(), S.size()); + } +}; + +template +class SerializationTraits::value>:: + type> { +public: + static Error serialize(RawByteChannel &C, const char *S) { + return SerializationTraits:: + serialize(C, S); + } +}; + +template +class SerializationTraits::value>:: + type> { +public: + /// RPC channel serialization for std::strings. + static Error serialize(RawByteChannel &C, const std::string &S) { + return SerializationTraits:: + serialize(C, S); + } + + /// RPC channel deserialization for std::strings. + static Error deserialize(RawByteChannel &C, std::string &S) { + uint64_t Count = 0; + if (auto Err = deserializeSeq(C, Count)) + return Err; + S.resize(Count); + return C.readBytes(&S[0], Count); + } +}; + +} // end namespace rpc +} // end namespace orc +} // end namespace llvm + +#endif // LLVM_EXECUTIONENGINE_ORC_RAWBYTECHANNEL_H diff --git a/lib/ExecutionEngine/Orc/CMakeLists.txt b/lib/ExecutionEngine/Orc/CMakeLists.txt index 76720a7c52e..685e882e4a8 100644 --- a/lib/ExecutionEngine/Orc/CMakeLists.txt +++ b/lib/ExecutionEngine/Orc/CMakeLists.txt @@ -6,7 +6,6 @@ add_llvm_library(LLVMOrcJIT OrcCBindings.cpp OrcError.cpp OrcMCJITReplacement.cpp - OrcRemoteTargetRPCAPI.cpp ADDITIONAL_HEADER_DIRS ${LLVM_MAIN_INCLUDE_DIR}/llvm/ExecutionEngine/Orc diff --git a/lib/ExecutionEngine/Orc/OrcError.cpp b/lib/ExecutionEngine/Orc/OrcError.cpp index 64472f9ba37..48dcd442266 100644 --- a/lib/ExecutionEngine/Orc/OrcError.cpp +++ b/lib/ExecutionEngine/Orc/OrcError.cpp @@ -43,6 +43,8 @@ public: return "Unexpected RPC call"; case OrcErrorCode::UnexpectedRPCResponse: return "Unexpected RPC response"; + case OrcErrorCode::UnknownRPCFunction: + return "Unknown RPC function"; } llvm_unreachable("Unhandled error code"); } diff --git a/lib/ExecutionEngine/Orc/OrcRemoteTargetRPCAPI.cpp b/lib/ExecutionEngine/Orc/OrcRemoteTargetRPCAPI.cpp deleted file mode 100644 index d1a021aee3a..00000000000 --- a/lib/ExecutionEngine/Orc/OrcRemoteTargetRPCAPI.cpp +++ /dev/null @@ -1,53 +0,0 @@ -//===------- OrcRemoteTargetRPCAPI.cpp - ORC Remote API utilities ---------===// -// -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. -// -//===----------------------------------------------------------------------===// - -#include "llvm/ExecutionEngine/Orc/OrcRemoteTargetRPCAPI.h" - -namespace llvm { -namespace orc { -namespace remote { - -#define FUNCNAME(X) \ - case X ## Id: \ - return #X - -const char *OrcRemoteTargetRPCAPI::getJITFuncIdName(JITFuncId Id) { - switch (Id) { - case InvalidId: - return "*** Invalid JITFuncId ***"; - FUNCNAME(CallIntVoid); - FUNCNAME(CallMain); - FUNCNAME(CallVoidVoid); - FUNCNAME(CreateRemoteAllocator); - FUNCNAME(CreateIndirectStubsOwner); - FUNCNAME(DeregisterEHFrames); - FUNCNAME(DestroyRemoteAllocator); - FUNCNAME(DestroyIndirectStubsOwner); - FUNCNAME(EmitIndirectStubs); - FUNCNAME(EmitResolverBlock); - FUNCNAME(EmitTrampolineBlock); - FUNCNAME(GetSymbolAddress); - FUNCNAME(GetRemoteInfo); - FUNCNAME(ReadMem); - FUNCNAME(RegisterEHFrames); - FUNCNAME(ReserveMem); - FUNCNAME(RequestCompile); - FUNCNAME(SetProtections); - FUNCNAME(TerminateSession); - FUNCNAME(WriteMem); - FUNCNAME(WritePtr); - }; - return nullptr; -} - -#undef FUNCNAME - -} // end namespace remote -} // end namespace orc -} // end namespace llvm diff --git a/tools/lli/ChildTarget/ChildTarget.cpp b/tools/lli/ChildTarget/ChildTarget.cpp index f6d2413655e..77b1d47a946 100644 --- a/tools/lli/ChildTarget/ChildTarget.cpp +++ b/tools/lli/ChildTarget/ChildTarget.cpp @@ -53,23 +53,12 @@ int main(int argc, char *argv[]) { RTDyldMemoryManager::deregisterEHFramesInProcess(Addr, Size); }; - FDRPCChannel Channel(InFD, OutFD); - typedef remote::OrcRemoteTargetServer JITServer; + FDRawChannel Channel(InFD, OutFD); + typedef remote::OrcRemoteTargetServer JITServer; JITServer Server(Channel, SymbolLookup, RegisterEHFrames, DeregisterEHFrames); - while (1) { - uint32_t RawId; - ExitOnErr(Server.startReceivingFunction(Channel, RawId)); - auto Id = static_cast(RawId); - switch (Id) { - case JITServer::TerminateSessionId: - ExitOnErr(Server.handleTerminateSession()); - return 0; - default: - ExitOnErr(Server.handleKnownFunction(Id)); - break; - } - } + while (!Server.receivedTerminate()) + ExitOnErr(Server.handleOne()); close(InFD); close(OutFD); diff --git a/tools/lli/RemoteJITUtils.h b/tools/lli/RemoteJITUtils.h index d47716cb880..89a51420256 100644 --- a/tools/lli/RemoteJITUtils.h +++ b/tools/lli/RemoteJITUtils.h @@ -14,7 +14,7 @@ #ifndef LLVM_TOOLS_LLI_REMOTEJITUTILS_H #define LLVM_TOOLS_LLI_REMOTEJITUTILS_H -#include "llvm/ExecutionEngine/Orc/RPCByteChannel.h" +#include "llvm/ExecutionEngine/Orc/RawByteChannel.h" #include "llvm/ExecutionEngine/RTDyldMemoryManager.h" #include @@ -25,9 +25,9 @@ #endif /// RPC channel that reads from and writes from file descriptors. -class FDRPCChannel final : public llvm::orc::remote::RPCByteChannel { +class FDRawChannel final : public llvm::orc::rpc::RawByteChannel { public: - FDRPCChannel(int InFD, int OutFD) : InFD(InFD), OutFD(OutFD) {} + FDRawChannel(int InFD, int OutFD) : InFD(InFD), OutFD(OutFD) {} llvm::Error readBytes(char *Dst, unsigned Size) override { assert(Dst && "Attempt to read into null."); @@ -72,11 +72,12 @@ private: }; // launch the remote process (see lli.cpp) and return a channel to it. -std::unique_ptr launchRemote(); +std::unique_ptr launchRemote(); namespace llvm { -// ForwardingMM - Adapter to connect MCJIT to Orc's Remote memory manager. +// ForwardingMM - Adapter to connect MCJIT to Orc's Remote8 +// memory manager. class ForwardingMemoryManager : public llvm::RTDyldMemoryManager { public: void setMemMgr(std::unique_ptr MemMgr) { diff --git a/tools/lli/lli.cpp b/tools/lli/lli.cpp index 9dbe658beff..836a94037d7 100644 --- a/tools/lli/lli.cpp +++ b/tools/lli/lli.cpp @@ -654,20 +654,20 @@ int main(int argc, char **argv, char * const *envp) { // MCJIT itself. FIXME. // Lanch the remote process and get a channel to it. - std::unique_ptr C = launchRemote(); + std::unique_ptr C = launchRemote(); if (!C) { errs() << "Failed to launch remote JIT.\n"; exit(1); } // Create a remote target client running over the channel. - typedef orc::remote::OrcRemoteTargetClient + typedef orc::remote::OrcRemoteTargetClient MyRemote; - MyRemote R = ExitOnErr(MyRemote::Create(*C)); + auto R = ExitOnErr(MyRemote::Create(*C)); // Create a remote memory manager. std::unique_ptr RemoteMM; - ExitOnErr(R.createRemoteMemoryManager(RemoteMM)); + ExitOnErr(R->createRemoteMemoryManager(RemoteMM)); // Forward MCJIT's memory manager calls to the remote memory manager. static_cast(RTDyldMM)->setMemMgr( @@ -678,7 +678,7 @@ int main(int argc, char **argv, char * const *envp) { orc::createLambdaResolver( [](const std::string &Name) { return nullptr; }, [&](const std::string &Name) { - if (auto Addr = ExitOnErr(R.getSymbolAddress(Name))) + if (auto Addr = ExitOnErr(R->getSymbolAddress(Name))) return JITSymbol(Addr, JITSymbolFlags::Exported); return JITSymbol(nullptr); } @@ -691,7 +691,7 @@ int main(int argc, char **argv, char * const *envp) { EE->finalizeObject(); DEBUG(dbgs() << "Executing '" << EntryFn->getName() << "' at 0x" << format("%llx", Entry) << "\n"); - Result = ExitOnErr(R.callIntVoid(Entry)); + Result = ExitOnErr(R->callIntVoid(Entry)); // Like static constructors, the remote target MCJIT support doesn't handle // this yet. It could. FIXME. @@ -702,13 +702,13 @@ int main(int argc, char **argv, char * const *envp) { EE.reset(); // Signal the remote target that we're done JITing. - ExitOnErr(R.terminateSession()); + ExitOnErr(R->terminateSession()); } return Result; } -std::unique_ptr launchRemote() { +std::unique_ptr launchRemote() { #ifndef LLVM_ON_UNIX llvm_unreachable("launchRemote not supported on non-Unix platforms"); #else @@ -758,6 +758,6 @@ std::unique_ptr launchRemote() { close(PipeFD[1][1]); // Return an RPC channel connected to our end of the pipes. - return llvm::make_unique(PipeFD[1][0], PipeFD[0][1]); + return llvm::make_unique(PipeFD[1][0], PipeFD[0][1]); #endif } diff --git a/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp b/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp index 259a75a203f..4d703c78a0e 100644 --- a/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp +++ b/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp @@ -7,7 +7,7 @@ // //===----------------------------------------------------------------------===// -#include "llvm/ExecutionEngine/Orc/RPCByteChannel.h" +#include "llvm/ExecutionEngine/Orc/RawByteChannel.h" #include "llvm/ExecutionEngine/Orc/RPCUtils.h" #include "gtest/gtest.h" @@ -15,7 +15,7 @@ using namespace llvm; using namespace llvm::orc; -using namespace llvm::orc::remote; +using namespace llvm::orc::rpc; class Queue : public std::queue { public: @@ -25,7 +25,7 @@ private: std::mutex Lock; }; -class QueueChannel : public RPCByteChannel { +class QueueChannel : public RawByteChannel { public: QueueChannel(Queue &InQueue, Queue &OutQueue) : InQueue(InQueue), OutQueue(OutQueue) {} @@ -61,126 +61,190 @@ private: Queue &OutQueue; }; -class DummyRPC : public testing::Test, public RPC { +class DummyRPCAPI { public: - enum FuncId : uint32_t { - VoidBoolId = RPCFunctionIdTraits::FirstValidId, - IntIntId, - AllTheTypesId + + class VoidBool : public Function { + public: + static const char* getName() { return "VoidBool"; } + }; + + class IntInt : public Function { + public: + static const char* getName() { return "IntInt"; } + }; + + class AllTheTypes + : public Function)> { + public: + static const char* getName() { return "AllTheTypes"; } }; +}; - typedef Function VoidBool; - typedef Function IntInt; - typedef Function)> - AllTheTypes; +class DummyRPCEndpoint : public DummyRPCAPI, + public SingleThreadedRPC { +public: + DummyRPCEndpoint(Queue &Q1, Queue &Q2) + : SingleThreadedRPC(C, true), C(Q1, Q2) {} +private: + QueueChannel C; }; -TEST_F(DummyRPC, TestAsyncVoidBool) { +TEST(DummyRPC, TestAsyncVoidBool) { Queue Q1, Q2; - QueueChannel C1(Q1, Q2); - QueueChannel C2(Q2, Q1); + DummyRPCEndpoint Client(Q1, Q2); + DummyRPCEndpoint Server(Q2, Q1); - // Make an async call. - auto ResOrErr = callNBWithSeq(C1, true); - EXPECT_TRUE(!!ResOrErr) << "Simple call over queue failed"; + std::thread ServerThread([&]() { + Server.addHandler( + [](bool B) { + EXPECT_EQ(B, true) + << "Server void(bool) received unexpected result"; + }); + + { + // Poke the server to handle the negotiate call. + auto Err = Server.handleOne(); + EXPECT_FALSE(!!Err) << "Server failed to handle call to negotiate"; + } + + { + // Poke the server to handle the VoidBool call. + auto Err = Server.handleOne(); + EXPECT_FALSE(!!Err) << "Server failed to handle call to void(bool)"; + } + }); { - // Expect a call to Proc1. - auto EC = expect(C2, [&](bool &B) { - EXPECT_EQ(B, true) << "Bool serialization broken"; - return Error::success(); - }); - EXPECT_FALSE(EC) << "Simple expect over queue failed"; + // Make an async call. + auto Err = Client.callAsync( + [](Error Err) { + EXPECT_FALSE(!!Err) << "Async void(bool) response handler failed"; + return Error::success(); + }, true); + EXPECT_FALSE(!!Err) << "Client.callAsync failed for void(bool)"; } { - // Wait for the result. - auto EC = waitForResult(C1, ResOrErr->second, handleNone); - EXPECT_FALSE(EC) << "Could not read result."; + // Poke the client to process the result of the void(bool) call. + auto Err = Client.handleOne(); + EXPECT_FALSE(!!Err) << "Client failed to handle response from void(bool)"; } - // Verify that the function returned ok. - auto Err = ResOrErr->first.get(); - EXPECT_FALSE(!!Err) << "Remote void function failed to execute."; + ServerThread.join(); } -TEST_F(DummyRPC, TestAsyncIntInt) { +TEST(DummyRPC, TestAsyncIntInt) { Queue Q1, Q2; - QueueChannel C1(Q1, Q2); - QueueChannel C2(Q2, Q1); + DummyRPCEndpoint Client(Q1, Q2); + DummyRPCEndpoint Server(Q2, Q1); - // Make an async call. - auto ResOrErr = callNBWithSeq(C1, 21); - EXPECT_TRUE(!!ResOrErr) << "Simple call over queue failed"; + std::thread ServerThread([&]() { + Server.addHandler( + [](int X) -> int { + EXPECT_EQ(X, 21) << "Server int(int) receieved unexpected result"; + return 2 * X; + }); - { - // Expect a call to Proc1. - auto EC = expect(C2, [&](int32_t I) -> Expected { - EXPECT_EQ(I, 21) << "Bool serialization broken"; - return 2 * I; + { + // Poke the server to handle the negotiate call. + auto Err = Server.handleOne(); + EXPECT_FALSE(!!Err) << "Server failed to handle call to negotiate"; + } + + { + // Poke the server to handle the int(int) call. + auto Err = Server.handleOne(); + EXPECT_FALSE(!!Err) << "Server failed to handle call to int(int)"; + } }); - EXPECT_FALSE(EC) << "Simple expect over queue failed"; + + { + auto Err = Client.callAsync( + [](Expected Result) { + EXPECT_TRUE(!!Result) << "Async int(int) response handler failed"; + EXPECT_EQ(*Result, 42) + << "Async int(int) response handler received incorrect result"; + return Error::success(); + }, 21); + EXPECT_FALSE(!!Err) << "Client.callAsync failed for int(int)"; } { - // Wait for the result. - auto EC = waitForResult(C1, ResOrErr->second, handleNone); - EXPECT_FALSE(EC) << "Could not read result."; + // Poke the client to process the result. + auto Err = Client.handleOne(); + EXPECT_FALSE(!!Err) << "Client failed to handle response from void(bool)"; } - // Verify that the function returned ok. - auto Val = ResOrErr->first.get(); - EXPECT_TRUE(!!Val) << "Remote int function failed to execute."; - EXPECT_EQ(*Val, 42) << "Remote int function return wrong value."; + ServerThread.join(); } -TEST_F(DummyRPC, TestSerialization) { +TEST(DummyRPC, TestSerialization) { Queue Q1, Q2; - QueueChannel C1(Q1, Q2); - QueueChannel C2(Q2, Q1); + DummyRPCEndpoint Client(Q1, Q2); + DummyRPCEndpoint Server(Q2, Q1); - // Make a call to Proc1. - std::vector v({42, 7}); - auto ResOrErr = callNBWithSeq( - C1, -101, 250, -10000, 10000, -1000000000, 1000000000, -10000000000, - 10000000000, true, "foo", v); - EXPECT_TRUE(!!ResOrErr) << "Big (serialization test) call over queue failed"; + std::thread ServerThread([&]() { + Server.addHandler( + [&](int8_t S8, uint8_t U8, int16_t S16, uint16_t U16, + int32_t S32, uint32_t U32, int64_t S64, uint64_t U64, + bool B, std::string S, std::vector V) { - { - // Expect a call to Proc1. - auto EC = expect( - C2, [&](int8_t &s8, uint8_t &u8, int16_t &s16, uint16_t &u16, - int32_t &s32, uint32_t &u32, int64_t &s64, uint64_t &u64, - bool &b, std::string &s, std::vector &v) { - - EXPECT_EQ(s8, -101) << "int8_t serialization broken"; - EXPECT_EQ(u8, 250) << "uint8_t serialization broken"; - EXPECT_EQ(s16, -10000) << "int16_t serialization broken"; - EXPECT_EQ(u16, 10000) << "uint16_t serialization broken"; - EXPECT_EQ(s32, -1000000000) << "int32_t serialization broken"; - EXPECT_EQ(u32, 1000000000ULL) << "uint32_t serialization broken"; - EXPECT_EQ(s64, -10000000000) << "int64_t serialization broken"; - EXPECT_EQ(u64, 10000000000ULL) << "uint64_t serialization broken"; - EXPECT_EQ(b, true) << "bool serialization broken"; - EXPECT_EQ(s, "foo") << "std::string serialization broken"; - EXPECT_EQ(v, std::vector({42, 7})) + EXPECT_EQ(S8, -101) << "int8_t serialization broken"; + EXPECT_EQ(U8, 250) << "uint8_t serialization broken"; + EXPECT_EQ(S16, -10000) << "int16_t serialization broken"; + EXPECT_EQ(U16, 10000) << "uint16_t serialization broken"; + EXPECT_EQ(S32, -1000000000) << "int32_t serialization broken"; + EXPECT_EQ(U32, 1000000000ULL) << "uint32_t serialization broken"; + EXPECT_EQ(S64, -10000000000) << "int64_t serialization broken"; + EXPECT_EQ(U64, 10000000000ULL) << "uint64_t serialization broken"; + EXPECT_EQ(B, true) << "bool serialization broken"; + EXPECT_EQ(S, "foo") << "std::string serialization broken"; + EXPECT_EQ(V, std::vector({42, 7})) << "std::vector serialization broken"; + return Error::success(); + }); + + { + // Poke the server to handle the negotiate call. + auto Err = Server.handleOne(); + EXPECT_FALSE(!!Err) << "Server failed to handle call to negotiate"; + } + + { + // Poke the server to handle the AllTheTypes call. + auto Err = Server.handleOne(); + EXPECT_FALSE(!!Err) << "Server failed to handle call to void(bool)"; + } + }); + + + { + // Make an async call. + std::vector v({42, 7}); + auto Err = Client.callAsync( + [](Error Err) { + EXPECT_FALSE(!!Err) << "Async AllTheTypes response handler failed"; return Error::success(); - }); - EXPECT_FALSE(EC) << "Big (serialization test) call over queue failed"; + }, + static_cast(-101), static_cast(250), + static_cast(-10000), static_cast(10000), + static_cast(-1000000000), static_cast(1000000000), + static_cast(-10000000000), static_cast(10000000000), + true, std::string("foo"), v); + EXPECT_FALSE(!!Err) << "Client.callAsync failed for AllTheTypes"; } { - // Wait for the result. - auto EC = waitForResult(C1, ResOrErr->second, handleNone); - EXPECT_FALSE(EC) << "Could not read result."; + // Poke the client to process the result of the AllTheTypes call. + auto Err = Client.handleOne(); + EXPECT_FALSE(!!Err) << "Client failed to handle response from AllTheTypes"; } - // Verify that the function returned ok. - auto Err = ResOrErr->first.get(); - EXPECT_FALSE(!!Err) << "Remote void function failed to execute."; + ServerThread.join(); } // Test the synchronous call API.