]> granicus.if.org Git - python/commitdiff
Issue 24316: Wrap gen objects returned from callables in types.coroutine
authorYury Selivanov <yselivanov@sprymix.com>
Fri, 29 May 2015 20:19:18 +0000 (16:19 -0400)
committerYury Selivanov <yselivanov@sprymix.com>
Fri, 29 May 2015 20:19:18 +0000 (16:19 -0400)
Lib/test/test_types.py
Lib/types.py

index 956214d080762bd62b3bff85006aa9ef14946df9..17ec6458b336db0e7d3b85f13d2845f781170a00 100644 (file)
@@ -1206,28 +1206,51 @@ class CoroutineTests(unittest.TestCase):
         @types.coroutine
         def foo():
             pass
-        @types.coroutine
-        def gen():
-            def _gen(): yield
-            return _gen()
-
-        for sample in (foo, gen):
-            with self.assertRaisesRegex(TypeError,
-                                        'callable wrapped .* non-coroutine'):
-                sample()
+        with self.assertRaisesRegex(TypeError,
+                                    'callable wrapped .* non-coroutine'):
+            foo()
 
     def test_duck_coro(self):
         class CoroLike:
             def send(self): pass
             def throw(self): pass
             def close(self): pass
-            def __await__(self): pass
+            def __await__(self): return self
 
         coro = CoroLike()
         @types.coroutine
         def foo():
             return coro
-        self.assertIs(coro, foo())
+        self.assertIs(foo().__await__(), coro)
+
+    def test_duck_gen(self):
+        class GenLike:
+            def send(self): pass
+            def throw(self): pass
+            def close(self): pass
+            def __iter__(self): return self
+            def __next__(self): pass
+
+        gen = GenLike()
+        @types.coroutine
+        def foo():
+            return gen
+        self.assertIs(foo().__await__(), gen)
+
+        with self.assertRaises(AttributeError):
+            foo().gi_code
+
+    def test_gen(self):
+        def gen(): yield
+        gen = gen()
+        @types.coroutine
+        def foo(): return gen
+        self.assertIs(foo().__await__(), gen)
+
+        for name in ('__name__', '__qualname__', 'gi_code',
+                     'gi_running', 'gi_frame'):
+            self.assertIs(getattr(foo(), name),
+                          getattr(gen, name))
 
     def test_genfunc(self):
         def gen():
index e9cc7948050a8aa059083061c8c055d7eb6e5c43..0a87c2f08e702906f84ac2fb0827026e6fc12157 100644 (file)
@@ -166,32 +166,64 @@ def coroutine(func):
 
     # We don't want to import 'dis' or 'inspect' just for
     # these constants.
-    _CO_GENERATOR = 0x20
-    _CO_ITERABLE_COROUTINE = 0x100
+    CO_GENERATOR = 0x20
+    CO_ITERABLE_COROUTINE = 0x100
 
     if not callable(func):
         raise TypeError('types.coroutine() expects a callable')
 
     if (isinstance(func, FunctionType) and
         isinstance(getattr(func, '__code__', None), CodeType) and
-        (func.__code__.co_flags & _CO_GENERATOR)):
+        (func.__code__.co_flags & CO_GENERATOR)):
 
         # TODO: Implement this in C.
         co = func.__code__
         func.__code__ = CodeType(
             co.co_argcount, co.co_kwonlyargcount, co.co_nlocals,
             co.co_stacksize,
-            co.co_flags | _CO_ITERABLE_COROUTINE,
+            co.co_flags | CO_ITERABLE_COROUTINE,
             co.co_code,
             co.co_consts, co.co_names, co.co_varnames, co.co_filename,
             co.co_name, co.co_firstlineno, co.co_lnotab, co.co_freevars,
             co.co_cellvars)
         return func
 
+    # The following code is primarily to support functions that
+    # return generator-like objects (for instance generators
+    # compiled with Cython).
+
+    class GeneratorWrapper:
+        def __init__(self, gen):
+            self.__wrapped__ = gen
+            self.send = gen.send
+            self.throw = gen.throw
+            self.close = gen.close
+            self.__name__ = getattr(gen, '__name__', None)
+            self.__qualname__ = getattr(gen, '__qualname__', None)
+        @property
+        def gi_code(self):
+            return self.__wrapped__.gi_code
+        @property
+        def gi_frame(self):
+            return self.__wrapped__.gi_frame
+        @property
+        def gi_running(self):
+            return self.__wrapped__.gi_running
+        def __next__(self):
+            return next(self.__wrapped__)
+        def __iter__(self):
+            return self.__wrapped__
+        __await__ = __iter__
+
     @_functools.wraps(func)
     def wrapped(*args, **kwargs):
         coro = func(*args, **kwargs)
+        if coro.__class__ is GeneratorType:
+            return GeneratorWrapper(coro)
+        # slow checks
         if not isinstance(coro, _collections_abc.Coroutine):
+            if isinstance(coro, _collections_abc.Generator):
+                return GeneratorWrapper(coro)
             raise TypeError(
                 'callable wrapped with types.coroutine() returned '
                 'non-coroutine: {!r}'.format(coro))