From 2e624690bd74071358566300b7ef0bc45f444a30 Mon Sep 17 00:00:00 2001 From: Jelle Zijlstra Date: Sun, 30 Apr 2017 18:25:58 -0700 Subject: [PATCH] bpo-29679: Implement @contextlib.asynccontextmanager (#360) --- Doc/library/contextlib.rst | 30 +++++ Doc/reference/datamodel.rst | 2 + Doc/whatsnew/3.7.rst | 6 + Lib/contextlib.py | 99 +++++++++++++- Lib/test/test_contextlib_async.py | 212 ++++++++++++++++++++++++++++++ 5 files changed, 343 insertions(+), 6 deletions(-) create mode 100644 Lib/test/test_contextlib_async.py diff --git a/Doc/library/contextlib.rst b/Doc/library/contextlib.rst index dd34c96c8f..19793693b7 100644 --- a/Doc/library/contextlib.rst +++ b/Doc/library/contextlib.rst @@ -80,6 +80,36 @@ Functions and classes provided: Use of :class:`ContextDecorator`. +.. decorator:: asynccontextmanager + + Similar to :func:`~contextlib.contextmanager`, but creates an + :ref:`asynchronous context manager `. + + This function is a :term:`decorator` that can be used to define a factory + function for :keyword:`async with` statement asynchronous context managers, + without needing to create a class or separate :meth:`__aenter__` and + :meth:`__aexit__` methods. It must be applied to an :term:`asynchronous + generator` function. + + A simple example:: + + from contextlib import asynccontextmanager + + @asynccontextmanager + async def get_connection(): + conn = await acquire_db_connection() + try: + yield + finally: + await release_db_connection(conn) + + async def get_all_users(): + async with get_connection() as conn: + return conn.query('SELECT ...') + + .. versionadded:: 3.7 + + .. function:: closing(thing) Return a context manager that closes *thing* upon completion of the block. This diff --git a/Doc/reference/datamodel.rst b/Doc/reference/datamodel.rst index 4b49bfd78d..25afc351e8 100644 --- a/Doc/reference/datamodel.rst +++ b/Doc/reference/datamodel.rst @@ -2575,6 +2575,8 @@ An example of an asynchronous iterable object:: result in a :exc:`RuntimeError`. +.. _async-context-managers: + Asynchronous Context Managers ----------------------------- diff --git a/Doc/whatsnew/3.7.rst b/Doc/whatsnew/3.7.rst index 875fc55691..cb0086c4f7 100644 --- a/Doc/whatsnew/3.7.rst +++ b/Doc/whatsnew/3.7.rst @@ -95,6 +95,12 @@ New Modules Improved Modules ================ +contextlib +---------- + +:func:`contextlib.asynccontextmanager` has been added. (Contributed by +Jelle Zijlstra in :issue:`29679`.) + distutils --------- diff --git a/Lib/contextlib.py b/Lib/contextlib.py index 5e47054954..c53b35e8d5 100644 --- a/Lib/contextlib.py +++ b/Lib/contextlib.py @@ -4,9 +4,9 @@ import sys from collections import deque from functools import wraps -__all__ = ["contextmanager", "closing", "AbstractContextManager", - "ContextDecorator", "ExitStack", "redirect_stdout", - "redirect_stderr", "suppress"] +__all__ = ["asynccontextmanager", "contextmanager", "closing", + "AbstractContextManager", "ContextDecorator", "ExitStack", + "redirect_stdout", "redirect_stderr", "suppress"] class AbstractContextManager(abc.ABC): @@ -54,8 +54,8 @@ class ContextDecorator(object): return inner -class _GeneratorContextManager(ContextDecorator, AbstractContextManager): - """Helper for @contextmanager decorator.""" +class _GeneratorContextManagerBase: + """Shared functionality for @contextmanager and @asynccontextmanager.""" def __init__(self, func, args, kwds): self.gen = func(*args, **kwds) @@ -71,6 +71,12 @@ class _GeneratorContextManager(ContextDecorator, AbstractContextManager): # for the class instead. # See http://bugs.python.org/issue19404 for more details. + +class _GeneratorContextManager(_GeneratorContextManagerBase, + AbstractContextManager, + ContextDecorator): + """Helper for @contextmanager decorator.""" + def _recreate_cm(self): # _GCM instances are one-shot context managers, so the # CM must be recreated each time a decorated function is @@ -121,12 +127,61 @@ class _GeneratorContextManager(ContextDecorator, AbstractContextManager): # fixes the impedance mismatch between the throw() protocol # and the __exit__() protocol. # + # This cannot use 'except BaseException as exc' (as in the + # async implementation) to maintain compatibility with + # Python 2, where old-style class exceptions are not caught + # by 'except BaseException'. if sys.exc_info()[1] is value: return False raise raise RuntimeError("generator didn't stop after throw()") +class _AsyncGeneratorContextManager(_GeneratorContextManagerBase): + """Helper for @asynccontextmanager.""" + + async def __aenter__(self): + try: + return await self.gen.__anext__() + except StopAsyncIteration: + raise RuntimeError("generator didn't yield") from None + + async def __aexit__(self, typ, value, traceback): + if typ is None: + try: + await self.gen.__anext__() + except StopAsyncIteration: + return + else: + raise RuntimeError("generator didn't stop") + else: + if value is None: + value = typ() + # See _GeneratorContextManager.__exit__ for comments on subtleties + # in this implementation + try: + await self.gen.athrow(typ, value, traceback) + raise RuntimeError("generator didn't stop after throw()") + except StopAsyncIteration as exc: + return exc is not value + except RuntimeError as exc: + if exc is value: + return False + # Avoid suppressing if a StopIteration exception + # was passed to throw() and later wrapped into a RuntimeError + # (see PEP 479 for sync generators; async generators also + # have this behavior). But do this only if the exception wrapped + # by the RuntimeError is actully Stop(Async)Iteration (see + # issue29692). + if isinstance(value, (StopIteration, StopAsyncIteration)): + if exc.__cause__ is value: + return False + raise + except BaseException as exc: + if exc is not value: + raise + + def contextmanager(func): """@contextmanager decorator. @@ -153,7 +208,6 @@ def contextmanager(func): finally: - """ @wraps(func) def helper(*args, **kwds): @@ -161,6 +215,39 @@ def contextmanager(func): return helper +def asynccontextmanager(func): + """@asynccontextmanager decorator. + + Typical usage: + + @asynccontextmanager + async def some_async_generator(): + + try: + yield + finally: + + + This makes this: + + async with some_async_generator() as : + + + equivalent to this: + + + try: + = + + finally: + + """ + @wraps(func) + def helper(*args, **kwds): + return _AsyncGeneratorContextManager(func, args, kwds) + return helper + + class closing(AbstractContextManager): """Context to automatically close something at the end of a block. diff --git a/Lib/test/test_contextlib_async.py b/Lib/test/test_contextlib_async.py new file mode 100644 index 0000000000..42cc331c0a --- /dev/null +++ b/Lib/test/test_contextlib_async.py @@ -0,0 +1,212 @@ +import asyncio +from contextlib import asynccontextmanager +import functools +from test import support +import unittest + + +def _async_test(func): + """Decorator to turn an async function into a test case.""" + @functools.wraps(func) + def wrapper(*args, **kwargs): + coro = func(*args, **kwargs) + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + return loop.run_until_complete(coro) + finally: + loop.close() + asyncio.set_event_loop(None) + return wrapper + + +class AsyncContextManagerTestCase(unittest.TestCase): + + @_async_test + async def test_contextmanager_plain(self): + state = [] + @asynccontextmanager + async def woohoo(): + state.append(1) + yield 42 + state.append(999) + async with woohoo() as x: + self.assertEqual(state, [1]) + self.assertEqual(x, 42) + state.append(x) + self.assertEqual(state, [1, 42, 999]) + + @_async_test + async def test_contextmanager_finally(self): + state = [] + @asynccontextmanager + async def woohoo(): + state.append(1) + try: + yield 42 + finally: + state.append(999) + with self.assertRaises(ZeroDivisionError): + async with woohoo() as x: + self.assertEqual(state, [1]) + self.assertEqual(x, 42) + state.append(x) + raise ZeroDivisionError() + self.assertEqual(state, [1, 42, 999]) + + @_async_test + async def test_contextmanager_no_reraise(self): + @asynccontextmanager + async def whee(): + yield + ctx = whee() + await ctx.__aenter__() + # Calling __aexit__ should not result in an exception + self.assertFalse(await ctx.__aexit__(TypeError, TypeError("foo"), None)) + + @_async_test + async def test_contextmanager_trap_yield_after_throw(self): + @asynccontextmanager + async def whoo(): + try: + yield + except: + yield + ctx = whoo() + await ctx.__aenter__() + with self.assertRaises(RuntimeError): + await ctx.__aexit__(TypeError, TypeError('foo'), None) + + @_async_test + async def test_contextmanager_trap_no_yield(self): + @asynccontextmanager + async def whoo(): + if False: + yield + ctx = whoo() + with self.assertRaises(RuntimeError): + await ctx.__aenter__() + + @_async_test + async def test_contextmanager_trap_second_yield(self): + @asynccontextmanager + async def whoo(): + yield + yield + ctx = whoo() + await ctx.__aenter__() + with self.assertRaises(RuntimeError): + await ctx.__aexit__(None, None, None) + + @_async_test + async def test_contextmanager_non_normalised(self): + @asynccontextmanager + async def whoo(): + try: + yield + except RuntimeError: + raise SyntaxError + + ctx = whoo() + await ctx.__aenter__() + with self.assertRaises(SyntaxError): + await ctx.__aexit__(RuntimeError, None, None) + + @_async_test + async def test_contextmanager_except(self): + state = [] + @asynccontextmanager + async def woohoo(): + state.append(1) + try: + yield 42 + except ZeroDivisionError as e: + state.append(e.args[0]) + self.assertEqual(state, [1, 42, 999]) + async with woohoo() as x: + self.assertEqual(state, [1]) + self.assertEqual(x, 42) + state.append(x) + raise ZeroDivisionError(999) + self.assertEqual(state, [1, 42, 999]) + + @_async_test + async def test_contextmanager_except_stopiter(self): + @asynccontextmanager + async def woohoo(): + yield + + for stop_exc in (StopIteration('spam'), StopAsyncIteration('ham')): + with self.subTest(type=type(stop_exc)): + try: + async with woohoo(): + raise stop_exc + except Exception as ex: + self.assertIs(ex, stop_exc) + else: + self.fail(f'{stop_exc} was suppressed') + + @_async_test + async def test_contextmanager_wrap_runtimeerror(self): + @asynccontextmanager + async def woohoo(): + try: + yield + except Exception as exc: + raise RuntimeError(f'caught {exc}') from exc + + with self.assertRaises(RuntimeError): + async with woohoo(): + 1 / 0 + + # If the context manager wrapped StopAsyncIteration in a RuntimeError, + # we also unwrap it, because we can't tell whether the wrapping was + # done by the generator machinery or by the generator itself. + with self.assertRaises(StopAsyncIteration): + async with woohoo(): + raise StopAsyncIteration + + def _create_contextmanager_attribs(self): + def attribs(**kw): + def decorate(func): + for k,v in kw.items(): + setattr(func,k,v) + return func + return decorate + @asynccontextmanager + @attribs(foo='bar') + async def baz(spam): + """Whee!""" + yield + return baz + + def test_contextmanager_attribs(self): + baz = self._create_contextmanager_attribs() + self.assertEqual(baz.__name__,'baz') + self.assertEqual(baz.foo, 'bar') + + @support.requires_docstrings + def test_contextmanager_doc_attrib(self): + baz = self._create_contextmanager_attribs() + self.assertEqual(baz.__doc__, "Whee!") + + @support.requires_docstrings + @_async_test + async def test_instance_docstring_given_cm_docstring(self): + baz = self._create_contextmanager_attribs()(None) + self.assertEqual(baz.__doc__, "Whee!") + async with baz: + pass # suppress warning + + @_async_test + async def test_keywords(self): + # Ensure no keyword arguments are inhibited + @asynccontextmanager + async def woohoo(self, func, args, kwds): + yield (self, func, args, kwds) + async with woohoo(self=11, func=22, args=33, kwds=44) as target: + self.assertEqual(target, (11, 22, 33, 44)) + + +if __name__ == '__main__': + unittest.main() -- 2.40.0