]> granicus.if.org Git - clang/commitdiff
[coroutines] Support coroutine-handle returning await-suspend (i.e symmetric control...
authorGor Nishanov <GorNishanov@gmail.com>
Fri, 25 Aug 2017 04:46:54 +0000 (04:46 +0000)
committerGor Nishanov <GorNishanov@gmail.com>
Fri, 25 Aug 2017 04:46:54 +0000 (04:46 +0000)
Summary:
If await_suspend returns a coroutine_handle, as in the example below:
```
  coroutine_handle<> await_suspend(coroutine_handle<> h) {
    coro.promise().waiter = h;
    return coro;
  }
```
suspensionExpression processing will resume the coroutine pointed at by that handle.
Related LLVM change rL311751 makes resume calls of this kind `musttail` at any optimization level.

This enables unlimited symmetric control transfer from coroutine to coroutine without blowing up the stack.

Reviewers: GorNishanov

Reviewed By: GorNishanov

Subscribers: rsmith, EricWF, cfe-commits

Differential Revision: https://reviews.llvm.org/D37131

git-svn-id: https://llvm.org/svn/llvm-project/cfe/trunk@311762 91177308-0d34-0410-b5e6-96231b3b80d8

lib/CodeGen/CGCoroutine.cpp
lib/Sema/SemaCoroutine.cpp
test/CodeGenCoroutines/coro-await.cpp

index d23697a5309b3c4cbc48d12f25cab891a24a355d..5842e7b3ff93c225e6492f1ffe36342adf438f8d 100644 (file)
@@ -181,10 +181,8 @@ static LValueOrRValue emitSuspendExpression(CodeGenFunction &CGF, CGCoroData &Co
   auto *SaveCall = Builder.CreateCall(CoroSave, {NullPtr});
 
   auto *SuspendRet = CGF.EmitScalarExpr(S.getSuspendExpr());
-  if (SuspendRet != nullptr) {
+  if (SuspendRet != nullptr && SuspendRet->getType()->isIntegerTy(1)) {
     // Veto suspension if requested by bool returning await_suspend.
-    assert(SuspendRet->getType()->isIntegerTy(1) &&
-           "Sema should have already checked that it is void or bool");
     BasicBlock *RealSuspendBlock =
         CGF.createBasicBlock(Prefix + Twine(".suspend.bool"));
     CGF.Builder.CreateCondBr(SuspendRet, RealSuspendBlock, ReadyBlock);
index dc7d8e4e9cec3dc6b0e0adf6648254adb7725570..ae8744c07ba1e4cd36916eb1d719141eb98d9494 100644 (file)
@@ -363,6 +363,32 @@ static ExprResult buildMemberCall(Sema &S, Expr *Base, SourceLocation Loc,
   return S.ActOnCallExpr(nullptr, Result.get(), Loc, Args, Loc, nullptr);
 }
 
+// See if return type is coroutine-handle and if so, invoke builtin coro-resume
+// on its address. This is to enable experimental support for coroutine-handle
+// returning await_suspend that results in a guranteed tail call to the target
+// coroutine.
+static Expr *maybeTailCall(Sema &S, QualType RetType, Expr *E,
+                           SourceLocation Loc) {
+  if (RetType->isReferenceType())
+    return nullptr;
+  Type const *T = RetType.getTypePtr();
+  if (!T->isClassType() && !T->isStructureType())
+    return nullptr;
+
+  // FIXME: Add convertability check to coroutine_handle<>. Possibly via
+  // EvaluateBinaryTypeTrait(BTT_IsConvertible, ...) which is at the moment
+  // a private function in SemaExprCXX.cpp
+
+  ExprResult AddressExpr = buildMemberCall(S, E, Loc, "address", None);
+  if (AddressExpr.isInvalid())
+    return nullptr;
+
+  Expr *JustAddress = AddressExpr.get();
+  // FIXME: Check that the type of AddressExpr is void*
+  return buildBuiltinCall(S, Loc, Builtin::BI__builtin_coro_resume,
+                          JustAddress);
+}
+
 /// Build calls to await_ready, await_suspend, and await_resume for a co_await
 /// expression.
 static ReadySuspendResumeResult buildCoawaitCalls(Sema &S, VarDecl *CoroPromise,
@@ -412,16 +438,21 @@ static ReadySuspendResumeResult buildCoawaitCalls(Sema &S, VarDecl *CoroPromise,
     //   - await-suspend is the expression e.await_suspend(h), which shall be
     //     a prvalue of type void or bool.
     QualType RetType = AwaitSuspend->getCallReturnType(S.Context);
-    // non-class prvalues always have cv-unqualified types
-    QualType AdjRetType = RetType.getUnqualifiedType();
-    if (RetType->isReferenceType() ||
-        (AdjRetType != S.Context.BoolTy && AdjRetType != S.Context.VoidTy)) {
-      S.Diag(AwaitSuspend->getCalleeDecl()->getLocation(),
-             diag::err_await_suspend_invalid_return_type)
-          << RetType;
-      S.Diag(Loc, diag::note_coroutine_promise_call_implicitly_required)
-          << AwaitSuspend->getDirectCallee();
-      Calls.IsInvalid = true;
+    // Experimental support for coroutine_handle returning await_suspend.
+    if (Expr *TailCallSuspend = maybeTailCall(S, RetType, AwaitSuspend, Loc))
+      Calls.Results[ACT::ACT_Suspend] = TailCallSuspend;
+    else {
+      // non-class prvalues always have cv-unqualified types
+      QualType AdjRetType = RetType.getUnqualifiedType();
+      if (RetType->isReferenceType() ||
+          (AdjRetType != S.Context.BoolTy && AdjRetType != S.Context.VoidTy)) {
+        S.Diag(AwaitSuspend->getCalleeDecl()->getLocation(),
+               diag::err_await_suspend_invalid_return_type)
+            << RetType;
+        S.Diag(Loc, diag::note_coroutine_promise_call_implicitly_required)
+            << AwaitSuspend->getDirectCallee();
+        Calls.IsInvalid = true;
+      }
     }
   }
 
index fc6559f1e0ad4113c14462a48855b05f753985a2..41881d7123379123b86106017b1f88666b703dac 100644 (file)
@@ -12,6 +12,7 @@ template <>
 struct coroutine_handle<void> {
   void *ptr;
   static coroutine_handle from_address(void *);
+  void *address();
 };
 
 template <typename Promise>
@@ -326,3 +327,20 @@ void AwaitReturnsLValue(double) {
   // CHECK-NEXT: store %struct.RefTag* %[[RES3]], %struct.RefTag** %[[ZVAR]],
   RefTag& z = co_yield 42;
 }
+
+struct TailCallAwait {
+  bool await_ready();
+  std::experimental::coroutine_handle<> await_suspend(std::experimental::coroutine_handle<>);
+  void await_resume();
+};
+
+// CHECK-LABEL: @TestTailcall(
+extern "C" void TestTailcall() {
+  co_await TailCallAwait{};
+
+  // CHECK: %[[RESULT:.+]] = call i8* @_ZN13TailCallAwait13await_suspendENSt12experimental16coroutine_handleIvEE(%struct.TailCallAwait*
+  // CHECK: %[[COERCE:.+]] = getelementptr inbounds %"struct.std::experimental::coroutine_handle", %"struct.std::experimental::coroutine_handle"* %[[TMP:.+]], i32 0, i32 0
+  // CHECK: store i8* %[[RESULT]], i8** %[[COERCE]]
+  // CHECK: %[[ADDR:.+]] = call i8* @_ZNSt12experimental16coroutine_handleIvE7addressEv(%"struct.std::experimental::coroutine_handle"* %[[TMP]])
+  // CHECK: call void @llvm.coro.resume(i8* %[[ADDR]])
+}