]> granicus.if.org Git - llvm/commitdiff
[Orc][RPC] Refactor ParallelCallGroup to decouple it from RPCEndpoint.
authorLang Hames <lhames@gmail.com>
Tue, 24 Jan 2017 06:13:47 +0000 (06:13 +0000)
committerLang Hames <lhames@gmail.com>
Tue, 24 Jan 2017 06:13:47 +0000 (06:13 +0000)
This refactor allows parallel calls to be made via an arbitrary async call
dispatcher. In particular, this allows ParallelCallGroup to be used with
derived RPC classes that expose custom async RPC call operations.

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@292891 91177308-0d34-0410-b5e6-96231b3b80d8

include/llvm/ExecutionEngine/Orc/RPCUtils.h
unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp

index fcebc418f4caae6f8b99df1f89c4147d1ac5c9db..e739c72629854637d89ecec53a5842da05dc6abc 100644 (file)
@@ -1276,24 +1276,40 @@ public:
   }
 };
 
+/// Asynchronous dispatch for a function on an RPC endpoint.
+template <typename RPCClass, typename Func>
+class RPCAsyncDispatch {
+public:
+  RPCAsyncDispatch(RPCClass &Endpoint) : Endpoint(Endpoint) {}
+
+  template <typename HandlerT, typename... ArgTs>
+  Error operator()(HandlerT Handler, const ArgTs &... Args) const {
+    return Endpoint.template appendCallAsync<Func>(std::move(Handler), Args...);
+  }
+
+private:
+  RPCClass &Endpoint;
+};
+
+/// Construct an asynchronous dispatcher from an RPC endpoint and a Func.
+template <typename Func, typename RPCEndpointT>
+RPCAsyncDispatch<RPCEndpointT, Func> rpcAsyncDispatch(RPCEndpointT &Endpoint) {
+  return RPCAsyncDispatch<RPCEndpointT, Func>(Endpoint);
+}
+
 /// \brief Allows a set of asynchrounous calls to be dispatched, and then
 ///        waited on as a group.
-template <typename RPCClass> class ParallelCallGroup {
+class ParallelCallGroup {
 public:
 
-  /// \brief Construct a parallel call group for the given RPC.
-  ParallelCallGroup(RPCClass &RPC) : RPC(RPC), NumOutstandingCalls(0) {}
-
+  ParallelCallGroup() = default;
   ParallelCallGroup(const ParallelCallGroup &) = delete;
   ParallelCallGroup &operator=(const ParallelCallGroup &) = delete;
 
   /// \brief Make as asynchronous call.
-  ///
-  /// Does not issue a send call to the RPC's channel. The channel may use this
-  /// to batch up subsequent calls. A send will automatically be sent when wait
-  /// is called.
-  template <typename Func, typename HandlerT, typename... ArgTs>
-  Error appendCall(HandlerT Handler, const ArgTs &... Args) {
+  template <typename AsyncDispatcher, typename HandlerT, typename... ArgTs>
+  Error call(const AsyncDispatcher &AsyncDispatch, HandlerT Handler,
+             const ArgTs &... Args) {
     // Increment the count of outstanding calls. This has to happen before
     // we invoke the call, as the handler may (depending on scheduling)
     // be run immediately on another thread, and we don't want the decrement
@@ -1316,38 +1332,21 @@ public:
       return Err;
     };
 
-    return RPC.template appendCallAsync<Func>(std::move(WrappedHandler),
-                                              Args...);
-  }
-
-  /// \brief Make an asynchronous call.
-  ///
-  /// The same as appendCall, but also calls send on the channel immediately.
-  /// Prefer appendCall if you are about to issue a "wait" call shortly, as
-  /// this may allow the channel to better batch the calls.
-  template <typename Func, typename HandlerT, typename... ArgTs>
-  Error call(HandlerT Handler, const ArgTs &... Args) {
-    if (auto Err = appendCall(std::move(Handler), Args...))
-      return Err;
-    return RPC.sendAppendedCalls();
+    return AsyncDispatch(std::move(WrappedHandler), Args...);
   }
 
   /// \brief Blocks until all calls have been completed and their return value
   ///        handlers run.
-  Error wait() {
-    if (auto Err = RPC.sendAppendedCalls())
-      return Err;
+  void wait() {
     std::unique_lock<std::mutex> Lock(M);
     while (NumOutstandingCalls > 0)
       CV.wait(Lock);
-    return Error::success();
   }
 
 private:
-  RPCClass &RPC;
   std::mutex M;
   std::condition_variable CV;
-  uint32_t NumOutstandingCalls;
+  uint32_t NumOutstandingCalls = 0;
 };
 
 /// @brief Convenience class for grouping RPC Functions into APIs that can be
index 9abf401af4164e1c370a43ff181453de82494455..d21a4acc08dc5289de2a2716b4347352d0933e89 100644 (file)
@@ -405,10 +405,11 @@ TEST(DummyRPC, TestParallelCallGroup) {
 
   {
     int A, B, C;
-    ParallelCallGroup<DummyRPCEndpoint> PCG(Client);
+    ParallelCallGroup PCG;
 
     {
-      auto Err = PCG.appendCall<DummyRPCAPI::IntInt>(
+      auto Err = PCG.call(
+        rpcAsyncDispatch<DummyRPCAPI::IntInt>(Client),
         [&A](Expected<int> Result) {
           EXPECT_TRUE(!!Result) << "Async int(int) response handler failed";
           A = *Result;
@@ -418,7 +419,8 @@ TEST(DummyRPC, TestParallelCallGroup) {
     }
 
     {
-      auto Err = PCG.appendCall<DummyRPCAPI::IntInt>(
+      auto Err = PCG.call(
+        rpcAsyncDispatch<DummyRPCAPI::IntInt>(Client),
         [&B](Expected<int> Result) {
           EXPECT_TRUE(!!Result) << "Async int(int) response handler failed";
           B = *Result;
@@ -428,7 +430,8 @@ TEST(DummyRPC, TestParallelCallGroup) {
     }
 
     {
-      auto Err = PCG.appendCall<DummyRPCAPI::IntInt>(
+      auto Err = PCG.call(
+        rpcAsyncDispatch<DummyRPCAPI::IntInt>(Client),
         [&C](Expected<int> Result) {
           EXPECT_TRUE(!!Result) << "Async int(int) response handler failed";
           C = *Result;
@@ -443,10 +446,7 @@ TEST(DummyRPC, TestParallelCallGroup) {
       EXPECT_FALSE(!!Err) << "Client failed to handle response from void(bool)";
     }
 
-    {
-      auto Err = PCG.wait();
-      EXPECT_FALSE(!!Err) << "Third parallel call failed for int(int)";
-    }
+    PCG.wait();
 
     EXPECT_EQ(A, 2) << "First parallel call returned bogus result";
     EXPECT_EQ(B, 4) << "Second parallel call returned bogus result";