]> granicus.if.org Git - llvm/commitdiff
[Orc] Break QueueChannel out into its own header and add a utility,
authorLang Hames <lhames@gmail.com>
Thu, 6 Apr 2017 01:49:21 +0000 (01:49 +0000)
committerLang Hames <lhames@gmail.com>
Thu, 6 Apr 2017 01:49:21 +0000 (01:49 +0000)
createPairedQueueChannels, to simplify channel creation in the RPC unit tests.

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@299611 91177308-0d34-0410-b5e6-96231b3b80d8

unittests/ExecutionEngine/Orc/CMakeLists.txt
unittests/ExecutionEngine/Orc/QueueChannel.cpp [new file with mode: 0644]
unittests/ExecutionEngine/Orc/QueueChannel.h [new file with mode: 0644]
unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp

index 009210fb442ad31834a6d1216fd93cc76d0c5808..db40c4213bd70c51e3a3e7f70161c5557fd8934c 100644 (file)
@@ -14,11 +14,12 @@ add_llvm_unittest(OrcJITTests
   IndirectionUtilsTest.cpp
   GlobalMappingLayerTest.cpp
   LazyEmittingLayerTest.cpp
-  RTDyldObjectLinkingLayerTest.cpp
   ObjectTransformLayerTest.cpp
   OrcCAPITest.cpp
   OrcTestCommon.cpp
+  QueueChannel.cpp
   RPCUtilsTest.cpp
+  RTDyldObjectLinkingLayerTest.cpp
   )
 
 target_link_libraries(OrcJITTests ${LLVM_PTHREAD_LIB})
diff --git a/unittests/ExecutionEngine/Orc/QueueChannel.cpp b/unittests/ExecutionEngine/Orc/QueueChannel.cpp
new file mode 100644 (file)
index 0000000..e309a7e
--- /dev/null
@@ -0,0 +1,14 @@
+//===-------- QueueChannel.cpp - Unit tests the remote executors ----------===//
+//
+//                     The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+
+#include "QueueChannel.h"
+
+char llvm::QueueChannelError::ID;
+char llvm::QueueChannelClosedError::ID;
+
diff --git a/unittests/ExecutionEngine/Orc/QueueChannel.h b/unittests/ExecutionEngine/Orc/QueueChannel.h
new file mode 100644 (file)
index 0000000..1528ac9
--- /dev/null
@@ -0,0 +1,145 @@
+//===----------------------- Queue.h - RPC Queue ------------------*-c++-*-===//
+//
+//                     The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_UNITTESTS_EXECUTIONENGINE_ORC_QUEUECHANNEL_H
+#define LLVM_UNITTESTS_EXECUTIONENGINE_ORC_QUEUECHANNEL_H
+
+#include "llvm/ExecutionEngine/Orc/RawByteChannel.h"
+#include "llvm/Support/Error.h"
+
+#include <queue>
+
+namespace llvm {
+
+class QueueChannelError : public ErrorInfo<QueueChannelError> {
+public:
+  static char ID;
+};
+
+class QueueChannelClosedError
+    : public ErrorInfo<QueueChannelClosedError, QueueChannelError> {
+public:
+  static char ID;
+  std::error_code convertToErrorCode() const override {
+    return inconvertibleErrorCode();
+  }
+
+  void log(raw_ostream &OS) const override {
+    OS << "Queue closed";
+  }
+};
+
+class Queue : public std::queue<char> {
+public:
+  using ErrorInjector = std::function<Error()>;
+
+  Queue()
+    : ReadError([]() { return Error::success(); }),
+      WriteError([]() { return Error::success(); }) {}
+
+  Queue(const Queue&) = delete;
+  Queue& operator=(const Queue&) = delete;
+  Queue(Queue&&) = delete;
+  Queue& operator=(Queue&&) = delete;
+
+  std::mutex &getMutex() { return M; }
+  std::condition_variable &getCondVar() { return CV; }
+  Error checkReadError() { return ReadError(); }
+  Error checkWriteError() { return WriteError(); }
+  void setReadError(ErrorInjector NewReadError) {
+    {
+      std::lock_guard<std::mutex> Lock(M);
+      ReadError = std::move(NewReadError);
+    }
+    CV.notify_one();
+  }
+  void setWriteError(ErrorInjector NewWriteError) {
+    std::lock_guard<std::mutex> Lock(M);
+    WriteError = std::move(NewWriteError);
+  }
+private:
+  std::mutex M;
+  std::condition_variable CV;
+  std::function<Error()> ReadError, WriteError;
+};
+
+class QueueChannel : public orc::rpc::RawByteChannel {
+public:
+  QueueChannel(std::shared_ptr<Queue> InQueue,
+               std::shared_ptr<Queue> OutQueue)
+      : InQueue(InQueue), OutQueue(OutQueue) {}
+
+  QueueChannel(const QueueChannel&) = delete;
+  QueueChannel& operator=(const QueueChannel&) = delete;
+  QueueChannel(QueueChannel&&) = delete;
+  QueueChannel& operator=(QueueChannel&&) = delete;
+
+  Error readBytes(char *Dst, unsigned Size) override {
+    std::unique_lock<std::mutex> Lock(InQueue->getMutex());
+    while (Size) {
+      {
+        Error Err = InQueue->checkReadError();
+        while (!Err && InQueue->empty()) {
+          InQueue->getCondVar().wait(Lock);
+          Err = InQueue->checkReadError();
+        }
+        if (Err)
+          return Err;
+      }
+      *Dst++ = InQueue->front();
+      --Size;
+      ++NumRead;
+      InQueue->pop();
+    }
+    return Error::success();
+  }
+
+  Error appendBytes(const char *Src, unsigned Size) override {
+    std::unique_lock<std::mutex> Lock(OutQueue->getMutex());
+    while (Size--) {
+      if (Error Err = OutQueue->checkWriteError())
+        return Err;
+      OutQueue->push(*Src++);
+      ++NumWritten;
+    }
+    OutQueue->getCondVar().notify_one();
+    return Error::success();
+  }
+
+  Error send() override { return Error::success(); }
+
+  void close() {
+    auto ChannelClosed = []() { return make_error<QueueChannelClosedError>(); };
+    InQueue->setReadError(ChannelClosed);
+    InQueue->setWriteError(ChannelClosed);
+    OutQueue->setReadError(ChannelClosed);
+    OutQueue->setWriteError(ChannelClosed);
+  }
+
+  uint64_t NumWritten = 0;
+  uint64_t NumRead = 0;
+
+private:
+
+  std::shared_ptr<Queue> InQueue;
+  std::shared_ptr<Queue> OutQueue;
+};
+
+inline std::pair<std::unique_ptr<QueueChannel>, std::unique_ptr<QueueChannel>>
+createPairedQueueChannels() {
+  auto Q1 = std::make_shared<Queue>();
+  auto Q2 = std::make_shared<Queue>();
+  auto C1 = llvm::make_unique<QueueChannel>(Q1, Q2);
+  auto C2 = llvm::make_unique<QueueChannel>(Q2, Q1);
+  return std::make_pair(std::move(C1), std::move(C2));
+}
+
+}
+
+#endif
index 355d20b4f784e85268618dd0649b215f8790a86d..3d46ef88f7c5173bf8bbb21e7d1201259b00e464 100644 (file)
@@ -7,8 +7,8 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "llvm/ExecutionEngine/Orc/RawByteChannel.h"
 #include "llvm/ExecutionEngine/Orc/RPCUtils.h"
+#include "QueueChannel.h"
 #include "gtest/gtest.h"
 
 #include <queue>
@@ -17,47 +17,6 @@ using namespace llvm;
 using namespace llvm::orc;
 using namespace llvm::orc::rpc;
 
-class Queue : public std::queue<char> {
-public:
-  std::mutex &getMutex() { return M; }
-  std::condition_variable &getCondVar() { return CV; }
-private:
-  std::mutex M;
-  std::condition_variable CV;
-};
-
-class QueueChannel : public RawByteChannel {
-public:
-  QueueChannel(Queue &InQueue, Queue &OutQueue)
-      : InQueue(InQueue), OutQueue(OutQueue) {}
-
-  Error readBytes(char *Dst, unsigned Size) override {
-    std::unique_lock<std::mutex> Lock(InQueue.getMutex());
-    while (Size) {
-      while (InQueue.empty())
-        InQueue.getCondVar().wait(Lock);
-      *Dst++ = InQueue.front();
-      --Size;
-      InQueue.pop();
-    }
-    return Error::success();
-  }
-
-  Error appendBytes(const char *Src, unsigned Size) override {
-    std::unique_lock<std::mutex> Lock(OutQueue.getMutex());
-    while (Size--)
-      OutQueue.push(*Src++);
-    OutQueue.getCondVar().notify_one();
-    return Error::success();
-  }
-
-  Error send() override { return Error::success(); }
-
-private:
-  Queue &InQueue;
-  Queue &OutQueue;
-};
-
 class RPCFoo {};
 
 namespace llvm {
@@ -143,10 +102,8 @@ namespace DummyRPCAPI {
 
 class DummyRPCEndpoint : public SingleThreadedRPCEndpoint<QueueChannel> {
 public:
-  DummyRPCEndpoint(Queue &Q1, Queue &Q2)
-      : SingleThreadedRPCEndpoint(C, true), C(Q1, Q2) {}
-private:
-  QueueChannel C;
+  DummyRPCEndpoint(QueueChannel &C)
+      : SingleThreadedRPCEndpoint(C, true) {}
 };
 
 
@@ -154,15 +111,15 @@ void freeVoidBool(bool B) {
 }
 
 TEST(DummyRPC, TestFreeFunctionHandler) {
-  Queue Q1, Q2;
-  DummyRPCEndpoint Server(Q2, Q1);
+  auto Channels = createPairedQueueChannels();
+  DummyRPCEndpoint Server(*Channels.first);
   Server.addHandler<DummyRPCAPI::VoidBool>(freeVoidBool);
 }
 
 TEST(DummyRPC, TestCallAsyncVoidBool) {
-  Queue Q1, Q2;
-  DummyRPCEndpoint Client(Q1, Q2);
-  DummyRPCEndpoint Server(Q2, Q1);
+  auto Channels = createPairedQueueChannels();
+  DummyRPCEndpoint Client(*Channels.first);
+  DummyRPCEndpoint Server(*Channels.second);
 
   std::thread ServerThread([&]() {
       Server.addHandler<DummyRPCAPI::VoidBool>(
@@ -204,9 +161,9 @@ TEST(DummyRPC, TestCallAsyncVoidBool) {
 }
 
 TEST(DummyRPC, TestCallAsyncIntInt) {
-  Queue Q1, Q2;
-  DummyRPCEndpoint Client(Q1, Q2);
-  DummyRPCEndpoint Server(Q2, Q1);
+  auto Channels = createPairedQueueChannels();
+  DummyRPCEndpoint Client(*Channels.first);
+  DummyRPCEndpoint Server(*Channels.second);
 
   std::thread ServerThread([&]() {
       Server.addHandler<DummyRPCAPI::IntInt>(
@@ -249,9 +206,9 @@ TEST(DummyRPC, TestCallAsyncIntInt) {
 }
 
 TEST(DummyRPC, TestAsyncIntIntHandler) {
-  Queue Q1, Q2;
-  DummyRPCEndpoint Client(Q1, Q2);
-  DummyRPCEndpoint Server(Q2, Q1);
+  auto Channels = createPairedQueueChannels();
+  DummyRPCEndpoint Client(*Channels.first);
+  DummyRPCEndpoint Server(*Channels.second);
 
   std::thread ServerThread([&]() {
       Server.addAsyncHandler<DummyRPCAPI::IntInt>(
@@ -295,9 +252,9 @@ TEST(DummyRPC, TestAsyncIntIntHandler) {
 }
 
 TEST(DummyRPC, TestAsyncIntIntHandlerMethod) {
-  Queue Q1, Q2;
-  DummyRPCEndpoint Client(Q1, Q2);
-  DummyRPCEndpoint Server(Q2, Q1);
+  auto Channels = createPairedQueueChannels();
+  DummyRPCEndpoint Client(*Channels.first);
+  DummyRPCEndpoint Server(*Channels.second);
 
   class Dummy {
   public:
@@ -346,9 +303,9 @@ TEST(DummyRPC, TestAsyncIntIntHandlerMethod) {
 }
 
 TEST(DummyRPC, TestCallAsyncVoidString) {
-  Queue Q1, Q2;
-  DummyRPCEndpoint Client(Q1, Q2);
-  DummyRPCEndpoint Server(Q2, Q1);
+  auto Channels = createPairedQueueChannels();
+  DummyRPCEndpoint Client(*Channels.first);
+  DummyRPCEndpoint Server(*Channels.second);
 
   std::thread ServerThread([&]() {
       Server.addHandler<DummyRPCAPI::VoidString>(
@@ -386,9 +343,9 @@ TEST(DummyRPC, TestCallAsyncVoidString) {
 }
 
 TEST(DummyRPC, TestSerialization) {
-  Queue Q1, Q2;
-  DummyRPCEndpoint Client(Q1, Q2);
-  DummyRPCEndpoint Server(Q2, Q1);
+  auto Channels = createPairedQueueChannels();
+  DummyRPCEndpoint Client(*Channels.first);
+  DummyRPCEndpoint Server(*Channels.second);
 
   std::thread ServerThread([&]() {
       Server.addHandler<DummyRPCAPI::AllTheTypes>(
@@ -451,9 +408,9 @@ TEST(DummyRPC, TestSerialization) {
 }
 
 TEST(DummyRPC, TestCustomType) {
-  Queue Q1, Q2;
-  DummyRPCEndpoint Client(Q1, Q2);
-  DummyRPCEndpoint Server(Q2, Q1);
+  auto Channels = createPairedQueueChannels();
+  DummyRPCEndpoint Client(*Channels.first);
+  DummyRPCEndpoint Server(*Channels.second);
 
   std::thread ServerThread([&]() {
       Server.addHandler<DummyRPCAPI::CustomType>(
@@ -494,9 +451,9 @@ TEST(DummyRPC, TestCustomType) {
 }
 
 TEST(DummyRPC, TestWithAltCustomType) {
-  Queue Q1, Q2;
-  DummyRPCEndpoint Client(Q1, Q2);
-  DummyRPCEndpoint Server(Q2, Q1);
+  auto Channels = createPairedQueueChannels();
+  DummyRPCEndpoint Client(*Channels.first);
+  DummyRPCEndpoint Server(*Channels.second);
 
   std::thread ServerThread([&]() {
       Server.addHandler<DummyRPCAPI::CustomType>(
@@ -537,9 +494,9 @@ TEST(DummyRPC, TestWithAltCustomType) {
 }
 
 TEST(DummyRPC, TestParallelCallGroup) {
-  Queue Q1, Q2;
-  DummyRPCEndpoint Client(Q1, Q2);
-  DummyRPCEndpoint Server(Q2, Q1);
+  auto Channels = createPairedQueueChannels();
+  DummyRPCEndpoint Client(*Channels.first);
+  DummyRPCEndpoint Server(*Channels.second);
 
   std::thread ServerThread([&]() {
       Server.addHandler<DummyRPCAPI::IntInt>(
@@ -619,9 +576,9 @@ TEST(DummyRPC, TestAPICalls) {
   static_assert(!DummyCalls1::Contains<DummyRPCAPI::CustomType>::value,
                 "Contains<Func> template should return false here");
 
-  Queue Q1, Q2;
-  DummyRPCEndpoint Client(Q1, Q2);
-  DummyRPCEndpoint Server(Q2, Q1);
+  auto Channels = createPairedQueueChannels();
+  DummyRPCEndpoint Client(*Channels.first);
+  DummyRPCEndpoint Server(*Channels.second);
 
   std::thread ServerThread(
     [&]() {
@@ -657,8 +614,8 @@ TEST(DummyRPC, TestAPICalls) {
 }
 
 TEST(DummyRPC, TestRemoveHandler) {
-  Queue Q1, Q2;
-  DummyRPCEndpoint Server(Q1, Q2);
+  auto Channels = createPairedQueueChannels();
+  DummyRPCEndpoint Server(*Channels.second);
 
   Server.addHandler<DummyRPCAPI::VoidBool>(
     [](bool B) {
@@ -670,8 +627,8 @@ TEST(DummyRPC, TestRemoveHandler) {
 }
 
 TEST(DummyRPC, TestClearHandlers) {
-  Queue Q1, Q2;
-  DummyRPCEndpoint Server(Q1, Q2);
+  auto Channels = createPairedQueueChannels();
+  DummyRPCEndpoint Server(*Channels.second);
 
   Server.addHandler<DummyRPCAPI::VoidBool>(
     [](bool B) {