return data
else:
raise StopAsyncIteration
+
+class ACM:
+ async def __aenter__(self) -> int:
+ return 42
+ async def __aexit__(self, etype, eval, tb):
+ return None
"""
if ASYNCIO:
else:
# fake names for the sake of static analysis
asyncio = None
- AwaitableWrapper = AsyncIteratorWrapper = object
+ AwaitableWrapper = AsyncIteratorWrapper = ACM = object
PY36 = sys.version_info[:2] >= (3, 6)
PY36_TESTS = """
from test import ann_module, ann_module2, ann_module3
+from typing import AsyncContextManager
class A:
y: float
return f'{self.x} -> {self.y}'
def __add__(self, other):
return 0
+
+async def g_with(am: AsyncContextManager[int]):
+ x: int
+ async with am as x:
+ return x
+
+try:
+ g_with(ACM()).send(None)
+except StopIteration as e:
+ assert e.args[0] == 42
"""
if PY36:
class OtherABCTests(BaseTestCase):
- @skipUnless(hasattr(typing, 'ContextManager'),
- 'requires typing.ContextManager')
def test_contextmanager(self):
@contextlib.contextmanager
def manager():
self.assertIsInstance(cm, typing.ContextManager)
self.assertNotIsInstance(42, typing.ContextManager)
+ @skipUnless(ASYNCIO, 'Python 3.5 required')
+ def test_async_contextmanager(self):
+ class NotACM:
+ pass
+ self.assertIsInstance(ACM(), typing.AsyncContextManager)
+ self.assertNotIsInstance(NotACM(), typing.AsyncContextManager)
+ @contextlib.contextmanager
+ def manager():
+ yield 42
+
+ cm = manager()
+ self.assertNotIsInstance(cm, typing.AsyncContextManager)
+ self.assertEqual(typing.AsyncContextManager[int].__args__, (int,))
+ with self.assertRaises(TypeError):
+ isinstance(42, typing.AsyncContextManager[int])
+ with self.assertRaises(TypeError):
+ typing.AsyncContextManager[int, str]
+
class TypeTests(BaseTestCase):
import collections.abc as collections_abc
except ImportError:
import collections as collections_abc # Fallback for PY3.2.
+if sys.version_info[:2] >= (3, 6):
+ import _collections_abc # Needed for private function _check_methods # noqa
try:
from types import WrapperDescriptorType, MethodWrapperType, MethodDescriptorType
except ImportError:
# for 'Generic' and ABCs below.
'ByteString',
'Container',
+ 'ContextManager',
'Hashable',
'ItemsView',
'Iterable',
# AsyncIterable,
# Coroutine,
# Collection,
- # ContextManager,
# AsyncGenerator,
+ # AsyncContextManager
# Structural checks, a.k.a. protocols.
'Reversible',
if hasattr(contextlib, 'AbstractContextManager'):
class ContextManager(Generic[T_co], extra=contextlib.AbstractContextManager):
__slots__ = ()
- __all__.append('ContextManager')
+else:
+ class ContextManager(Generic[T_co]):
+ __slots__ = ()
+
+ def __enter__(self):
+ return self
+
+ @abc.abstractmethod
+ def __exit__(self, exc_type, exc_value, traceback):
+ return None
+
+ @classmethod
+ def __subclasshook__(cls, C):
+ if cls is ContextManager:
+ # In Python 3.6+, it is possible to set a method to None to
+ # explicitly indicate that the class does not implement an ABC
+ # (https://bugs.python.org/issue25958), but we do not support
+ # that pattern here because this fallback class is only used
+ # in Python 3.5 and earlier.
+ if (any("__enter__" in B.__dict__ for B in C.__mro__) and
+ any("__exit__" in B.__dict__ for B in C.__mro__)):
+ return True
+ return NotImplemented
+
+
+if hasattr(contextlib, 'AbstractAsyncContextManager'):
+ class AsyncContextManager(Generic[T_co],
+ extra=contextlib.AbstractAsyncContextManager):
+ __slots__ = ()
+
+ __all__.append('AsyncContextManager')
+elif sys.version_info[:2] >= (3, 5):
+ exec("""
+class AsyncContextManager(Generic[T_co]):
+ __slots__ = ()
+
+ async def __aenter__(self):
+ return self
+
+ @abc.abstractmethod
+ async def __aexit__(self, exc_type, exc_value, traceback):
+ return None
+
+ @classmethod
+ def __subclasshook__(cls, C):
+ if cls is AsyncContextManager:
+ if sys.version_info[:2] >= (3, 6):
+ return _collections_abc._check_methods(C, "__aenter__", "__aexit__")
+ if (any("__aenter__" in B.__dict__ for B in C.__mro__) and
+ any("__aexit__" in B.__dict__ for B in C.__mro__)):
+ return True
+ return NotImplemented
+
+__all__.append('AsyncContextManager')
+""")
class Dict(dict, MutableMapping[KT, VT], extra=dict):