From: Lang Hames Date: Tue, 5 Sep 2017 03:34:09 +0000 (+0000) Subject: [ORC] Add a pair of ORC layers that forward object-layer operations via RPC. X-Git-Url: https://granicus.if.org/sourcecode?a=commitdiff_plain;h=e20d26a3b35eadab66b359ca13881a6fa5a92952;p=llvm [ORC] Add a pair of ORC layers that forward object-layer operations via RPC. This patch introduces RemoteObjectClientLayer and RemoteObjectServerLayer, which can be used to forward ORC object-layer operations from a JIT stack in the client to a JIT stack (consisting only of object-layers) in the server. This is a new way to support remote-JITing in LLVM. The previous approach (supported by OrcRemoteTargetClient and OrcRemoteTargetServer) used a remote-mapping memory manager that sat "beneath" the JIT stack and sent fully-relocated binary blobs to the server. The main advantage of the new approach is that relocatable objects can be cached on the server and re-used (if the code that they represent hasn't changed), whereas fully-relocated blobs can not (since the addresses they have been permanently bound to will change from run to run). git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@312511 91177308-0d34-0410-b5e6-96231b3b80d8 --- diff --git a/include/llvm/ExecutionEngine/JITSymbol.h b/include/llvm/ExecutionEngine/JITSymbol.h index 5e661401eb6..aeebfd8b1c0 100644 --- a/include/llvm/ExecutionEngine/JITSymbol.h +++ b/include/llvm/ExecutionEngine/JITSymbol.h @@ -89,9 +89,15 @@ public: /// @brief Implicitly convert to the underlying flags type. operator UnderlyingType&() { return Flags; } + /// @brief Implicitly convert to the underlying flags type. + operator const UnderlyingType&() const { return Flags; } + /// @brief Return a reference to the target-specific flags. TargetFlagsType& getTargetFlags() { return TargetFlags; } + /// @brief Return a reference to the target-specific flags. + const TargetFlagsType& getTargetFlags() const { return TargetFlags; } + /// Construct a JITSymbolFlags value based on the flags of the given global /// value. static JITSymbolFlags fromGlobalValue(const GlobalValue &GV); diff --git a/include/llvm/ExecutionEngine/Orc/OrcError.h b/include/llvm/ExecutionEngine/Orc/OrcError.h index e6374b70967..e1ac87075ac 100644 --- a/include/llvm/ExecutionEngine/Orc/OrcError.h +++ b/include/llvm/ExecutionEngine/Orc/OrcError.h @@ -33,7 +33,8 @@ enum class OrcErrorCode : int { RPCResponseAbandoned, UnexpectedRPCCall, UnexpectedRPCResponse, - UnknownErrorCodeFromRemote + UnknownErrorCodeFromRemote, + UnknownResourceHandle }; std::error_code orcError(OrcErrorCode ErrCode); diff --git a/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetRPCAPI.h b/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetRPCAPI.h index 5040cc0c550..bc0da0f9a73 100644 --- a/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetRPCAPI.h +++ b/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetRPCAPI.h @@ -25,6 +25,37 @@ namespace orc { namespace remote { +/// Template error for missing resources. +template +class ResourceNotFound + : public ErrorInfo> { +public: + static char ID; + + ResourceNotFound(ResourceIdT ResourceId, + std::string ResourceDescription = "") + : ResourceId(std::move(ResourceId)), + ResourceDescription(std::move(ResourceDescription)) {} + + std::error_code convertToErrorCode() const override { + return orcError(OrcErrorCode::UnknownResourceHandle); + } + + void log(raw_ostream &OS) const override { + OS << (ResourceDescription.empty() + ? "Remote resource with id " + : ResourceDescription) + << " " << ResourceId << " not found"; + } + +private: + ResourceIdT ResourceId; + std::string ResourceDescription; +}; + +template +char ResourceNotFound::ID = 0; + class DirectBufferWriter { public: DirectBufferWriter() = default; @@ -45,6 +76,32 @@ private: namespace rpc { +template <> +class RPCTypeName { +public: + static const char *getName() { return "JITSymbolFlags"; } +}; + +template +class SerializationTraits { +public: + + static Error serialize(ChannelT &C, const JITSymbolFlags &Flags) { + return serializeSeq(C, static_cast(Flags), + Flags.getTargetFlags()); + } + + static Error deserialize(ChannelT &C, JITSymbolFlags &Flags) { + JITSymbolFlags::UnderlyingType JITFlags; + JITSymbolFlags::TargetFlagsType TargetFlags; + if (auto Err = deserializeSeq(C, JITFlags, TargetFlags)) + return Err; + Flags = JITSymbolFlags(static_cast(JITFlags), + TargetFlags); + return Error::success(); + } +}; + template <> class RPCTypeName { public: static const char *getName() { return "DirectBufferWriter"; } diff --git a/include/llvm/ExecutionEngine/Orc/RemoteObjectLayer.h b/include/llvm/ExecutionEngine/Orc/RemoteObjectLayer.h new file mode 100644 index 00000000000..2f117a5e20d --- /dev/null +++ b/include/llvm/ExecutionEngine/Orc/RemoteObjectLayer.h @@ -0,0 +1,498 @@ +//===------ RemoteObjectLayer.h - Forwards objs to a remote -----*- C++ -*-===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// Forwards objects to a remote object layer via RPC. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_EXECUTIONENGINE_ORC_REMOTEOBJECTLAYER_H +#define LLVM_EXECUTIONENGINE_ORC_REMOTEOBJECTLAYER_H + +#include "llvm/ExecutionEngine/Orc/OrcRemoteTargetRPCAPI.h" +#include "llvm/Object/ObjectFile.h" +#include "llvm/ExecutionEngine/Orc/LambdaResolver.h" +#include + +namespace llvm { +namespace orc { + +/// RPC API needed by RemoteObjectClientLayer and RemoteObjectServerLayer. +class RemoteObjectLayerAPI { +public: + + using ObjHandleT = remote::ResourceIdMgr::ResourceId; + +protected: + + using RemoteSymbolId = remote::ResourceIdMgr::ResourceId; + using RemoteSymbol = std::pair; + +public: + + using BadSymbolHandleError = remote::ResourceNotFound; + using BadObjectHandleError = remote::ResourceNotFound; + +protected: + + static const ObjHandleT InvalidObjectHandleId = 0; + static const RemoteSymbolId NullSymbolId = 0; + + class AddObject + : public rpc::Function(std::string)> { + public: + static const char *getName() { return "AddObject"; } + }; + + class RemoveObject + : public rpc::Function { + public: + static const char *getName() { return "RemoveObject"; } + }; + + class FindSymbol + : public rpc::Function(std::string, + bool)> { + public: + static const char *getName() { return "FindSymbol"; } + }; + + class FindSymbolIn + : public rpc::Function(ObjHandleT, std::string, + bool)> { + public: + static const char *getName() { return "FindSymbolIn"; } + }; + + class EmitAndFinalize + : public rpc::Function { + public: + static const char *getName() { return "EmitAndFinalize"; } + }; + + class Lookup + : public rpc::Function(ObjHandleT, std::string)> { + public: + static const char *getName() { return "Lookup"; } + }; + + class LookupInLogicalDylib + : public rpc::Function(ObjHandleT, std::string)> { + public: + static const char *getName() { return "LookupInLogicalDylib"; } + }; + + class ReleaseRemoteSymbol + : public rpc::Function { + public: + static const char *getName() { return "ReleaseRemoteSymbol"; } + }; + + class MaterializeRemoteSymbol + : public rpc::Function(RemoteSymbolId)> { + public: + static const char *getName() { return "MaterializeRemoteSymbol"; } + }; +}; + +/// Base class containing common utilities for RemoteObjectClientLayer and +/// RemoteObjectServerLayer. +template +class RemoteObjectLayer : public RemoteObjectLayerAPI { +public: + + RemoteObjectLayer(RPCEndpoint &Remote, + std::function ReportError) + : Remote(Remote), ReportError(std::move(ReportError)), + SymbolIdMgr(NullSymbolId + 1) { + using ThisT = RemoteObjectLayer; + Remote.template addHandler( + *this, &ThisT::handleReleaseRemoteSymbol); + Remote.template addHandler( + *this, &ThisT::handleMaterializeRemoteSymbol); + } + +protected: + + class RemoteSymbolMaterializer { + public: + + RemoteSymbolMaterializer(RemoteObjectLayer &C, + RemoteSymbolId Id) + : C(C), Id(Id) {} + + RemoteSymbolMaterializer(const RemoteSymbolMaterializer &Other) + : C(Other.C), Id(Other.Id) { + // FIXME: This is a horrible, auto_ptr-style, copy-as-move operation. + // It should be removed as soon as LLVM has C++14's generalized + // lambda capture (at which point the materializer can be moved + // into the lambda in remoteToJITSymbol below). + const_cast(Other).Id = 0; + } + + RemoteSymbolMaterializer& + operator=(const RemoteSymbolMaterializer&) = delete; + + ~RemoteSymbolMaterializer() { + if (Id) + C.releaseRemoteSymbol(Id); + } + + Expected materialize() { + auto Addr = C.materializeRemoteSymbol(Id); + Id = 0; + return Addr; + } + + private: + RemoteObjectLayer &C; + RemoteSymbolId Id; + }; + + RemoteSymbol nullRemoteSymbol() { + return RemoteSymbol(0, JITSymbolFlags()); + } + + // Creates a StringError that contains a copy of Err's log message, then + // sends that StringError to ReportError. + // + // This allows us to locally log error messages for errors that will actually + // be delivered to the remote. + Error teeLog(Error Err) { + return handleErrors(std::move(Err), + [this](std::unique_ptr EIB) { + ReportError(make_error( + EIB->message(), + EIB->convertToErrorCode())); + return Error(std::move(EIB)); + }); + } + + Error badRemoteSymbolIdError(RemoteSymbolId Id) { + return make_error(Id, "Remote JIT Symbol"); + } + + Error badObjectHandleError(ObjHandleT H) { + return make_error( + H, "Bad object handle"); + } + + Expected jitSymbolToRemote(JITSymbol Sym) { + if (Sym) { + auto Id = SymbolIdMgr.getNext(); + auto Flags = Sym.getFlags(); + assert(!InUseSymbols.count(Id) && "Symbol id already in use"); + InUseSymbols.insert(std::make_pair(Id, std::move(Sym))); + return RemoteSymbol(Id, Flags); + } else if (auto Err = Sym.takeError()) + return teeLog(std::move(Err)); + // else... + return nullRemoteSymbol(); + } + + JITSymbol remoteToJITSymbol(Expected RemoteSymOrErr) { + if (RemoteSymOrErr) { + auto &RemoteSym = *RemoteSymOrErr; + RemoteSymbolMaterializer RSM(*this, RemoteSym.first); + auto Sym = + JITSymbol([RSM]() mutable { return RSM.materialize(); }, + RemoteSym.second); + return Sym; + } else + return RemoteSymOrErr.takeError(); + } + + template + using CallBResult = decltype(foldRemoteERror( + std::declval() + .template callB( + std::declval()...))); + + /// API checked callB function. + template + CallBResult callB(const ArgTs &... Args) { + return foldRemoteError(Remote.template callB(Args...)); + } + + RPCEndpoint &Remote; + std::function ReportError; + +private: + + void releaseRemoteSymbol(RemoteSymbolId Id) { + if (auto Err = Remote.template callB(Id)) + ReportError(std::move(Err)); + } + + Expected materializeRemoteSymbol(RemoteSymbolId Id) { + return Remote.template callB(Id); + } + + Error handleReleaseRemoteSymbol(RemoteSymbolId Id) { + auto SI = InUseSymbols.find(Id); + if (SI != InUseSymbols.end()) { + InUseSymbols.erase(SI); + return Error::success(); + } else + return teeLog(badRemoteSymbolIdError(Id)); + } + + Expected handleMaterializeRemoteSymbol(RemoteSymbolId Id) { + auto SI = InUseSymbols.find(Id); + if (SI != InUseSymbols.end()) { + auto AddrOrErr = SI->second.getAddress(); + InUseSymbols.erase(SI); + SymbolIdMgr.release(Id); + if (AddrOrErr) + return *AddrOrErr; + else + return teeLog(AddrOrErr.takeError()); + } else { + return teeLog(badRemoteSymbolIdError(Id)); + } + } + + remote::ResourceIdMgr SymbolIdMgr; + std::map InUseSymbols; +}; + +template +class RemoteObjectClientLayer : public RemoteObjectLayer { +private: + + using AddObject = RemoteObjectLayerAPI::AddObject; + using RemoveObject = RemoteObjectLayerAPI::RemoveObject; + using FindSymbol = RemoteObjectLayerAPI::FindSymbol; + using FindSymbolIn = RemoteObjectLayerAPI::FindSymbolIn; + using EmitAndFinalize = RemoteObjectLayerAPI::EmitAndFinalize; + using Lookup = RemoteObjectLayerAPI::Lookup; + using LookupInLogicalDylib = RemoteObjectLayerAPI::LookupInLogicalDylib; + + using RemoteObjectLayer::teeLog; + using RemoteObjectLayer::badObjectHandleError; + using RemoteObjectLayer::remoteToJITSymbol; + +public: + + using ObjHandleT = RemoteObjectLayerAPI::ObjHandleT; + using RemoteSymbol = RemoteObjectLayerAPI::RemoteSymbol; + + using ObjectPtr = + std::shared_ptr>; + + RemoteObjectClientLayer(RPCEndpoint &Remote, + std::function ReportError) + : RemoteObjectLayer(Remote, std::move(ReportError)) { + using ThisT = RemoteObjectClientLayer; + Remote.template addHandler(*this, &ThisT::lookup); + Remote.template addHandler( + *this, &ThisT::lookupInLogicalDylib); + } + + Expected + addObject(ObjectPtr Object, std::shared_ptr Resolver) { + StringRef ObjBuffer = Object->getBinary()->getData(); + if (auto HandleOrErr = + this->Remote.template callB(ObjBuffer)) { + auto &Handle = *HandleOrErr; + // FIXME: Return an error for this: + assert(!Resolvers.count(Handle) && "Handle already in use?"); + Resolvers[Handle] = std::move(Resolver); + return Handle; + } else + return HandleOrErr.takeError(); + } + + Error removeObject(ObjHandleT H) { + return this->Remote.template callB(H); + } + + JITSymbol findSymbol(StringRef Name, bool ExportedSymbolsOnly) { + return remoteToJITSymbol( + this->Remote.template callB(Name, + ExportedSymbolsOnly)); + } + + JITSymbol findSymbolIn(ObjHandleT H, StringRef Name, bool ExportedSymbolsOnly) { + return remoteToJITSymbol( + this->Remote.template callB(H, Name, + ExportedSymbolsOnly)); + } + + Error emitAndFinalize(ObjHandleT H) { + return this->Remote.template callB(H); + } + +private: + + Expected lookup(ObjHandleT H, const std::string &Name) { + auto RI = Resolvers.find(H); + if (RI != Resolvers.end()) { + return this->jitSymbolToRemote(RI->second->findSymbol(Name)); + } else + return teeLog(badObjectHandleError(H)); + } + + Expected lookupInLogicalDylib(ObjHandleT H, + const std::string &Name) { + auto RI = Resolvers.find(H); + if (RI != Resolvers.end()) + return this->jitSymbolToRemote( + RI->second->findSymbolInLogicalDylib(Name)); + else + return teeLog(badObjectHandleError(H)); + } + + std::map> Resolvers; +}; + +template +class RemoteObjectServerLayer : public RemoteObjectLayer { +private: + + using ObjHandleT = RemoteObjectLayerAPI::ObjHandleT; + using RemoteSymbol = RemoteObjectLayerAPI::RemoteSymbol; + + using AddObject = RemoteObjectLayerAPI::AddObject; + using RemoveObject = RemoteObjectLayerAPI::RemoveObject; + using FindSymbol = RemoteObjectLayerAPI::FindSymbol; + using FindSymbolIn = RemoteObjectLayerAPI::FindSymbolIn; + using EmitAndFinalize = RemoteObjectLayerAPI::EmitAndFinalize; + using Lookup = RemoteObjectLayerAPI::Lookup; + using LookupInLogicalDylib = RemoteObjectLayerAPI::LookupInLogicalDylib; + + using RemoteObjectLayer::teeLog; + using RemoteObjectLayer::badObjectHandleError; + using RemoteObjectLayer::remoteToJITSymbol; + +public: + + RemoteObjectServerLayer(BaseLayerT &BaseLayer, + RPCEndpoint &Remote, + std::function ReportError) + : RemoteObjectLayer(Remote, std::move(ReportError)), + BaseLayer(BaseLayer), HandleIdMgr(1) { + using ThisT = RemoteObjectServerLayer; + + Remote.template addHandler(*this, &ThisT::addObject); + Remote.template addHandler(*this, &ThisT::removeObject); + Remote.template addHandler(*this, &ThisT::findSymbol); + Remote.template addHandler(*this, &ThisT::findSymbolIn); + Remote.template addHandler(*this, &ThisT::emitAndFinalize); + } + +private: + + class StringMemoryBuffer : public MemoryBuffer { + public: + StringMemoryBuffer(std::string Buffer) + : Buffer(std::move(Buffer)) { + init(this->Buffer.data(), this->Buffer.data() + this->Buffer.size(), + false); + } + + BufferKind getBufferKind() const override { return MemoryBuffer_Malloc; } + private: + std::string Buffer; + }; + + JITSymbol lookup(ObjHandleT Id, const std::string &Name) { + return remoteToJITSymbol( + this->Remote.template callB(Id, Name)); + } + + JITSymbol lookupInLogicalDylib(ObjHandleT Id, const std::string &Name) { + return remoteToJITSymbol( + this->Remote.template callB(Id, Name)); + } + + Expected addObject(std::string ObjBuffer) { + auto Buffer = llvm::make_unique(std::move(ObjBuffer)); + if (auto ObjectOrErr = + object::ObjectFile::createObjectFile(Buffer->getMemBufferRef())) { + auto Object = + std::make_shared>( + std::move(*ObjectOrErr), std::move(Buffer)); + + auto Id = HandleIdMgr.getNext(); + assert(!BaseLayerHandles.count(Id) && "Id already in use?"); + + auto Resolver = + createLambdaResolver( + [this, Id](const std::string &Name) { return lookup(Id, Name); }, + [this, Id](const std::string &Name) { + return lookupInLogicalDylib(Id, Name); + }); + + if (auto HandleOrErr = + BaseLayer.addObject(std::move(Object), std::move(Resolver))) { + BaseLayerHandles[Id] = std::move(*HandleOrErr); + return Id; + } else + return teeLog(HandleOrErr.takeError()); + } else + return teeLog(ObjectOrErr.takeError()); + } + + Error removeObject(ObjHandleT H) { + auto HI = BaseLayerHandles.find(H); + if (HI != BaseLayerHandles.end()) { + if (auto Err = BaseLayer.removeObject(HI->second)) + return teeLog(std::move(Err)); + return Error::success(); + } else + return teeLog(badObjectHandleError(H)); + } + + Expected findSymbol(const std::string &Name, + bool ExportedSymbolsOnly) { + if (auto Sym = BaseLayer.findSymbol(Name, ExportedSymbolsOnly)) + return this->jitSymbolToRemote(std::move(Sym)); + else if (auto Err = Sym.takeError()) + return teeLog(std::move(Err)); + return this->nullRemoteSymbol(); + } + + Expected findSymbolIn(ObjHandleT H, const std::string &Name, + bool ExportedSymbolsOnly) { + auto HI = BaseLayerHandles.find(H); + if (HI != BaseLayerHandles.end()) { + if (auto Sym = BaseLayer.findSymbolIn(HI->second, Name, ExportedSymbolsOnly)) + return this->jitSymbolToRemote(std::move(Sym)); + else if (auto Err = Sym.takeError()) + return teeLog(std::move(Err)); + return this->nullRemoteSymbol(); + } else + return teeLog(badObjectHandleError(H)); + } + + Error emitAndFinalize(ObjHandleT H) { + auto HI = BaseLayerHandles.find(H); + if (HI != BaseLayerHandles.end()) { + if (auto Err = BaseLayer.emitAndFinalize(HI->second)) + return teeLog(std::move(Err)); + return Error::success(); + } else + return teeLog(badObjectHandleError(H)); + } + + BaseLayerT &BaseLayer; + remote::ResourceIdMgr HandleIdMgr; + std::map BaseLayerHandles; +}; + +} // end namespace orc +} // end namespace llvm + +#endif // LLVM_EXECUTIONENGINE_ORC_REMOTEOBJECTLAYER_H diff --git a/lib/ExecutionEngine/Orc/OrcError.cpp b/lib/ExecutionEngine/Orc/OrcError.cpp index df2d320e0f7..c218cb9a523 100644 --- a/lib/ExecutionEngine/Orc/OrcError.cpp +++ b/lib/ExecutionEngine/Orc/OrcError.cpp @@ -54,6 +54,8 @@ public: case OrcErrorCode::UnknownErrorCodeFromRemote: return "Unknown error returned from remote RPC function " "(Use StringError to get error message)"; + case OrcErrorCode::UnknownResourceHandle: + return "Unknown resource handle"; } llvm_unreachable("Unhandled error code"); } diff --git a/unittests/ExecutionEngine/Orc/CMakeLists.txt b/unittests/ExecutionEngine/Orc/CMakeLists.txt index db40c4213bd..e7e3034905e 100644 --- a/unittests/ExecutionEngine/Orc/CMakeLists.txt +++ b/unittests/ExecutionEngine/Orc/CMakeLists.txt @@ -18,6 +18,7 @@ add_llvm_unittest(OrcJITTests OrcCAPITest.cpp OrcTestCommon.cpp QueueChannel.cpp + RemoteObjectLayerTest.cpp RPCUtilsTest.cpp RTDyldObjectLinkingLayerTest.cpp ) diff --git a/unittests/ExecutionEngine/Orc/RemoteObjectLayerTest.cpp b/unittests/ExecutionEngine/Orc/RemoteObjectLayerTest.cpp new file mode 100644 index 00000000000..da76890d73d --- /dev/null +++ b/unittests/ExecutionEngine/Orc/RemoteObjectLayerTest.cpp @@ -0,0 +1,576 @@ +//===---------------------- RemoteObjectLayerTest.cpp ---------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#include "llvm/ExecutionEngine/Orc/CompileUtils.h" +#include "llvm/ExecutionEngine/Orc/NullResolver.h" +#include "llvm/ExecutionEngine/Orc/RemoteObjectLayer.h" +#include "OrcTestCommon.h" +#include "QueueChannel.h" +#include "gtest/gtest.h" + +using namespace llvm; +using namespace llvm::orc; + +namespace { + +class MockObjectLayer { +public: + + using ObjHandleT = uint64_t; + + using ObjectPtr = + std::shared_ptr>; + + using LookupFn = std::function; + using SymbolLookupTable = std::map; + + using AddObjectFtor = + std::function(ObjectPtr, SymbolLookupTable&)>; + + class ObjectNotFound : public remote::ResourceNotFound { + public: + ObjectNotFound(ObjHandleT H) : ResourceNotFound(H, "Object handle") {} + }; + + MockObjectLayer(AddObjectFtor AddObject) + : AddObject(std::move(AddObject)) {} + + Expected addObject(ObjectPtr Obj, + std::shared_ptr Resolver) { + return AddObject(Obj, SymTab); + } + + Error removeObject(ObjHandleT H) { + if (SymTab.count(H)) + return Error::success(); + else + return make_error(H); + } + + JITSymbol findSymbol(StringRef Name, bool ExportedSymbolsOnly) { + for (auto KV : SymTab) { + if (auto Sym = KV.second(Name, ExportedSymbolsOnly)) + return Sym; + else if (auto Err = Sym.takeError()) + return std::move(Err); + } + return JITSymbol(nullptr); + } + + JITSymbol findSymbolIn(ObjHandleT H, StringRef Name, + bool ExportedSymbolsOnly) { + auto LI = SymTab.find(H); + if (LI != SymTab.end()) + return LI->second(Name, ExportedSymbolsOnly); + else + return make_error(H); + } + + Error emitAndFinalize(ObjHandleT H) { + if (SymTab.count(H)) + return Error::success(); + else + return make_error(H); + } + +private: + AddObjectFtor AddObject; + SymbolLookupTable SymTab; +}; + +using RPCEndpoint = rpc::SingleThreadedRPCEndpoint; + +MockObjectLayer::ObjectPtr createTestObject() { + OrcNativeTarget::initialize(); + auto TM = std::unique_ptr(EngineBuilder().selectTarget()); + + if (!TM) + return nullptr; + + LLVMContext Ctx; + ModuleBuilder MB(Ctx, TM->getTargetTriple().str(), "TestModule"); + MB.getModule()->setDataLayout(TM->createDataLayout()); + auto *Main = MB.createFunctionDecl("main"); + Main->getBasicBlockList().push_back(BasicBlock::Create(Ctx)); + IRBuilder<> B(&Main->back()); + B.CreateRet(ConstantInt::getSigned(Type::getInt32Ty(Ctx), 42)); + + SimpleCompiler IRCompiler(*TM); + return std::make_shared>( + IRCompiler(*MB.getModule())); +} + +TEST(RemoteObjectLayer, AddObject) { + llvm::orc::rpc::registerStringError(); + auto TestObject = createTestObject(); + if (!TestObject) + return; + + auto Channels = createPairedQueueChannels(); + + auto ReportError = + [](Error Err) { + logAllUnhandledErrors(std::move(Err), llvm::errs(), ""); + }; + + // Copy the bytes out of the test object: the copy will be used to verify + // that the original is correctly transmitted over RPC to the mock layer. + StringRef ObjBytes = TestObject->getBinary()->getData(); + std::vector ObjContents(ObjBytes.size()); + std::copy(ObjBytes.begin(), ObjBytes.end(), ObjContents.begin()); + + RPCEndpoint ClientEP(*Channels.first, true); + RemoteObjectClientLayer Client(ClientEP, ReportError); + + RPCEndpoint ServerEP(*Channels.second, true); + MockObjectLayer BaseLayer( + [&ObjContents](MockObjectLayer::ObjectPtr Obj, + MockObjectLayer::SymbolLookupTable &SymTab) { + + // Check that the received object file content matches the original. + StringRef RPCObjContents = Obj->getBinary()->getData(); + EXPECT_EQ(RPCObjContents.size(), ObjContents.size()) + << "RPC'd object file has incorrect size"; + EXPECT_TRUE(std::equal(RPCObjContents.begin(), RPCObjContents.end(), + ObjContents.begin())) + << "RPC'd object file content does not match original content"; + + return 1; + }); + RemoteObjectServerLayer Server(BaseLayer, + ServerEP, + ReportError); + + bool Finished = false; + ServerEP.addHandler( + [&]() { Finished = true; } + ); + + auto ServerThread = + std::thread([&]() { + while (!Finished) + cantFail(ServerEP.handleOne()); + }); + + cantFail(Client.addObject(std::move(TestObject), + std::make_shared())); + cantFail(ClientEP.callB()); + ServerThread.join(); +} + +TEST(RemoteObjectLayer, AddObjectFailure) { + llvm::orc::rpc::registerStringError(); + auto TestObject = createTestObject(); + if (!TestObject) + return; + + auto Channels = createPairedQueueChannels(); + + auto ReportError = + [](Error Err) { + auto ErrMsg = toString(std::move(Err)); + EXPECT_EQ(ErrMsg, "AddObjectFailure - Test Message") + << "Expected error string to be \"AddObjectFailure - Test Message\""; + }; + + RPCEndpoint ClientEP(*Channels.first, true); + RemoteObjectClientLayer Client(ClientEP, ReportError); + + RPCEndpoint ServerEP(*Channels.second, true); + MockObjectLayer BaseLayer( + [](MockObjectLayer::ObjectPtr Obj, + MockObjectLayer::SymbolLookupTable &SymTab) + -> Expected { + return make_error("AddObjectFailure - Test Message", + inconvertibleErrorCode()); + }); + RemoteObjectServerLayer Server(BaseLayer, + ServerEP, + ReportError); + + bool Finished = false; + ServerEP.addHandler( + [&]() { Finished = true; } + ); + + auto ServerThread = + std::thread([&]() { + while (!Finished) + cantFail(ServerEP.handleOne()); + }); + + auto HandleOrErr = + Client.addObject(std::move(TestObject), std::make_shared()); + + EXPECT_FALSE(HandleOrErr) << "Expected error from addObject"; + + auto ErrMsg = toString(HandleOrErr.takeError()); + EXPECT_EQ(ErrMsg, "AddObjectFailure - Test Message") + << "Expected error string to be \"AddObjectFailure - Test Message\""; + + cantFail(ClientEP.callB()); + ServerThread.join(); +} + + +TEST(RemoteObjectLayer, RemoveObject) { + llvm::orc::rpc::registerStringError(); + auto TestObject = createTestObject(); + if (!TestObject) + return; + + auto Channels = createPairedQueueChannels(); + + auto ReportError = + [](Error Err) { + logAllUnhandledErrors(std::move(Err), llvm::errs(), ""); + }; + + RPCEndpoint ClientEP(*Channels.first, true); + RemoteObjectClientLayer Client(ClientEP, ReportError); + + RPCEndpoint ServerEP(*Channels.second, true); + + MockObjectLayer BaseLayer( + [](MockObjectLayer::ObjectPtr Obj, + MockObjectLayer::SymbolLookupTable &SymTab) { + SymTab[1] = MockObjectLayer::LookupFn(); + return 1; + }); + RemoteObjectServerLayer Server(BaseLayer, + ServerEP, + ReportError); + + bool Finished = false; + ServerEP.addHandler( + [&]() { Finished = true; } + ); + + auto ServerThread = + std::thread([&]() { + while (!Finished) + cantFail(ServerEP.handleOne()); + }); + + auto H = cantFail(Client.addObject(std::move(TestObject), + std::make_shared())); + + cantFail(Client.removeObject(H)); + + cantFail(ClientEP.callB()); + ServerThread.join(); +} + +TEST(RemoteObjectLayer, RemoveObjectFailure) { + llvm::orc::rpc::registerStringError(); + auto TestObject = createTestObject(); + if (!TestObject) + return; + + auto Channels = createPairedQueueChannels(); + + auto ReportError = + [](Error Err) { + auto ErrMsg = toString(std::move(Err)); + EXPECT_EQ(ErrMsg, "Object handle 42 not found") + << "Expected error string to be \"Object handle 42 not found\""; + }; + + RPCEndpoint ClientEP(*Channels.first, true); + RemoteObjectClientLayer Client(ClientEP, ReportError); + + RPCEndpoint ServerEP(*Channels.second, true); + + // AddObject lambda does not update symbol table, so removeObject will treat + // this as a bad object handle. + MockObjectLayer BaseLayer( + [](MockObjectLayer::ObjectPtr Obj, + MockObjectLayer::SymbolLookupTable &SymTab) { + return 42; + }); + RemoteObjectServerLayer Server(BaseLayer, + ServerEP, + ReportError); + + bool Finished = false; + ServerEP.addHandler( + [&]() { Finished = true; } + ); + + auto ServerThread = + std::thread([&]() { + while (!Finished) + cantFail(ServerEP.handleOne()); + }); + + auto H = cantFail(Client.addObject(std::move(TestObject), + std::make_shared())); + + auto Err = Client.removeObject(H); + EXPECT_TRUE(!!Err) << "Expected error from removeObject"; + + auto ErrMsg = toString(std::move(Err)); + EXPECT_EQ(ErrMsg, "Object handle 42 not found") + << "Expected error string to be \"Object handle 42 not found\""; + + cantFail(ClientEP.callB()); + ServerThread.join(); +} + +TEST(RemoteObjectLayer, FindSymbol) { + llvm::orc::rpc::registerStringError(); + auto TestObject = createTestObject(); + if (!TestObject) + return; + + auto Channels = createPairedQueueChannels(); + + auto ReportError = + [](Error Err) { + auto ErrMsg = toString(std::move(Err)); + EXPECT_EQ(ErrMsg, "Could not find symbol 'barbaz'") + << "Expected error string to be \"Object handle 42 not found\""; + }; + + RPCEndpoint ClientEP(*Channels.first, true); + RemoteObjectClientLayer Client(ClientEP, ReportError); + + RPCEndpoint ServerEP(*Channels.second, true); + + // AddObject lambda does not update symbol table, so removeObject will treat + // this as a bad object handle. + MockObjectLayer BaseLayer( + [](MockObjectLayer::ObjectPtr Obj, + MockObjectLayer::SymbolLookupTable &SymTab) { + SymTab[42] = + [](StringRef Name, bool ExportedSymbolsOnly) -> JITSymbol { + if (Name == "foobar") + return JITSymbol(0x12348765, JITSymbolFlags::Exported); + return make_error(Name); + }; + return 42; + }); + RemoteObjectServerLayer Server(BaseLayer, + ServerEP, + ReportError); + + bool Finished = false; + ServerEP.addHandler( + [&]() { Finished = true; } + ); + + auto ServerThread = + std::thread([&]() { + while (!Finished) + cantFail(ServerEP.handleOne()); + }); + + cantFail(Client.addObject(std::move(TestObject), + std::make_shared())); + + auto Sym1 = Client.findSymbol("foobar", true); + + EXPECT_TRUE(!!Sym1) << "Symbol 'foobar' should be findable"; + EXPECT_EQ(cantFail(Sym1.getAddress()), 0x12348765ULL) + << "Symbol 'foobar' does not return the correct address"; + + auto Sym2 = Client.findSymbol("barbaz", true); + EXPECT_FALSE(!!Sym2) << "Symbol 'barbaz' should not be findable"; + auto Err = Sym2.takeError(); + EXPECT_TRUE(!!Err) << "Sym2 should contain an error value"; + auto ErrMsg = toString(std::move(Err)); + EXPECT_EQ(ErrMsg, "Could not find symbol 'barbaz'") + << "Expected symbol-not-found error for Sym2"; + + cantFail(ClientEP.callB()); + ServerThread.join(); +} + +TEST(RemoteObjectLayer, FindSymbolIn) { + llvm::orc::rpc::registerStringError(); + auto TestObject = createTestObject(); + if (!TestObject) + return; + + auto Channels = createPairedQueueChannels(); + + auto ReportError = + [](Error Err) { + auto ErrMsg = toString(std::move(Err)); + EXPECT_EQ(ErrMsg, "Could not find symbol 'barbaz'") + << "Expected error string to be \"Object handle 42 not found\""; + }; + + RPCEndpoint ClientEP(*Channels.first, true); + RemoteObjectClientLayer Client(ClientEP, ReportError); + + RPCEndpoint ServerEP(*Channels.second, true); + + // AddObject lambda does not update symbol table, so removeObject will treat + // this as a bad object handle. + MockObjectLayer BaseLayer( + [](MockObjectLayer::ObjectPtr Obj, + MockObjectLayer::SymbolLookupTable &SymTab) { + SymTab[42] = + [](StringRef Name, bool ExportedSymbolsOnly) -> JITSymbol { + if (Name == "foobar") + return JITSymbol(0x12348765, JITSymbolFlags::Exported); + return make_error(Name); + }; + // Dummy symbol table entry - this should not be visible to + // findSymbolIn. + SymTab[43] = + [](StringRef Name, bool ExportedSymbolsOnly) -> JITSymbol { + if (Name == "barbaz") + return JITSymbol(0xdeadbeef, JITSymbolFlags::Exported); + return make_error(Name); + }; + + return 42; + }); + RemoteObjectServerLayer Server(BaseLayer, + ServerEP, + ReportError); + + bool Finished = false; + ServerEP.addHandler( + [&]() { Finished = true; } + ); + + auto ServerThread = + std::thread([&]() { + while (!Finished) + cantFail(ServerEP.handleOne()); + }); + + auto H = cantFail(Client.addObject(std::move(TestObject), + std::make_shared())); + + auto Sym1 = Client.findSymbolIn(H, "foobar", true); + + EXPECT_TRUE(!!Sym1) << "Symbol 'foobar' should be findable"; + EXPECT_EQ(cantFail(Sym1.getAddress()), 0x12348765ULL) + << "Symbol 'foobar' does not return the correct address"; + + auto Sym2 = Client.findSymbolIn(H, "barbaz", true); + EXPECT_FALSE(!!Sym2) << "Symbol 'barbaz' should not be findable"; + auto Err = Sym2.takeError(); + EXPECT_TRUE(!!Err) << "Sym2 should contain an error value"; + auto ErrMsg = toString(std::move(Err)); + EXPECT_EQ(ErrMsg, "Could not find symbol 'barbaz'") + << "Expected symbol-not-found error for Sym2"; + + cantFail(ClientEP.callB()); + ServerThread.join(); +} + +TEST(RemoteObjectLayer, EmitAndFinalize) { + llvm::orc::rpc::registerStringError(); + auto TestObject = createTestObject(); + if (!TestObject) + return; + + auto Channels = createPairedQueueChannels(); + + auto ReportError = + [](Error Err) { + logAllUnhandledErrors(std::move(Err), llvm::errs(), ""); + }; + + RPCEndpoint ClientEP(*Channels.first, true); + RemoteObjectClientLayer Client(ClientEP, ReportError); + + RPCEndpoint ServerEP(*Channels.second, true); + + MockObjectLayer BaseLayer( + [](MockObjectLayer::ObjectPtr Obj, + MockObjectLayer::SymbolLookupTable &SymTab) { + SymTab[1] = MockObjectLayer::LookupFn(); + return 1; + }); + RemoteObjectServerLayer Server(BaseLayer, + ServerEP, + ReportError); + + bool Finished = false; + ServerEP.addHandler( + [&]() { Finished = true; } + ); + + auto ServerThread = + std::thread([&]() { + while (!Finished) + cantFail(ServerEP.handleOne()); + }); + + auto H = cantFail(Client.addObject(std::move(TestObject), + std::make_shared())); + + auto Err = Client.emitAndFinalize(H); + EXPECT_FALSE(!!Err) << "emitAndFinalize should work"; + + cantFail(ClientEP.callB()); + ServerThread.join(); +} + +TEST(RemoteObjectLayer, EmitAndFinalizeFailure) { + llvm::orc::rpc::registerStringError(); + auto TestObject = createTestObject(); + if (!TestObject) + return; + + auto Channels = createPairedQueueChannels(); + + auto ReportError = + [](Error Err) { + auto ErrMsg = toString(std::move(Err)); + EXPECT_EQ(ErrMsg, "Object handle 1 not found") + << "Expected bad handle error"; + }; + + RPCEndpoint ClientEP(*Channels.first, true); + RemoteObjectClientLayer Client(ClientEP, ReportError); + + RPCEndpoint ServerEP(*Channels.second, true); + + MockObjectLayer BaseLayer( + [](MockObjectLayer::ObjectPtr Obj, + MockObjectLayer::SymbolLookupTable &SymTab) { + return 1; + }); + RemoteObjectServerLayer Server(BaseLayer, + ServerEP, + ReportError); + + bool Finished = false; + ServerEP.addHandler( + [&]() { Finished = true; } + ); + + auto ServerThread = + std::thread([&]() { + while (!Finished) + cantFail(ServerEP.handleOne()); + }); + + auto H = cantFail(Client.addObject(std::move(TestObject), + std::make_shared())); + + auto Err = Client.emitAndFinalize(H); + EXPECT_TRUE(!!Err) << "emitAndFinalize should work"; + + auto ErrMsg = toString(std::move(Err)); + EXPECT_EQ(ErrMsg, "Object handle 1 not found") + << "emitAndFinalize returned incorrect error"; + + cantFail(ClientEP.callB()); + ServerThread.join(); +} + +}