From: Lang Hames Date: Thu, 13 Apr 2017 03:51:35 +0000 (+0000) Subject: [ORC] Add RPC and serialization support for Errors and Expecteds. X-Git-Url: https://granicus.if.org/sourcecode?a=commitdiff_plain;h=ed576a186fb8b5d33b1fba26f2e825182603f9f2;p=llvm [ORC] Add RPC and serialization support for Errors and Expecteds. This patch allows Error and Expected types to be passed to and returned from RPC functions. Serializers and deserializers for custom error types (types deriving from the ErrorInfo class template) can be registered with the SerializationTraits for a given channel type (see registerStringError in RPCSerialization.h for an example), allowing a given custom type to be sent/received. Unregistered types will be serialized/deserialized as StringErrors using the custom type's log message as the error string. git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@300167 91177308-0d34-0410-b5e6-96231b3b80d8 --- diff --git a/include/llvm/ExecutionEngine/Orc/OrcError.h b/include/llvm/ExecutionEngine/Orc/OrcError.h index e85cbe30970..cbb40fad022 100644 --- a/include/llvm/ExecutionEngine/Orc/OrcError.h +++ b/include/llvm/ExecutionEngine/Orc/OrcError.h @@ -32,6 +32,7 @@ enum class OrcErrorCode : int { RPCResponseAbandoned, UnexpectedRPCCall, UnexpectedRPCResponse, + UnknownErrorCodeFromRemote }; std::error_code orcError(OrcErrorCode ErrCode); diff --git a/include/llvm/ExecutionEngine/Orc/RPCSerialization.h b/include/llvm/ExecutionEngine/Orc/RPCSerialization.h index 359a9d81b22..baad2014762 100644 --- a/include/llvm/ExecutionEngine/Orc/RPCSerialization.h +++ b/include/llvm/ExecutionEngine/Orc/RPCSerialization.h @@ -12,6 +12,7 @@ #include "OrcError.h" #include "llvm/Support/thread.h" +#include #include #include @@ -114,6 +115,35 @@ public: static const char* getName() { return "std::string"; } }; +template <> +class RPCTypeName { +public: + static const char* getName() { return "Error"; } +}; + +template +class RPCTypeName> { +public: + static const char* getName() { + std::lock_guard Lock(NameMutex); + if (Name.empty()) + raw_string_ostream(Name) << "Expected<" + << RPCTypeNameSequence() + << ">"; + return Name.data(); + } + +private: + static std::mutex NameMutex; + static std::string Name; +}; + +template +std::mutex RPCTypeName>::NameMutex; + +template +std::string RPCTypeName>::Name; + template class RPCTypeName> { public: @@ -243,8 +273,10 @@ class SequenceSerialization { public: template - static Error serialize(ChannelT &C, const CArgT &CArg) { - return SerializationTraits::serialize(C, CArg); + static Error serialize(ChannelT &C, CArgT &&CArg) { + return SerializationTraits::type>:: + serialize(C, std::forward(CArg)); } template @@ -258,19 +290,21 @@ class SequenceSerialization { public: template - static Error serialize(ChannelT &C, const CArgT &CArg, - const CArgTs&... CArgs) { + static Error serialize(ChannelT &C, CArgT &&CArg, + CArgTs &&... CArgs) { if (auto Err = - SerializationTraits::serialize(C, CArg)) + SerializationTraits::type>:: + serialize(C, std::forward(CArg))) return Err; if (auto Err = SequenceTraits::emitSeparator(C)) return Err; - return SequenceSerialization::serialize(C, CArgs...); + return SequenceSerialization:: + serialize(C, std::forward(CArgs)...); } template static Error deserialize(ChannelT &C, CArgT &CArg, - CArgTs&... CArgs) { + CArgTs &... CArgs) { if (auto Err = SerializationTraits::deserialize(C, CArg)) return Err; @@ -281,8 +315,9 @@ public: }; template -Error serializeSeq(ChannelT &C, const ArgTs &... Args) { - return SequenceSerialization::serialize(C, Args...); +Error serializeSeq(ChannelT &C, ArgTs &&... Args) { + return SequenceSerialization::type...>:: + serialize(C, std::forward(Args)...); } template @@ -290,6 +325,186 @@ Error deserializeSeq(ChannelT &C, ArgTs &... Args) { return SequenceSerialization::deserialize(C, Args...); } +template +class SerializationTraits { +public: + + using WrappedErrorSerializer = + std::function; + + using WrappedErrorDeserializer = + std::function; + + template + static void registerErrorType(std::string Name, SerializeFtor Serialize, + DeserializeFtor Deserialize) { + assert(!Name.empty() && + "The empty string is reserved for the Success value"); + + std::lock_guard Lock(SerializersMutex); + + // We're abusing the stability of std::map here: We take a reference to the + // key of the deserializers map to save us from duplicating the string in + // the serializer. This should be changed to use a stringpool if we switch + // to a map type that may move keys in memory. + auto I = + Deserializers.insert(Deserializers.begin(), + std::make_pair(std::move(Name), + std::move(Deserialize))); + + const std::string &KeyName = I->first; + // FIXME: Move capture Serialize once we have C++14. + Serializers[ErrorInfoT::classID()] = + [&KeyName, Serialize](ChannelT &C, const ErrorInfoBase &EIB) -> Error { + assert(EIB.dynamicClassID() == ErrorInfoT::classID() && + "Serializer called for wrong error type"); + if (auto Err = serializeSeq(C, KeyName)) + return Err; + return Serialize(C, static_cast(EIB)); + }; + } + + static Error serialize(ChannelT &C, Error &&Err) { + std::lock_guard Lock(SerializersMutex); + if (!Err) + return serializeSeq(C, std::string()); + + return handleErrors(std::move(Err), + [&C](const ErrorInfoBase &EIB) { + auto SI = Serializers.find(EIB.dynamicClassID()); + if (SI == Serializers.end()) + return serializeAsStringError(C, EIB); + return (SI->second)(C, EIB); + }); + } + + static Error deserialize(ChannelT &C, Error &Err) { + std::lock_guard Lock(SerializersMutex); + + std::string Key; + if (auto Err = deserializeSeq(C, Key)) + return Err; + + if (Key.empty()) { + ErrorAsOutParameter EAO(&Err); + Err = Error::success(); + return Error::success(); + } + + auto DI = Deserializers.find(Key); + assert(DI != Deserializers.end() && "No deserializer for error type"); + return (DI->second)(C, Err); + } + +private: + + static Error serializeAsStringError(ChannelT &C, const ErrorInfoBase &EIB) { + assert(EIB.dynamicClassID() != StringError::classID() && + "StringError serialization not registered"); + std::string ErrMsg; + { + raw_string_ostream ErrMsgStream(ErrMsg); + EIB.log(ErrMsgStream); + } + return serialize(C, make_error(std::move(ErrMsg), + inconvertibleErrorCode())); + } + + static std::mutex SerializersMutex; + static std::map Serializers; + static std::map Deserializers; +}; + +template +std::mutex SerializationTraits::SerializersMutex; + +template +std::map::WrappedErrorSerializer> +SerializationTraits::Serializers; + +template +std::map::WrappedErrorDeserializer> +SerializationTraits::Deserializers; + +template +void registerStringError() { + static bool AlreadyRegistered = false; + if (!AlreadyRegistered) { + SerializationTraits:: + template registerErrorType( + "StringError", + [](ChannelT &C, const StringError &SE) { + return serializeSeq(C, SE.getMessage()); + }, + [](ChannelT &C, Error &Err) { + ErrorAsOutParameter EAO(&Err); + std::string Msg; + if (auto E2 = deserializeSeq(C, Msg)) + return E2; + Err = + make_error(std::move(Msg), + orcError( + OrcErrorCode::UnknownErrorCodeFromRemote)); + return Error::success(); + }); + AlreadyRegistered = true; + } +}; + +/// SerializationTraits for Expected from an Expected. +template +class SerializationTraits, Expected> { +public: + + static Error serialize(ChannelT &C, Expected &&ValOrErr) { + if (ValOrErr) { + if (auto Err = serializeSeq(C, true)) + return Err; + return SerializationTraits::serialize(C, *ValOrErr); + } + if (auto Err = serializeSeq(C, false)) + return Err; + return serializeSeq(C, ValOrErr.takeError()); + } + + static Error deserialize(ChannelT &C, Expected &ValOrErr) { + ExpectedAsOutParameter EAO(&ValOrErr); + bool HasValue; + if (auto Err = deserializeSeq(C, HasValue)) + return Err; + if (HasValue) + return SerializationTraits::deserialize(C, *ValOrErr); + Error Err = Error::success(); + if (auto E2 = deserializeSeq(C, Err)) + return E2; + ValOrErr = std::move(Err); + return Error::success(); + } +}; + +/// SerializationTraits for Expected from a T2. +template +class SerializationTraits, T2> { +public: + + static Error serialize(ChannelT &C, T2 &&Val) { + return serializeSeq(C, Expected(std::forward(Val))); + } +}; + +/// SerializationTraits for Expected from an Error. +template +class SerializationTraits, Error> { +public: + + static Error serialize(ChannelT &C, Error &&Err) { + return serializeSeq(C, Expected(std::move(Err))); + } +}; + /// SerializationTraits default specialization for std::pair. template class SerializationTraits> { diff --git a/include/llvm/ExecutionEngine/Orc/RPCUtils.h b/include/llvm/ExecutionEngine/Orc/RPCUtils.h index c8b3704b5fc..6212f64ff31 100644 --- a/include/llvm/ExecutionEngine/Orc/RPCUtils.h +++ b/include/llvm/ExecutionEngine/Orc/RPCUtils.h @@ -129,7 +129,7 @@ public: CouldNotNegotiate(std::string Signature); std::error_code convertToErrorCode() const override; - void log(raw_ostream &OS) const override; + void log(raw_ostream &OS) const override; const std::string &getSignature() const { return Signature; } private: std::string Signature; @@ -362,30 +362,122 @@ template <> class ResultTraits : public ResultTraits {}; template class ResultTraits> : public ResultTraits {}; +// Determines whether an RPC function's defined error return type supports +// error return value. +template +class SupportsErrorReturn { +public: + static const bool value = false; +}; + +template <> +class SupportsErrorReturn { +public: + static const bool value = true; +}; + +template +class SupportsErrorReturn> { +public: + static const bool value = true; +}; + +// RespondHelper packages return values based on whether or not the declared +// RPC function return type supports error returns. +template +class RespondHelper; + +// RespondHelper specialization for functions that support error returns. +template <> +class RespondHelper { +public: + + // Send Expected. + template + static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId, + SequenceNumberT SeqNo, + Expected ResultOrErr) { + if (!ResultOrErr && ResultOrErr.template errorIsA()) + return ResultOrErr.takeError(); + + // Open the response message. + if (auto Err = C.startSendMessage(ResponseId, SeqNo)) + return Err; + + // Serialize the result. + if (auto Err = + SerializationTraits>::serialize( + C, std::move(ResultOrErr))) + return Err; + + // Close the response message. + return C.endSendMessage(); + } + + template + static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId, + SequenceNumberT SeqNo, Error Err) { + if (Err && Err.isA()) + return Err; + if (auto Err2 = C.startSendMessage(ResponseId, SeqNo)) + return Err2; + if (auto Err2 = serializeSeq(C, std::move(Err))) + return Err2; + return C.endSendMessage(); + } + +}; + +// RespondHelper specialization for functions that do not support error returns. +template <> +class RespondHelper { +public: + + template + static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId, + SequenceNumberT SeqNo, + Expected ResultOrErr) { + if (auto Err = ResultOrErr.takeError()) + return Err; + + // Open the response message. + if (auto Err = C.startSendMessage(ResponseId, SeqNo)) + return Err; + + // Serialize the result. + if (auto Err = + SerializationTraits::serialize( + C, *ResultOrErr)) + return Err; + + // Close the response message. + return C.endSendMessage(); + } + + template + static Error sendResult(ChannelT &C, const FunctionIdT &ResponseId, + SequenceNumberT SeqNo, Error Err) { + if (Err) + return Err; + if (auto Err2 = C.startSendMessage(ResponseId, SeqNo)) + return Err2; + return C.endSendMessage(); + } + +}; + + // Send a response of the given wire return type (WireRetT) over the // channel, with the given sequence number. template -static Error respond(ChannelT &C, const FunctionIdT &ResponseId, - SequenceNumberT SeqNo, Expected ResultOrErr) { - // If this was an error bail out. - // FIXME: Send an "error" message to the client if this is not a channel - // failure? - if (auto Err = ResultOrErr.takeError()) - return Err; - - // Open the response message. - if (auto Err = C.startSendMessage(ResponseId, SeqNo)) - return Err; - - // Serialize the result. - if (auto Err = - SerializationTraits::serialize( - C, *ResultOrErr)) - return Err; - - // Close the response message. - return C.endSendMessage(); +Error respond(ChannelT &C, const FunctionIdT &ResponseId, + SequenceNumberT SeqNo, Expected ResultOrErr) { + return RespondHelper::value>:: + template sendResult(C, ResponseId, SeqNo, std::move(ResultOrErr)); } // Send an empty response message on the given channel to indicate that @@ -394,11 +486,8 @@ template Error respond(ChannelT &C, const FunctionIdT &ResponseId, SequenceNumberT SeqNo, Error Err) { - if (Err) - return Err; - if (auto Err2 = C.startSendMessage(ResponseId, SeqNo)) - return Err2; - return C.endSendMessage(); + return RespondHelper::value>:: + sendResult(C, ResponseId, SeqNo, std::move(Err)); } // Converts a given type to the equivalent error return type. @@ -670,6 +759,72 @@ private: HandlerT Handler; }; +template +class ResponseHandlerImpl, HandlerT> + : public ResponseHandler { +public: + ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {} + + // Handle the result by deserializing it from the channel then passing it + // to the user defined handler. + Error handleResponse(ChannelT &C) override { + using HandlerArgType = typename ResponseHandlerArg< + typename HandlerTraits::Type>::ArgType; + HandlerArgType Result((typename HandlerArgType::value_type())); + + if (auto Err = + SerializationTraits, + HandlerArgType>::deserialize(C, Result)) + return Err; + if (auto Err = C.endReceiveMessage()) + return Err; + return Handler(std::move(Result)); + } + + // Abandon this response by calling the handler with an 'abandoned response' + // error. + void abandon() override { + if (auto Err = Handler(this->createAbandonedResponseError())) { + // Handlers should not fail when passed an abandoned response error. + report_fatal_error(std::move(Err)); + } + } + +private: + HandlerT Handler; +}; + +template +class ResponseHandlerImpl + : public ResponseHandler { +public: + ResponseHandlerImpl(HandlerT Handler) : Handler(std::move(Handler)) {} + + // Handle the result by deserializing it from the channel then passing it + // to the user defined handler. + Error handleResponse(ChannelT &C) override { + Error Result = Error::success(); + if (auto Err = + SerializationTraits::deserialize(C, Result)) + return Err; + if (auto Err = C.endReceiveMessage()) + return Err; + return Handler(std::move(Result)); + } + + // Abandon this response by calling the handler with an 'abandoned response' + // error. + void abandon() override { + if (auto Err = Handler(this->createAbandonedResponseError())) { + // Handlers should not fail when passed an abandoned response error. + report_fatal_error(std::move(Err)); + } + } + +private: + HandlerT Handler; +}; + // Create a ResponseHandler from a given user handler. template std::unique_ptr> createResponseHandler(HandlerT H) { diff --git a/include/llvm/Support/Error.h b/include/llvm/Support/Error.h index d5421b97c5f..a3482f5a58b 100644 --- a/include/llvm/Support/Error.h +++ b/include/llvm/Support/Error.h @@ -236,6 +236,14 @@ public: return getPtr() && getPtr()->isA(ErrT::classID()); } + /// Returns the dynamic class id of this error, or null if this is a success + /// value. + const void* dynamicClassID() const { + if (!getPtr()) + return nullptr; + return getPtr()->dynamicClassID(); + } + private: void assertIsChecked() { #if LLVM_ENABLE_ABI_BREAKING_CHECKS @@ -635,6 +643,7 @@ private: /// takeError(). It also adds an bool errorIsA() method for testing the /// error class type. template class LLVM_NODISCARD Expected { + template friend class ExpectedAsOutParameter; template friend class Expected; static const bool isRef = std::is_reference::value; typedef ReferenceStorage::type> wrap; @@ -743,7 +752,7 @@ public: /// \brief Check that this Expected is an error of type ErrT. template bool errorIsA() const { - return HasError && getErrorStorage()->template isA(); + return HasError && (*getErrorStorage())->template isA(); } /// \brief Take ownership of the stored error. @@ -838,6 +847,18 @@ private: return reinterpret_cast(ErrorStorage.buffer); } + const error_type *getErrorStorage() const { + assert(HasError && "Cannot get error when a value exists!"); + return reinterpret_cast(ErrorStorage.buffer); + } + + // Used by ExpectedAsOutParameter to reset the checked flag. + void setUnchecked() { +#if LLVM_ENABLE_ABI_BREAKING_CHECKS + Unchecked = true; +#endif + } + void assertIsChecked() { #if LLVM_ENABLE_ABI_BREAKING_CHECKS if (Unchecked) { @@ -864,6 +885,28 @@ private: #endif }; +/// Helper for Expecteds used as out-parameters. +/// +/// See ErrorAsOutParameter. +template +class ExpectedAsOutParameter { +public: + + ExpectedAsOutParameter(Expected *ValOrErr) + : ValOrErr(ValOrErr) { + if (ValOrErr) + (void)!!*ValOrErr; + } + + ~ExpectedAsOutParameter() { + if (ValOrErr) + ValOrErr->setUnchecked(); + } + +private: + Expected *ValOrErr; +}; + /// This class wraps a std::error_code in a Error. /// /// This is useful if you're writing an interface that returns a Error diff --git a/lib/ExecutionEngine/Orc/OrcError.cpp b/lib/ExecutionEngine/Orc/OrcError.cpp index c1f228c98cb..9e70c4ac1db 100644 --- a/lib/ExecutionEngine/Orc/OrcError.cpp +++ b/lib/ExecutionEngine/Orc/OrcError.cpp @@ -49,6 +49,9 @@ public: return "Unexpected RPC call"; case OrcErrorCode::UnexpectedRPCResponse: return "Unexpected RPC response"; + case OrcErrorCode::UnknownErrorCodeFromRemote: + return "Unknown error returned from remote RPC function " + "(Use StringError to get error message)"; } llvm_unreachable("Unhandled error code"); } diff --git a/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp b/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp index a84610f5eb4..095bf25291b 100644 --- a/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp +++ b/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp @@ -47,6 +47,54 @@ namespace rpc { class RPCBar {}; +class DummyError : public ErrorInfo { +public: + + static char ID; + + DummyError(uint32_t Val) : Val(Val) {} + + std::error_code convertToErrorCode() const override { + // Use a nonsense error code - we want to verify that errors + // transmitted over the network are replaced with + // OrcErrorCode::UnknownErrorCodeFromRemote. + return orcError(OrcErrorCode::RemoteAllocatorDoesNotExist); + } + + void log(raw_ostream &OS) const override { + OS << "Dummy error " << Val; + } + + uint32_t getValue() const { return Val; } + +public: + uint32_t Val; +}; + +char DummyError::ID = 0; + +template +void registerDummyErrorSerialization() { + static bool AlreadyRegistered = false; + if (!AlreadyRegistered) { + SerializationTraits:: + template registerErrorType( + "DummyError", + [](ChannelT &C, const DummyError &DE) { + return serializeSeq(C, DE.getValue()); + }, + [](ChannelT &C, Error &Err) -> Error { + ErrorAsOutParameter EAO(&Err); + uint32_t Val; + if (auto Err = deserializeSeq(C, Val)) + return Err; + Err = make_error(Val); + return Error::success(); + }); + AlreadyRegistered = true; + } +} + namespace llvm { namespace orc { namespace rpc { @@ -98,6 +146,16 @@ namespace DummyRPCAPI { static const char* getName() { return "CustomType"; } }; + class ErrorFunc : public Function { + public: + static const char* getName() { return "ErrorFunc"; } + }; + + class ExpectedFunc : public Function()> { + public: + static const char* getName() { return "ExpectedFunc"; } + }; + } class DummyRPCEndpoint : public SingleThreadedRPCEndpoint { @@ -493,6 +551,140 @@ TEST(DummyRPC, TestWithAltCustomType) { ServerThread.join(); } +TEST(DummyRPC, ReturnErrorSuccess) { + registerDummyErrorSerialization(); + + auto Channels = createPairedQueueChannels(); + DummyRPCEndpoint Client(*Channels.first); + DummyRPCEndpoint Server(*Channels.second); + + std::thread ServerThread([&]() { + Server.addHandler( + []() { + return Error::success(); + }); + + // Handle the negotiate plus one call. + for (unsigned I = 0; I != 2; ++I) + cantFail(Server.handleOne()); + }); + + cantFail(Client.callAsync( + [&](Error Err) { + EXPECT_FALSE(!!Err) << "Expected success value"; + return Error::success(); + })); + + cantFail(Client.handleOne()); + + ServerThread.join(); +} + +TEST(DummyRPC, ReturnErrorFailure) { + registerDummyErrorSerialization(); + + auto Channels = createPairedQueueChannels(); + DummyRPCEndpoint Client(*Channels.first); + DummyRPCEndpoint Server(*Channels.second); + + std::thread ServerThread([&]() { + Server.addHandler( + []() { + return make_error(42); + }); + + // Handle the negotiate plus one call. + for (unsigned I = 0; I != 2; ++I) + cantFail(Server.handleOne()); + }); + + cantFail(Client.callAsync( + [&](Error Err) { + EXPECT_TRUE(Err.isA()) + << "Incorrect error type"; + return handleErrors( + std::move(Err), + [](const DummyError &DE) { + EXPECT_EQ(DE.getValue(), 42ULL) + << "Incorrect DummyError serialization"; + }); + })); + + cantFail(Client.handleOne()); + + ServerThread.join(); +} + +TEST(DummyRPC, RPCExpectedSuccess) { + registerDummyErrorSerialization(); + + auto Channels = createPairedQueueChannels(); + DummyRPCEndpoint Client(*Channels.first); + DummyRPCEndpoint Server(*Channels.second); + + std::thread ServerThread([&]() { + Server.addHandler( + []() -> uint32_t { + return 42; + }); + + // Handle the negotiate plus one call. + for (unsigned I = 0; I != 2; ++I) + cantFail(Server.handleOne()); + }); + + cantFail(Client.callAsync( + [&](Expected ValOrErr) { + EXPECT_TRUE(!!ValOrErr) + << "Expected success value"; + EXPECT_EQ(*ValOrErr, 42ULL) + << "Incorrect Expected deserialization"; + return Error::success(); + })); + + cantFail(Client.handleOne()); + + ServerThread.join(); +}; + +TEST(DummyRPC, RPCExpectedFailure) { + registerDummyErrorSerialization(); + + auto Channels = createPairedQueueChannels(); + DummyRPCEndpoint Client(*Channels.first); + DummyRPCEndpoint Server(*Channels.second); + + std::thread ServerThread([&]() { + Server.addHandler( + []() -> Expected { + return make_error(7); + }); + + // Handle the negotiate plus one call. + for (unsigned I = 0; I != 2; ++I) + cantFail(Server.handleOne()); + }); + + cantFail(Client.callAsync( + [&](Expected ValOrErr) { + EXPECT_FALSE(!!ValOrErr) + << "Expected failure value"; + auto Err = ValOrErr.takeError(); + EXPECT_TRUE(Err.isA()) + << "Incorrect error type"; + return handleErrors( + std::move(Err), + [](const DummyError &DE) { + EXPECT_EQ(DE.getValue(), 7ULL) + << "Incorrect DummyError serialization"; + }); + })); + + cantFail(Client.handleOne()); + + ServerThread.join(); +}; + TEST(DummyRPC, TestParallelCallGroup) { auto Channels = createPairedQueueChannels(); DummyRPCEndpoint Client(*Channels.first);