]> granicus.if.org Git - python/commitdiff
asyncio: ensure_future() now understands awaitables
authorYury Selivanov <yselivanov@sprymix.com>
Fri, 2 Oct 2015 19:00:19 +0000 (15:00 -0400)
committerYury Selivanov <yselivanov@sprymix.com>
Fri, 2 Oct 2015 19:00:19 +0000 (15:00 -0400)
Lib/asyncio/tasks.py
Lib/test/test_asyncio/test_tasks.py

index a235e742e234d5fcd4e449a33c7e9095499a820e..434f498e47026fa0f0cdd6f5e53386da08b4793f 100644 (file)
@@ -512,7 +512,7 @@ def async(coro_or_future, *, loop=None):
 
 
 def ensure_future(coro_or_future, *, loop=None):
-    """Wrap a coroutine in a future.
+    """Wrap a coroutine or an awaitable in a future.
 
     If the argument is a Future, it is returned directly.
     """
@@ -527,8 +527,20 @@ def ensure_future(coro_or_future, *, loop=None):
         if task._source_traceback:
             del task._source_traceback[-1]
         return task
+    elif compat.PY35 and inspect.isawaitable(coro_or_future):
+        return ensure_future(_wrap_awaitable(coro_or_future), loop=loop)
     else:
-        raise TypeError('A Future or coroutine is required')
+        raise TypeError('A Future, a coroutine or an awaitable is required')
+
+
+@coroutine
+def _wrap_awaitable(awaitable):
+    """Helper for asyncio.ensure_future().
+
+    Wraps awaitable (an object with __await__) into a coroutine
+    that will later be wrapped in a Task by ensure_future().
+    """
+    return (yield from awaitable.__await__())
 
 
 class _GatheringFuture(futures.Future):
index 04267873103a0bdea1be26abdeb64c7a6d0c8c71..16d3d9da129b25bdf4d4e833c850a3f15541efd5 100644 (file)
@@ -153,6 +153,24 @@ class TaskTests(test_utils.TestCase):
         t = asyncio.ensure_future(t_orig, loop=self.loop)
         self.assertIs(t, t_orig)
 
+    @unittest.skipUnless(PY35, 'need python 3.5 or later')
+    def test_ensure_future_awaitable(self):
+        class Aw:
+            def __init__(self, coro):
+                self.coro = coro
+            def __await__(self):
+                return (yield from self.coro)
+
+        @asyncio.coroutine
+        def coro():
+            return 'ok'
+
+        loop = asyncio.new_event_loop()
+        self.set_event_loop(loop)
+        fut = asyncio.ensure_future(Aw(coro()), loop=loop)
+        loop.run_until_complete(fut)
+        assert fut.result() == 'ok'
+
     def test_ensure_future_neither(self):
         with self.assertRaises(TypeError):
             asyncio.ensure_future('ok')