Add Awaitable, AsyncIterable, AsyncIterator to typing.py.
authorGuido van Rossum <guido@python.org>
Fri, 4 Dec 2015 01:31:24 +0000 (17:31 -0800)
committerGuido van Rossum <guido@python.org>
Fri, 4 Dec 2015 01:31:24 +0000 (17:31 -0800)
Lib/test/test_typing.py
Lib/typing.py

index 060119a134ebe9692fb1c8474e4e44a3775369de..b9ca64259f37b989a3176c21ef1cfced37c9e617 100644 (file)
@@ -1,3 +1,4 @@
+import asyncio
 import pickle
 import re
 import sys
@@ -960,6 +961,36 @@ class OverloadTests(TestCase):
                 pass
 
 
+T_a = TypeVar('T')
+
+
+class AwaitableWrapper(typing.Awaitable[T_a]):
+
+    def __init__(self, value):
+        self.value = value
+
+    def __await__(self) -> typing.Iterator[T_a]:
+        yield
+        return self.value
+
+
+class AsyncIteratorWrapper(typing.AsyncIterator[T_a]):
+
+    def __init__(self, value: typing.Iterable[T_a]):
+        self.value = value
+
+    def __aiter__(self) -> typing.AsyncIterator[T_a]:
+        return self
+
+    @asyncio.coroutine
+    def __anext__(self) -> T_a:
+        data = yield from self.value
+        if data:
+            return data
+        else:
+            raise StopAsyncIteration
+
+
 class CollectionsAbcTests(TestCase):
 
     def test_hashable(self):
@@ -984,6 +1015,36 @@ class CollectionsAbcTests(TestCase):
         assert isinstance(it, typing.Iterator[int])
         assert not isinstance(42, typing.Iterator)
 
+    def test_awaitable(self):
+        async def foo() -> typing.Awaitable[int]:
+            return await AwaitableWrapper(42)
+        g = foo()
+        assert issubclass(type(g), typing.Awaitable[int])
+        assert isinstance(g, typing.Awaitable)
+        assert not isinstance(foo, typing.Awaitable)
+        assert issubclass(typing.Awaitable[Manager],
+                          typing.Awaitable[Employee])
+        assert not issubclass(typing.Awaitable[Employee],
+                              typing.Awaitable[Manager])
+        g.send(None)  # Run foo() till completion, to avoid warning.
+
+    def test_async_iterable(self):
+        base_it = range(10)  # type: Iterator[int]
+        it = AsyncIteratorWrapper(base_it)
+        assert isinstance(it, typing.AsyncIterable)
+        assert isinstance(it, typing.AsyncIterable)
+        assert issubclass(typing.AsyncIterable[Manager],
+                          typing.AsyncIterable[Employee])
+        assert not isinstance(42, typing.AsyncIterable)
+
+    def test_async_iterator(self):
+        base_it = range(10)  # type: Iterator[int]
+        it = AsyncIteratorWrapper(base_it)
+        assert isinstance(it, typing.AsyncIterator)
+        assert issubclass(typing.AsyncIterator[Manager],
+                          typing.AsyncIterator[Employee])
+        assert not isinstance(42, typing.AsyncIterator)
+
     def test_sized(self):
         assert isinstance([], typing.Sized)
         assert not isinstance(42, typing.Sized)
index 1757f138220e513749646cd35ff5d58ed579b031..823f9be5d8b3908d01fa0c13c0bca370d59c413a 100644 (file)
@@ -28,6 +28,9 @@ __all__ = [
 
     # ABCs (from collections.abc).
     'AbstractSet',  # collections.abc.Set.
+    'Awaitable',
+    'AsyncIterator',
+    'AsyncIterable',
     'ByteString',
     'Container',
     'Hashable',
@@ -1261,6 +1264,18 @@ class _Protocol(metaclass=_ProtocolMeta):
 Hashable = collections_abc.Hashable  # Not generic.
 
 
+class Awaitable(Generic[T_co], extra=collections_abc.Awaitable):
+    __slots__ = ()
+
+
+class AsyncIterable(Generic[T_co], extra=collections_abc.AsyncIterable):
+    __slots__ = ()
+
+
+class AsyncIterator(AsyncIterable[T_co], extra=collections_abc.AsyncIterator):
+    __slots__ = ()
+
+
 class Iterable(Generic[T_co], extra=collections_abc.Iterable):
     __slots__ = ()