From: Lang Hames Date: Thu, 6 Apr 2017 01:49:21 +0000 (+0000) Subject: [Orc] Break QueueChannel out into its own header and add a utility, X-Git-Url: https://granicus.if.org/sourcecode?a=commitdiff_plain;h=456e01a064dec93abd54cee23ef0ca034fd313c6;p=llvm [Orc] Break QueueChannel out into its own header and add a utility, createPairedQueueChannels, to simplify channel creation in the RPC unit tests. git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@299611 91177308-0d34-0410-b5e6-96231b3b80d8 --- diff --git a/unittests/ExecutionEngine/Orc/CMakeLists.txt b/unittests/ExecutionEngine/Orc/CMakeLists.txt index 009210fb442..db40c4213bd 100644 --- a/unittests/ExecutionEngine/Orc/CMakeLists.txt +++ b/unittests/ExecutionEngine/Orc/CMakeLists.txt @@ -14,11 +14,12 @@ add_llvm_unittest(OrcJITTests IndirectionUtilsTest.cpp GlobalMappingLayerTest.cpp LazyEmittingLayerTest.cpp - RTDyldObjectLinkingLayerTest.cpp ObjectTransformLayerTest.cpp OrcCAPITest.cpp OrcTestCommon.cpp + QueueChannel.cpp RPCUtilsTest.cpp + RTDyldObjectLinkingLayerTest.cpp ) target_link_libraries(OrcJITTests ${LLVM_PTHREAD_LIB}) diff --git a/unittests/ExecutionEngine/Orc/QueueChannel.cpp b/unittests/ExecutionEngine/Orc/QueueChannel.cpp new file mode 100644 index 00000000000..e309a7e428c --- /dev/null +++ b/unittests/ExecutionEngine/Orc/QueueChannel.cpp @@ -0,0 +1,14 @@ +//===-------- QueueChannel.cpp - Unit tests the remote executors ----------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#include "QueueChannel.h" + +char llvm::QueueChannelError::ID; +char llvm::QueueChannelClosedError::ID; + diff --git a/unittests/ExecutionEngine/Orc/QueueChannel.h b/unittests/ExecutionEngine/Orc/QueueChannel.h new file mode 100644 index 00000000000..1528ac94a94 --- /dev/null +++ b/unittests/ExecutionEngine/Orc/QueueChannel.h @@ -0,0 +1,145 @@ +//===----------------------- Queue.h - RPC Queue ------------------*-c++-*-===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_UNITTESTS_EXECUTIONENGINE_ORC_QUEUECHANNEL_H +#define LLVM_UNITTESTS_EXECUTIONENGINE_ORC_QUEUECHANNEL_H + +#include "llvm/ExecutionEngine/Orc/RawByteChannel.h" +#include "llvm/Support/Error.h" + +#include + +namespace llvm { + +class QueueChannelError : public ErrorInfo { +public: + static char ID; +}; + +class QueueChannelClosedError + : public ErrorInfo { +public: + static char ID; + std::error_code convertToErrorCode() const override { + return inconvertibleErrorCode(); + } + + void log(raw_ostream &OS) const override { + OS << "Queue closed"; + } +}; + +class Queue : public std::queue { +public: + using ErrorInjector = std::function; + + Queue() + : ReadError([]() { return Error::success(); }), + WriteError([]() { return Error::success(); }) {} + + Queue(const Queue&) = delete; + Queue& operator=(const Queue&) = delete; + Queue(Queue&&) = delete; + Queue& operator=(Queue&&) = delete; + + std::mutex &getMutex() { return M; } + std::condition_variable &getCondVar() { return CV; } + Error checkReadError() { return ReadError(); } + Error checkWriteError() { return WriteError(); } + void setReadError(ErrorInjector NewReadError) { + { + std::lock_guard Lock(M); + ReadError = std::move(NewReadError); + } + CV.notify_one(); + } + void setWriteError(ErrorInjector NewWriteError) { + std::lock_guard Lock(M); + WriteError = std::move(NewWriteError); + } +private: + std::mutex M; + std::condition_variable CV; + std::function ReadError, WriteError; +}; + +class QueueChannel : public orc::rpc::RawByteChannel { +public: + QueueChannel(std::shared_ptr InQueue, + std::shared_ptr OutQueue) + : InQueue(InQueue), OutQueue(OutQueue) {} + + QueueChannel(const QueueChannel&) = delete; + QueueChannel& operator=(const QueueChannel&) = delete; + QueueChannel(QueueChannel&&) = delete; + QueueChannel& operator=(QueueChannel&&) = delete; + + Error readBytes(char *Dst, unsigned Size) override { + std::unique_lock Lock(InQueue->getMutex()); + while (Size) { + { + Error Err = InQueue->checkReadError(); + while (!Err && InQueue->empty()) { + InQueue->getCondVar().wait(Lock); + Err = InQueue->checkReadError(); + } + if (Err) + return Err; + } + *Dst++ = InQueue->front(); + --Size; + ++NumRead; + InQueue->pop(); + } + return Error::success(); + } + + Error appendBytes(const char *Src, unsigned Size) override { + std::unique_lock Lock(OutQueue->getMutex()); + while (Size--) { + if (Error Err = OutQueue->checkWriteError()) + return Err; + OutQueue->push(*Src++); + ++NumWritten; + } + OutQueue->getCondVar().notify_one(); + return Error::success(); + } + + Error send() override { return Error::success(); } + + void close() { + auto ChannelClosed = []() { return make_error(); }; + InQueue->setReadError(ChannelClosed); + InQueue->setWriteError(ChannelClosed); + OutQueue->setReadError(ChannelClosed); + OutQueue->setWriteError(ChannelClosed); + } + + uint64_t NumWritten = 0; + uint64_t NumRead = 0; + +private: + + std::shared_ptr InQueue; + std::shared_ptr OutQueue; +}; + +inline std::pair, std::unique_ptr> +createPairedQueueChannels() { + auto Q1 = std::make_shared(); + auto Q2 = std::make_shared(); + auto C1 = llvm::make_unique(Q1, Q2); + auto C2 = llvm::make_unique(Q2, Q1); + return std::make_pair(std::move(C1), std::move(C2)); +} + +} + +#endif diff --git a/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp b/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp index 355d20b4f78..3d46ef88f7c 100644 --- a/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp +++ b/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp @@ -7,8 +7,8 @@ // //===----------------------------------------------------------------------===// -#include "llvm/ExecutionEngine/Orc/RawByteChannel.h" #include "llvm/ExecutionEngine/Orc/RPCUtils.h" +#include "QueueChannel.h" #include "gtest/gtest.h" #include @@ -17,47 +17,6 @@ using namespace llvm; using namespace llvm::orc; using namespace llvm::orc::rpc; -class Queue : public std::queue { -public: - std::mutex &getMutex() { return M; } - std::condition_variable &getCondVar() { return CV; } -private: - std::mutex M; - std::condition_variable CV; -}; - -class QueueChannel : public RawByteChannel { -public: - QueueChannel(Queue &InQueue, Queue &OutQueue) - : InQueue(InQueue), OutQueue(OutQueue) {} - - Error readBytes(char *Dst, unsigned Size) override { - std::unique_lock Lock(InQueue.getMutex()); - while (Size) { - while (InQueue.empty()) - InQueue.getCondVar().wait(Lock); - *Dst++ = InQueue.front(); - --Size; - InQueue.pop(); - } - return Error::success(); - } - - Error appendBytes(const char *Src, unsigned Size) override { - std::unique_lock Lock(OutQueue.getMutex()); - while (Size--) - OutQueue.push(*Src++); - OutQueue.getCondVar().notify_one(); - return Error::success(); - } - - Error send() override { return Error::success(); } - -private: - Queue &InQueue; - Queue &OutQueue; -}; - class RPCFoo {}; namespace llvm { @@ -143,10 +102,8 @@ namespace DummyRPCAPI { class DummyRPCEndpoint : public SingleThreadedRPCEndpoint { public: - DummyRPCEndpoint(Queue &Q1, Queue &Q2) - : SingleThreadedRPCEndpoint(C, true), C(Q1, Q2) {} -private: - QueueChannel C; + DummyRPCEndpoint(QueueChannel &C) + : SingleThreadedRPCEndpoint(C, true) {} }; @@ -154,15 +111,15 @@ void freeVoidBool(bool B) { } TEST(DummyRPC, TestFreeFunctionHandler) { - Queue Q1, Q2; - DummyRPCEndpoint Server(Q2, Q1); + auto Channels = createPairedQueueChannels(); + DummyRPCEndpoint Server(*Channels.first); Server.addHandler(freeVoidBool); } TEST(DummyRPC, TestCallAsyncVoidBool) { - Queue Q1, Q2; - DummyRPCEndpoint Client(Q1, Q2); - DummyRPCEndpoint Server(Q2, Q1); + auto Channels = createPairedQueueChannels(); + DummyRPCEndpoint Client(*Channels.first); + DummyRPCEndpoint Server(*Channels.second); std::thread ServerThread([&]() { Server.addHandler( @@ -204,9 +161,9 @@ TEST(DummyRPC, TestCallAsyncVoidBool) { } TEST(DummyRPC, TestCallAsyncIntInt) { - Queue Q1, Q2; - DummyRPCEndpoint Client(Q1, Q2); - DummyRPCEndpoint Server(Q2, Q1); + auto Channels = createPairedQueueChannels(); + DummyRPCEndpoint Client(*Channels.first); + DummyRPCEndpoint Server(*Channels.second); std::thread ServerThread([&]() { Server.addHandler( @@ -249,9 +206,9 @@ TEST(DummyRPC, TestCallAsyncIntInt) { } TEST(DummyRPC, TestAsyncIntIntHandler) { - Queue Q1, Q2; - DummyRPCEndpoint Client(Q1, Q2); - DummyRPCEndpoint Server(Q2, Q1); + auto Channels = createPairedQueueChannels(); + DummyRPCEndpoint Client(*Channels.first); + DummyRPCEndpoint Server(*Channels.second); std::thread ServerThread([&]() { Server.addAsyncHandler( @@ -295,9 +252,9 @@ TEST(DummyRPC, TestAsyncIntIntHandler) { } TEST(DummyRPC, TestAsyncIntIntHandlerMethod) { - Queue Q1, Q2; - DummyRPCEndpoint Client(Q1, Q2); - DummyRPCEndpoint Server(Q2, Q1); + auto Channels = createPairedQueueChannels(); + DummyRPCEndpoint Client(*Channels.first); + DummyRPCEndpoint Server(*Channels.second); class Dummy { public: @@ -346,9 +303,9 @@ TEST(DummyRPC, TestAsyncIntIntHandlerMethod) { } TEST(DummyRPC, TestCallAsyncVoidString) { - Queue Q1, Q2; - DummyRPCEndpoint Client(Q1, Q2); - DummyRPCEndpoint Server(Q2, Q1); + auto Channels = createPairedQueueChannels(); + DummyRPCEndpoint Client(*Channels.first); + DummyRPCEndpoint Server(*Channels.second); std::thread ServerThread([&]() { Server.addHandler( @@ -386,9 +343,9 @@ TEST(DummyRPC, TestCallAsyncVoidString) { } TEST(DummyRPC, TestSerialization) { - Queue Q1, Q2; - DummyRPCEndpoint Client(Q1, Q2); - DummyRPCEndpoint Server(Q2, Q1); + auto Channels = createPairedQueueChannels(); + DummyRPCEndpoint Client(*Channels.first); + DummyRPCEndpoint Server(*Channels.second); std::thread ServerThread([&]() { Server.addHandler( @@ -451,9 +408,9 @@ TEST(DummyRPC, TestSerialization) { } TEST(DummyRPC, TestCustomType) { - Queue Q1, Q2; - DummyRPCEndpoint Client(Q1, Q2); - DummyRPCEndpoint Server(Q2, Q1); + auto Channels = createPairedQueueChannels(); + DummyRPCEndpoint Client(*Channels.first); + DummyRPCEndpoint Server(*Channels.second); std::thread ServerThread([&]() { Server.addHandler( @@ -494,9 +451,9 @@ TEST(DummyRPC, TestCustomType) { } TEST(DummyRPC, TestWithAltCustomType) { - Queue Q1, Q2; - DummyRPCEndpoint Client(Q1, Q2); - DummyRPCEndpoint Server(Q2, Q1); + auto Channels = createPairedQueueChannels(); + DummyRPCEndpoint Client(*Channels.first); + DummyRPCEndpoint Server(*Channels.second); std::thread ServerThread([&]() { Server.addHandler( @@ -537,9 +494,9 @@ TEST(DummyRPC, TestWithAltCustomType) { } TEST(DummyRPC, TestParallelCallGroup) { - Queue Q1, Q2; - DummyRPCEndpoint Client(Q1, Q2); - DummyRPCEndpoint Server(Q2, Q1); + auto Channels = createPairedQueueChannels(); + DummyRPCEndpoint Client(*Channels.first); + DummyRPCEndpoint Server(*Channels.second); std::thread ServerThread([&]() { Server.addHandler( @@ -619,9 +576,9 @@ TEST(DummyRPC, TestAPICalls) { static_assert(!DummyCalls1::Contains::value, "Contains template should return false here"); - Queue Q1, Q2; - DummyRPCEndpoint Client(Q1, Q2); - DummyRPCEndpoint Server(Q2, Q1); + auto Channels = createPairedQueueChannels(); + DummyRPCEndpoint Client(*Channels.first); + DummyRPCEndpoint Server(*Channels.second); std::thread ServerThread( [&]() { @@ -657,8 +614,8 @@ TEST(DummyRPC, TestAPICalls) { } TEST(DummyRPC, TestRemoveHandler) { - Queue Q1, Q2; - DummyRPCEndpoint Server(Q1, Q2); + auto Channels = createPairedQueueChannels(); + DummyRPCEndpoint Server(*Channels.second); Server.addHandler( [](bool B) { @@ -670,8 +627,8 @@ TEST(DummyRPC, TestRemoveHandler) { } TEST(DummyRPC, TestClearHandlers) { - Queue Q1, Q2; - DummyRPCEndpoint Server(Q1, Q2); + auto Channels = createPairedQueueChannels(); + DummyRPCEndpoint Server(*Channels.second); Server.addHandler( [](bool B) {