]> granicus.if.org Git - python/commitdiff
asyncio: Fix @coroutine to recognize CoroWrapper (issue #25647)
authorYury Selivanov <yselivanov@sprymix.com>
Wed, 2 Mar 2016 15:49:16 +0000 (10:49 -0500)
committerYury Selivanov <yselivanov@sprymix.com>
Wed, 2 Mar 2016 15:49:16 +0000 (10:49 -0500)
Patch by Vladimir Rutsky.

Lib/asyncio/coroutines.py
Lib/test/test_asyncio/test_tasks.py

index 27ab42a5bfaf5dde6740bb5cc6a424e6d1dc7639..71bc6fb2eab34fdaadcc6c140473540756b23aad 100644 (file)
@@ -204,7 +204,8 @@ def coroutine(func):
         @functools.wraps(func)
         def coro(*args, **kw):
             res = func(*args, **kw)
-            if isinstance(res, futures.Future) or inspect.isgenerator(res):
+            if isinstance(res, futures.Future) or inspect.isgenerator(res) or \
+                    isinstance(res, CoroWrapper):
                 res = yield from res
             elif _AwaitableABC is not None:
                 # If 'func' returns an Awaitable (new in 3.5) we
index c9d49f047c43696000bf3edb5460342cfd1fb2cc..acceb9b12aaeb13bd5a4557d1bd63398b3453c2e 100644 (file)
@@ -1794,6 +1794,30 @@ class TaskTests(test_utils.TestCase):
 
         self.assertRegex(message, re.compile(regex, re.DOTALL))
 
+    def test_return_coroutine_from_coroutine(self):
+        """Return of @asyncio.coroutine()-wrapped function generator object
+        from @asyncio.coroutine()-wrapped function should have same effect as
+        returning generator object or Future."""
+        def check():
+            @asyncio.coroutine
+            def outer_coro():
+                @asyncio.coroutine
+                def inner_coro():
+                    return 1
+
+                return inner_coro()
+
+            result = self.loop.run_until_complete(outer_coro())
+            self.assertEqual(result, 1)
+
+        # Test with debug flag cleared.
+        with set_coroutine_debug(False):
+            check()
+
+        # Test with debug flag set.
+        with set_coroutine_debug(True):
+            check()
+
     def test_task_source_traceback(self):
         self.loop.set_debug(True)