]> granicus.if.org Git - python/commitdiff
bpo-29679: Implement @contextlib.asynccontextmanager (#360)
authorJelle Zijlstra <jelle.zijlstra@gmail.com>
Mon, 1 May 2017 01:25:58 +0000 (18:25 -0700)
committerYury Selivanov <yselivanov@gmail.com>
Mon, 1 May 2017 01:25:58 +0000 (18:25 -0700)
Doc/library/contextlib.rst
Doc/reference/datamodel.rst
Doc/whatsnew/3.7.rst
Lib/contextlib.py
Lib/test/test_contextlib_async.py [new file with mode: 0644]

index dd34c96c8f8d6040289baae4c3bc47b49876a649..19793693b7ba68c9b41c09028ccbd086a2240493 100644 (file)
@@ -80,6 +80,36 @@ Functions and classes provided:
       Use of :class:`ContextDecorator`.
 
 
+.. decorator:: asynccontextmanager
+
+   Similar to :func:`~contextlib.contextmanager`, but creates an
+   :ref:`asynchronous context manager <async-context-managers>`.
+
+   This function is a :term:`decorator` that can be used to define a factory
+   function for :keyword:`async with` statement asynchronous context managers,
+   without needing to create a class or separate :meth:`__aenter__` and
+   :meth:`__aexit__` methods. It must be applied to an :term:`asynchronous
+   generator` function.
+
+   A simple example::
+
+      from contextlib import asynccontextmanager
+
+      @asynccontextmanager
+      async def get_connection():
+          conn = await acquire_db_connection()
+          try:
+              yield
+          finally:
+              await release_db_connection(conn)
+
+      async def get_all_users():
+          async with get_connection() as conn:
+              return conn.query('SELECT ...')
+
+   .. versionadded:: 3.7
+
+
 .. function:: closing(thing)
 
    Return a context manager that closes *thing* upon completion of the block.  This
index 4b49bfd78da051a3eac6d52666a03bed86501ed7..25afc351e8b5774e86c91777517b58295ab492ac 100644 (file)
@@ -2575,6 +2575,8 @@ An example of an asynchronous iterable object::
       result in a :exc:`RuntimeError`.
 
 
+.. _async-context-managers:
+
 Asynchronous Context Managers
 -----------------------------
 
index 875fc556912cae7d0bd66ae2e853013d2e62df7c..cb0086c4f76f45d37eba1f2e09ca44d15084639e 100644 (file)
@@ -95,6 +95,12 @@ New Modules
 Improved Modules
 ================
 
+contextlib
+----------
+
+:func:`contextlib.asynccontextmanager` has been added. (Contributed by
+Jelle Zijlstra in :issue:`29679`.)
+
 distutils
 ---------
 
index 5e47054954ba5a27d08a017d311d5ba1d338ae94..c53b35e8d5adaa00e2aa84c3baa222e4d2ed8726 100644 (file)
@@ -4,9 +4,9 @@ import sys
 from collections import deque
 from functools import wraps
 
-__all__ = ["contextmanager", "closing", "AbstractContextManager",
-           "ContextDecorator", "ExitStack", "redirect_stdout",
-           "redirect_stderr", "suppress"]
+__all__ = ["asynccontextmanager", "contextmanager", "closing",
+           "AbstractContextManager", "ContextDecorator", "ExitStack",
+           "redirect_stdout", "redirect_stderr", "suppress"]
 
 
 class AbstractContextManager(abc.ABC):
@@ -54,8 +54,8 @@ class ContextDecorator(object):
         return inner
 
 
-class _GeneratorContextManager(ContextDecorator, AbstractContextManager):
-    """Helper for @contextmanager decorator."""
+class _GeneratorContextManagerBase:
+    """Shared functionality for @contextmanager and @asynccontextmanager."""
 
     def __init__(self, func, args, kwds):
         self.gen = func(*args, **kwds)
@@ -71,6 +71,12 @@ class _GeneratorContextManager(ContextDecorator, AbstractContextManager):
         # for the class instead.
         # See http://bugs.python.org/issue19404 for more details.
 
+
+class _GeneratorContextManager(_GeneratorContextManagerBase,
+                               AbstractContextManager,
+                               ContextDecorator):
+    """Helper for @contextmanager decorator."""
+
     def _recreate_cm(self):
         # _GCM instances are one-shot context managers, so the
         # CM must be recreated each time a decorated function is
@@ -121,12 +127,61 @@ class _GeneratorContextManager(ContextDecorator, AbstractContextManager):
                 # fixes the impedance mismatch between the throw() protocol
                 # and the __exit__() protocol.
                 #
+                # This cannot use 'except BaseException as exc' (as in the
+                # async implementation) to maintain compatibility with
+                # Python 2, where old-style class exceptions are not caught
+                # by 'except BaseException'.
                 if sys.exc_info()[1] is value:
                     return False
                 raise
             raise RuntimeError("generator didn't stop after throw()")
 
 
+class _AsyncGeneratorContextManager(_GeneratorContextManagerBase):
+    """Helper for @asynccontextmanager."""
+
+    async def __aenter__(self):
+        try:
+            return await self.gen.__anext__()
+        except StopAsyncIteration:
+            raise RuntimeError("generator didn't yield") from None
+
+    async def __aexit__(self, typ, value, traceback):
+        if typ is None:
+            try:
+                await self.gen.__anext__()
+            except StopAsyncIteration:
+                return
+            else:
+                raise RuntimeError("generator didn't stop")
+        else:
+            if value is None:
+                value = typ()
+            # See _GeneratorContextManager.__exit__ for comments on subtleties
+            # in this implementation
+            try:
+                await self.gen.athrow(typ, value, traceback)
+                raise RuntimeError("generator didn't stop after throw()")
+            except StopAsyncIteration as exc:
+                return exc is not value
+            except RuntimeError as exc:
+                if exc is value:
+                    return False
+                # Avoid suppressing if a StopIteration exception
+                # was passed to throw() and later wrapped into a RuntimeError
+                # (see PEP 479 for sync generators; async generators also
+                # have this behavior). But do this only if the exception wrapped
+                # by the RuntimeError is actully Stop(Async)Iteration (see
+                # issue29692).
+                if isinstance(value, (StopIteration, StopAsyncIteration)):
+                    if exc.__cause__ is value:
+                        return False
+                raise
+            except BaseException as exc:
+                if exc is not value:
+                    raise
+
+
 def contextmanager(func):
     """@contextmanager decorator.
 
@@ -153,7 +208,6 @@ def contextmanager(func):
             <body>
         finally:
             <cleanup>
-
     """
     @wraps(func)
     def helper(*args, **kwds):
@@ -161,6 +215,39 @@ def contextmanager(func):
     return helper
 
 
+def asynccontextmanager(func):
+    """@asynccontextmanager decorator.
+
+    Typical usage:
+
+        @asynccontextmanager
+        async def some_async_generator(<arguments>):
+            <setup>
+            try:
+                yield <value>
+            finally:
+                <cleanup>
+
+    This makes this:
+
+        async with some_async_generator(<arguments>) as <variable>:
+            <body>
+
+    equivalent to this:
+
+        <setup>
+        try:
+            <variable> = <value>
+            <body>
+        finally:
+            <cleanup>
+    """
+    @wraps(func)
+    def helper(*args, **kwds):
+        return _AsyncGeneratorContextManager(func, args, kwds)
+    return helper
+
+
 class closing(AbstractContextManager):
     """Context to automatically close something at the end of a block.
 
diff --git a/Lib/test/test_contextlib_async.py b/Lib/test/test_contextlib_async.py
new file mode 100644 (file)
index 0000000..42cc331
--- /dev/null
@@ -0,0 +1,212 @@
+import asyncio
+from contextlib import asynccontextmanager
+import functools
+from test import support
+import unittest
+
+
+def _async_test(func):
+    """Decorator to turn an async function into a test case."""
+    @functools.wraps(func)
+    def wrapper(*args, **kwargs):
+        coro = func(*args, **kwargs)
+        loop = asyncio.new_event_loop()
+        asyncio.set_event_loop(loop)
+        try:
+            return loop.run_until_complete(coro)
+        finally:
+            loop.close()
+            asyncio.set_event_loop(None)
+    return wrapper
+
+
+class AsyncContextManagerTestCase(unittest.TestCase):
+
+    @_async_test
+    async def test_contextmanager_plain(self):
+        state = []
+        @asynccontextmanager
+        async def woohoo():
+            state.append(1)
+            yield 42
+            state.append(999)
+        async with woohoo() as x:
+            self.assertEqual(state, [1])
+            self.assertEqual(x, 42)
+            state.append(x)
+        self.assertEqual(state, [1, 42, 999])
+
+    @_async_test
+    async def test_contextmanager_finally(self):
+        state = []
+        @asynccontextmanager
+        async def woohoo():
+            state.append(1)
+            try:
+                yield 42
+            finally:
+                state.append(999)
+        with self.assertRaises(ZeroDivisionError):
+            async with woohoo() as x:
+                self.assertEqual(state, [1])
+                self.assertEqual(x, 42)
+                state.append(x)
+                raise ZeroDivisionError()
+        self.assertEqual(state, [1, 42, 999])
+
+    @_async_test
+    async def test_contextmanager_no_reraise(self):
+        @asynccontextmanager
+        async def whee():
+            yield
+        ctx = whee()
+        await ctx.__aenter__()
+        # Calling __aexit__ should not result in an exception
+        self.assertFalse(await ctx.__aexit__(TypeError, TypeError("foo"), None))
+
+    @_async_test
+    async def test_contextmanager_trap_yield_after_throw(self):
+        @asynccontextmanager
+        async def whoo():
+            try:
+                yield
+            except:
+                yield
+        ctx = whoo()
+        await ctx.__aenter__()
+        with self.assertRaises(RuntimeError):
+            await ctx.__aexit__(TypeError, TypeError('foo'), None)
+
+    @_async_test
+    async def test_contextmanager_trap_no_yield(self):
+        @asynccontextmanager
+        async def whoo():
+            if False:
+                yield
+        ctx = whoo()
+        with self.assertRaises(RuntimeError):
+            await ctx.__aenter__()
+
+    @_async_test
+    async def test_contextmanager_trap_second_yield(self):
+        @asynccontextmanager
+        async def whoo():
+            yield
+            yield
+        ctx = whoo()
+        await ctx.__aenter__()
+        with self.assertRaises(RuntimeError):
+            await ctx.__aexit__(None, None, None)
+
+    @_async_test
+    async def test_contextmanager_non_normalised(self):
+        @asynccontextmanager
+        async def whoo():
+            try:
+                yield
+            except RuntimeError:
+                raise SyntaxError
+
+        ctx = whoo()
+        await ctx.__aenter__()
+        with self.assertRaises(SyntaxError):
+            await ctx.__aexit__(RuntimeError, None, None)
+
+    @_async_test
+    async def test_contextmanager_except(self):
+        state = []
+        @asynccontextmanager
+        async def woohoo():
+            state.append(1)
+            try:
+                yield 42
+            except ZeroDivisionError as e:
+                state.append(e.args[0])
+                self.assertEqual(state, [1, 42, 999])
+        async with woohoo() as x:
+            self.assertEqual(state, [1])
+            self.assertEqual(x, 42)
+            state.append(x)
+            raise ZeroDivisionError(999)
+        self.assertEqual(state, [1, 42, 999])
+
+    @_async_test
+    async def test_contextmanager_except_stopiter(self):
+        @asynccontextmanager
+        async def woohoo():
+            yield
+
+        for stop_exc in (StopIteration('spam'), StopAsyncIteration('ham')):
+            with self.subTest(type=type(stop_exc)):
+                try:
+                    async with woohoo():
+                        raise stop_exc
+                except Exception as ex:
+                    self.assertIs(ex, stop_exc)
+                else:
+                    self.fail(f'{stop_exc} was suppressed')
+
+    @_async_test
+    async def test_contextmanager_wrap_runtimeerror(self):
+        @asynccontextmanager
+        async def woohoo():
+            try:
+                yield
+            except Exception as exc:
+                raise RuntimeError(f'caught {exc}') from exc
+
+        with self.assertRaises(RuntimeError):
+            async with woohoo():
+                1 / 0
+
+        # If the context manager wrapped StopAsyncIteration in a RuntimeError,
+        # we also unwrap it, because we can't tell whether the wrapping was
+        # done by the generator machinery or by the generator itself.
+        with self.assertRaises(StopAsyncIteration):
+            async with woohoo():
+                raise StopAsyncIteration
+
+    def _create_contextmanager_attribs(self):
+        def attribs(**kw):
+            def decorate(func):
+                for k,v in kw.items():
+                    setattr(func,k,v)
+                return func
+            return decorate
+        @asynccontextmanager
+        @attribs(foo='bar')
+        async def baz(spam):
+            """Whee!"""
+            yield
+        return baz
+
+    def test_contextmanager_attribs(self):
+        baz = self._create_contextmanager_attribs()
+        self.assertEqual(baz.__name__,'baz')
+        self.assertEqual(baz.foo, 'bar')
+
+    @support.requires_docstrings
+    def test_contextmanager_doc_attrib(self):
+        baz = self._create_contextmanager_attribs()
+        self.assertEqual(baz.__doc__, "Whee!")
+
+    @support.requires_docstrings
+    @_async_test
+    async def test_instance_docstring_given_cm_docstring(self):
+        baz = self._create_contextmanager_attribs()(None)
+        self.assertEqual(baz.__doc__, "Whee!")
+        async with baz:
+            pass  # suppress warning
+
+    @_async_test
+    async def test_keywords(self):
+        # Ensure no keyword arguments are inhibited
+        @asynccontextmanager
+        async def woohoo(self, func, args, kwds):
+            yield (self, func, args, kwds)
+        async with woohoo(self=11, func=22, args=33, kwds=44) as target:
+            self.assertEqual(target, (11, 22, 33, 44))
+
+
+if __name__ == '__main__':
+    unittest.main()