From: Lang Hames Date: Mon, 30 Jul 2018 21:08:06 +0000 (+0000) Subject: [ORC] Add SerializationTraits for std::set and std::map. X-Git-Url: https://granicus.if.org/sourcecode?a=commitdiff_plain;h=865d24fcb4446db39e61ff1c9bd818cf64e2c938;p=llvm [ORC] Add SerializationTraits for std::set and std::map. Also, make SerializationTraits for pairs forward the actual pair template type arguments to the underlying serializer. This allows, for example, std::pair to be passed as an argument to an RPC call expecting a std::pair, since there is an underlying serializer from StringRef to std::string that can be used. git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@338305 91177308-0d34-0410-b5e6-96231b3b80d8 --- diff --git a/include/llvm/ExecutionEngine/Orc/RPCSerialization.h b/include/llvm/ExecutionEngine/Orc/RPCSerialization.h index 569c50602f3..1e5f6ced597 100644 --- a/include/llvm/ExecutionEngine/Orc/RPCSerialization.h +++ b/include/llvm/ExecutionEngine/Orc/RPCSerialization.h @@ -14,7 +14,10 @@ #include "llvm/Support/thread.h" #include #include +#include #include +#include +#include namespace llvm { namespace orc { @@ -205,6 +208,42 @@ std::mutex RPCTypeName>::NameMutex; template std::string RPCTypeName>::Name; +template class RPCTypeName> { +public: + static const char *getName() { + std::lock_guard Lock(NameMutex); + if (Name.empty()) + raw_string_ostream(Name) + << "std::set<" << RPCTypeName::getName() << ">"; + return Name.data(); + } + +private: + static std::mutex NameMutex; + static std::string Name; +}; + +template std::mutex RPCTypeName>::NameMutex; +template std::string RPCTypeName>::Name; + +template class RPCTypeName> { +public: + static const char *getName() { + std::lock_guard Lock(NameMutex); + if (Name.empty()) + raw_string_ostream(Name) + << "std::map<" << RPCTypeNameSequence() << ">"; + return Name.data(); + } + +private: + static std::mutex NameMutex; + static std::string Name; +}; + +template +std::mutex RPCTypeName>::NameMutex; +template std::string RPCTypeName>::Name; /// The SerializationTraits class describes how to serialize and /// deserialize an instance of type T to/from an abstract channel of type @@ -527,15 +566,20 @@ public: }; /// SerializationTraits default specialization for std::pair. -template -class SerializationTraits> { +template +class SerializationTraits, std::pair> { public: - static Error serialize(ChannelT &C, const std::pair &V) { - return serializeSeq(C, V.first, V.second); + static Error serialize(ChannelT &C, const std::pair &V) { + if (auto Err = SerializationTraits::serialize(C, V.first)) + return Err; + return SerializationTraits::serialize(C, V.second); } - static Error deserialize(ChannelT &C, std::pair &V) { - return deserializeSeq(C, V.first, V.second); + static Error deserialize(ChannelT &C, std::pair &V) { + if (auto Err = + SerializationTraits::deserialize(C, V.first)) + return Err; + return SerializationTraits::deserialize(C, V.second); } }; @@ -589,6 +633,9 @@ public: /// Deserialize a std::vector to a std::vector. static Error deserialize(ChannelT &C, std::vector &V) { + assert(V.empty() && + "Expected default-constructed vector to deserialize into"); + uint64_t Count = 0; if (auto Err = deserializeSeq(C, Count)) return Err; @@ -602,6 +649,92 @@ public: } }; +template +class SerializationTraits, std::set> { +public: + /// Serialize a std::set from std::set. + static Error serialize(ChannelT &C, const std::set &S) { + if (auto Err = serializeSeq(C, static_cast(S.size()))) + return Err; + + for (const auto &E : S) + if (auto Err = SerializationTraits::serialize(C, E)) + return Err; + + return Error::success(); + } + + /// Deserialize a std::set to a std::set. + static Error deserialize(ChannelT &C, std::set &S) { + assert(S.empty() && "Expected default-constructed set to deserialize into"); + + uint64_t Count = 0; + if (auto Err = deserializeSeq(C, Count)) + return Err; + + while (Count-- != 0) { + T2 Val; + if (auto Err = SerializationTraits::deserialize(C, Val)) + return Err; + + auto Added = S.insert(Val).second; + if (!Added) + return make_error("Duplicate element in deserialized set", + orcError(OrcErrorCode::UnknownORCError)); + } + + return Error::success(); + } +}; + +template +class SerializationTraits, std::map> { +public: + /// Serialize a std::map from std::map. + static Error serialize(ChannelT &C, const std::map &M) { + if (auto Err = serializeSeq(C, static_cast(M.size()))) + return Err; + + for (const auto &E : M) { + if (auto Err = + SerializationTraits::serialize(C, E.first)) + return Err; + if (auto Err = + SerializationTraits::serialize(C, E.second)) + return Err; + } + + return Error::success(); + } + + /// Deserialize a std::map to a std::map. + static Error deserialize(ChannelT &C, std::map &M) { + assert(M.empty() && "Expected default-constructed map to deserialize into"); + + uint64_t Count = 0; + if (auto Err = deserializeSeq(C, Count)) + return Err; + + while (Count-- != 0) { + std::pair Val; + if (auto Err = + SerializationTraits::deserialize(C, Val.first)) + return Err; + + if (auto Err = + SerializationTraits::deserialize(C, Val.second)) + return Err; + + auto Added = M.insert(Val).second; + if (!Added) + return make_error("Duplicate element in deserialized map", + orcError(OrcErrorCode::UnknownORCError)); + } + + return Error::success(); + } +}; + } // end namespace rpc } // end namespace orc } // end namespace llvm diff --git a/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp b/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp index 7fe449b7016..c884aaa718a 100644 --- a/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp +++ b/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp @@ -133,10 +133,10 @@ namespace DummyRPCAPI { }; class AllTheTypes - : public Function)> { + : public Function, + std::set, std::map)> { public: static const char* getName() { return "AllTheTypes"; } }; @@ -451,43 +451,50 @@ TEST(DummyRPC, TestSerialization) { DummyRPCEndpoint Server(*Channels.second); std::thread ServerThread([&]() { - Server.addHandler( - [&](int8_t S8, uint8_t U8, int16_t S16, uint16_t U16, - int32_t S32, uint32_t U32, int64_t S64, uint64_t U64, - bool B, std::string S, std::vector V) { - - EXPECT_EQ(S8, -101) << "int8_t serialization broken"; - EXPECT_EQ(U8, 250) << "uint8_t serialization broken"; - EXPECT_EQ(S16, -10000) << "int16_t serialization broken"; - EXPECT_EQ(U16, 10000) << "uint16_t serialization broken"; - EXPECT_EQ(S32, -1000000000) << "int32_t serialization broken"; - EXPECT_EQ(U32, 1000000000ULL) << "uint32_t serialization broken"; - EXPECT_EQ(S64, -10000000000) << "int64_t serialization broken"; - EXPECT_EQ(U64, 10000000000ULL) << "uint64_t serialization broken"; - EXPECT_EQ(B, true) << "bool serialization broken"; - EXPECT_EQ(S, "foo") << "std::string serialization broken"; - EXPECT_EQ(V, std::vector({42, 7})) - << "std::vector serialization broken"; - return Error::success(); - }); - - { - // Poke the server to handle the negotiate call. - auto Err = Server.handleOne(); - EXPECT_FALSE(!!Err) << "Server failed to handle call to negotiate"; - } - - { - // Poke the server to handle the AllTheTypes call. - auto Err = Server.handleOne(); - EXPECT_FALSE(!!Err) << "Server failed to handle call to void(bool)"; - } + Server.addHandler([&](int8_t S8, uint8_t U8, + int16_t S16, uint16_t U16, + int32_t S32, uint32_t U32, + int64_t S64, uint64_t U64, + bool B, std::string S, + std::vector V, + std::set S2, + std::map M) { + EXPECT_EQ(S8, -101) << "int8_t serialization broken"; + EXPECT_EQ(U8, 250) << "uint8_t serialization broken"; + EXPECT_EQ(S16, -10000) << "int16_t serialization broken"; + EXPECT_EQ(U16, 10000) << "uint16_t serialization broken"; + EXPECT_EQ(S32, -1000000000) << "int32_t serialization broken"; + EXPECT_EQ(U32, 1000000000ULL) << "uint32_t serialization broken"; + EXPECT_EQ(S64, -10000000000) << "int64_t serialization broken"; + EXPECT_EQ(U64, 10000000000ULL) << "uint64_t serialization broken"; + EXPECT_EQ(B, true) << "bool serialization broken"; + EXPECT_EQ(S, "foo") << "std::string serialization broken"; + EXPECT_EQ(V, std::vector({42, 7})) + << "std::vector serialization broken"; + EXPECT_EQ(S2, std::set({7, 42})) << "std::set serialization broken"; + EXPECT_EQ(M, (std::map({{7, false}, {42, true}}))) + << "std::map serialization broken"; + return Error::success(); }); + { + // Poke the server to handle the negotiate call. + auto Err = Server.handleOne(); + EXPECT_FALSE(!!Err) << "Server failed to handle call to negotiate"; + } + + { + // Poke the server to handle the AllTheTypes call. + auto Err = Server.handleOne(); + EXPECT_FALSE(!!Err) << "Server failed to handle call to void(bool)"; + } + }); { // Make an async call. - std::vector v({42, 7}); + std::vector V({42, 7}); + std::set S({7, 42}); + std::map M({{7, false}, {42, true}}); auto Err = Client.callAsync( [](Error Err) { EXPECT_FALSE(!!Err) << "Async AllTheTypes response handler failed"; @@ -497,7 +504,7 @@ TEST(DummyRPC, TestSerialization) { static_cast(-10000), static_cast(10000), static_cast(-1000000000), static_cast(1000000000), static_cast(-10000000000), static_cast(10000000000), - true, std::string("foo"), v); + true, std::string("foo"), V, S, M); EXPECT_FALSE(!!Err) << "Client.callAsync failed for AllTheTypes"; }