]> granicus.if.org Git - python/commitdiff
bpo-36889: Merge asyncio streams (GH-13251)
authorAndrew Svetlov <andrew.svetlov@gmail.com>
Mon, 27 May 2019 19:56:22 +0000 (22:56 +0300)
committerMiss Islington (bot) <31488909+miss-islington@users.noreply.github.com>
Mon, 27 May 2019 19:56:22 +0000 (12:56 -0700)
https://bugs.python.org/issue36889

13 files changed:
Lib/asyncio/__init__.py
Lib/asyncio/streams.py
Lib/asyncio/subprocess.py
Lib/asyncio/windows_events.py
Lib/test/test___all__.py
Lib/test/test_asyncio/test_base_events.py
Lib/test/test_asyncio/test_buffered_proto.py
Lib/test/test_asyncio/test_pep492.py
Lib/test/test_asyncio/test_server.py
Lib/test/test_asyncio/test_sslproto.py
Lib/test/test_asyncio/test_streams.py
Lib/test/test_asyncio/test_windows_events.py
Misc/NEWS.d/next/Library/2019-05-14-12-25-44.bpo-36889.MChPqP.rst [new file with mode: 0644]

index 28c2e2c429f34a8d57d02a3c6caf6218f4d9a642..a6a29dbfecd507d93c89a48c3ab513f79d545bd4 100644 (file)
@@ -3,6 +3,7 @@
 # flake8: noqa
 
 import sys
+import warnings
 
 # This relies on each of the submodules having an __all__ variable.
 from .base_events import *
@@ -43,3 +44,40 @@ if sys.platform == 'win32':  # pragma: no cover
 else:
     from .unix_events import *  # pragma: no cover
     __all__ += unix_events.__all__
+
+
+__all__ += ('StreamReader', 'StreamWriter', 'StreamReaderProtocol')  # deprecated
+
+
+def __getattr__(name):
+    global StreamReader, StreamWriter, StreamReaderProtocol
+    if name == 'StreamReader':
+        warnings.warn("StreamReader is deprecated since Python 3.8 "
+                      "in favor of Stream, and scheduled for removal "
+                      "in Python 3.10",
+                      DeprecationWarning,
+                      stacklevel=2)
+        from .streams import StreamReader as sr
+        StreamReader = sr
+        return StreamReader
+    if name == 'StreamWriter':
+        warnings.warn("StreamWriter is deprecated since Python 3.8 "
+                      "in favor of Stream, and scheduled for removal "
+                      "in Python 3.10",
+                      DeprecationWarning,
+                      stacklevel=2)
+        from .streams import StreamWriter as sw
+        StreamWriter = sw
+        return StreamWriter
+    if name == 'StreamReaderProtocol':
+        warnings.warn("Using asyncio internal class StreamReaderProtocol "
+                      "is deprecated since Python 3.8 "
+                      " and scheduled for removal "
+                      "in Python 3.10",
+                      DeprecationWarning,
+                      stacklevel=2)
+        from .streams import StreamReaderProtocol as srp
+        StreamReaderProtocol = srp
+        return StreamReaderProtocol
+
+    raise AttributeError(f"module {__name__} has no attribute {name}")
index 2f0cbfdbe852d180e98559654fb955a8b9f26ca9..480f1a3fdd74edce0a518daf19e5f5e2e5027595 100644 (file)
@@ -1,14 +1,19 @@
 __all__ = (
-    'StreamReader', 'StreamWriter', 'StreamReaderProtocol',
-    'open_connection', 'start_server')
+    'Stream', 'StreamMode',
+    'open_connection', 'start_server',
+    'connect', 'connect_read_pipe', 'connect_write_pipe',
+    'StreamServer')
 
+import enum
 import socket
 import sys
 import warnings
 import weakref
 
 if hasattr(socket, 'AF_UNIX'):
-    __all__ += ('open_unix_connection', 'start_unix_server')
+    __all__ += ('open_unix_connection', 'start_unix_server',
+                'connect_unix',
+                'UnixStreamServer')
 
 from . import coroutines
 from . import events
@@ -16,12 +21,134 @@ from . import exceptions
 from . import format_helpers
 from . import protocols
 from .log import logger
-from .tasks import sleep
+from . import tasks
 
 
 _DEFAULT_LIMIT = 2 ** 16  # 64 KiB
 
 
+class StreamMode(enum.Flag):
+    READ = enum.auto()
+    WRITE = enum.auto()
+    READWRITE = READ | WRITE
+
+
+def _ensure_can_read(mode):
+    if not mode & StreamMode.READ:
+        raise RuntimeError("The stream is write-only")
+
+
+def _ensure_can_write(mode):
+    if not mode & StreamMode.WRITE:
+        raise RuntimeError("The stream is read-only")
+
+
+class _ContextManagerHelper:
+    __slots__ = ('_awaitable', '_result')
+
+    def __init__(self, awaitable):
+        self._awaitable = awaitable
+        self._result = None
+
+    def __await__(self):
+        return self._awaitable.__await__()
+
+    async def __aenter__(self):
+        ret = await self._awaitable
+        result = await ret.__aenter__()
+        self._result = result
+        return result
+
+    async def __aexit__(self, exc_type, exc_val, exc_tb):
+        return await self._result.__aexit__(exc_type, exc_val, exc_tb)
+
+
+def connect(host=None, port=None, *,
+            limit=_DEFAULT_LIMIT,
+            ssl=None, family=0, proto=0,
+            flags=0, sock=None, local_addr=None,
+            server_hostname=None,
+            ssl_handshake_timeout=None,
+            happy_eyeballs_delay=None, interleave=None):
+    # Design note:
+    # Don't use decorator approach but exilicit non-async
+    # function to fail fast and explicitly
+    # if passed arguments don't match the function signature
+    return _ContextManagerHelper(_connect(host, port, limit,
+                                          ssl, family, proto,
+                                          flags, sock, local_addr,
+                                          server_hostname,
+                                          ssl_handshake_timeout,
+                                          happy_eyeballs_delay,
+                                          interleave))
+
+
+async def _connect(host, port,
+                  limit,
+                  ssl, family, proto,
+                  flags, sock, local_addr,
+                  server_hostname,
+                  ssl_handshake_timeout,
+                  happy_eyeballs_delay, interleave):
+    loop = events.get_running_loop()
+    stream = Stream(mode=StreamMode.READWRITE,
+                    limit=limit,
+                    loop=loop,
+                    _asyncio_internal=True)
+    await loop.create_connection(
+        lambda: _StreamProtocol(stream, loop=loop,
+                                _asyncio_internal=True),
+        host, port,
+        ssl=ssl, family=family, proto=proto,
+        flags=flags, sock=sock, local_addr=local_addr,
+        server_hostname=server_hostname,
+        ssl_handshake_timeout=ssl_handshake_timeout,
+        happy_eyeballs_delay=happy_eyeballs_delay, interleave=interleave)
+    return stream
+
+
+def connect_read_pipe(pipe, *, limit=_DEFAULT_LIMIT):
+    # Design note:
+    # Don't use decorator approach but explicit non-async
+    # function to fail fast and explicitly
+    # if passed arguments don't match the function signature
+    return _ContextManagerHelper(_connect_read_pipe(pipe, limit))
+
+
+async def _connect_read_pipe(pipe, limit):
+    loop = events.get_running_loop()
+    stream = Stream(mode=StreamMode.READ,
+                    limit=limit,
+                    loop=loop,
+                    _asyncio_internal=True)
+    await loop.connect_read_pipe(
+        lambda: _StreamProtocol(stream, loop=loop,
+                                _asyncio_internal=True),
+        pipe)
+    return stream
+
+
+def connect_write_pipe(pipe, *, limit=_DEFAULT_LIMIT):
+    # Design note:
+    # Don't use decorator approach but explicit non-async
+    # function to fail fast and explicitly
+    # if passed arguments don't match the function signature
+    return _ContextManagerHelper(_connect_write_pipe(pipe, limit))
+
+
+async def _connect_write_pipe(pipe, limit):
+    loop = events.get_running_loop()
+    stream = Stream(mode=StreamMode.WRITE,
+                    limit=limit,
+                    loop=loop,
+                    _asyncio_internal=True)
+    await loop.connect_write_pipe(
+        lambda: _StreamProtocol(stream, loop=loop,
+                                _asyncio_internal=True),
+        pipe)
+    return stream
+
+
 async def open_connection(host=None, port=None, *,
                           loop=None, limit=_DEFAULT_LIMIT, **kwds):
     """A wrapper for create_connection() returning a (reader, writer) pair.
@@ -41,16 +168,18 @@ async def open_connection(host=None, port=None, *,
     StreamReaderProtocol classes, just copy the code -- there's
     really nothing special here except some convenience.)
     """
+    warnings.warn("open_connection() is deprecated since Python 3.8 "
+                  "in favor of connect(), and scheduled for removal "
+                  "in Python 3.10",
+                  DeprecationWarning,
+                  stacklevel=2)
     if loop is None:
         loop = events.get_event_loop()
-    reader = StreamReader(limit=limit, loop=loop,
-                          _asyncio_internal=True)
-    protocol = StreamReaderProtocol(reader, loop=loop,
-                                    _asyncio_internal=True)
+    reader = StreamReader(limit=limit, loop=loop)
+    protocol = StreamReaderProtocol(reader, loop=loop, _asyncio_internal=True)
     transport, _ = await loop.create_connection(
         lambda: protocol, host, port, **kwds)
-    writer = StreamWriter(transport, protocol, reader, loop,
-                          _asyncio_internal=True)
+    writer = StreamWriter(transport, protocol, reader, loop)
     return reader, writer
 
 
@@ -77,12 +206,16 @@ async def start_server(client_connected_cb, host=None, port=None, *,
     The return value is the same as loop.create_server(), i.e. a
     Server object which can be used to stop the service.
     """
+    warnings.warn("start_server() is deprecated since Python 3.8 "
+                  "in favor of StreamServer(), and scheduled for removal "
+                  "in Python 3.10",
+                  DeprecationWarning,
+                  stacklevel=2)
     if loop is None:
         loop = events.get_event_loop()
 
     def factory():
-        reader = StreamReader(limit=limit, loop=loop,
-                              _asyncio_internal=True)
+        reader = StreamReader(limit=limit, loop=loop)
         protocol = StreamReaderProtocol(reader, client_connected_cb,
                                         loop=loop,
                                         _asyncio_internal=True)
@@ -91,33 +224,258 @@ async def start_server(client_connected_cb, host=None, port=None, *,
     return await loop.create_server(factory, host, port, **kwds)
 
 
+class _BaseStreamServer:
+    # Design notes.
+    # StreamServer and UnixStreamServer are exposed as FINAL classes,
+    # not function factories.
+    # async with serve(host, port) as server:
+    #      server.start_serving()
+    # looks ugly.
+    # The class doesn't provide API for enumerating connected streams
+    # It can be a subject for improvements in Python 3.9
+
+    _server_impl = None
+
+    def __init__(self, client_connected_cb,
+                 /,
+                 limit=_DEFAULT_LIMIT,
+                 shutdown_timeout=60,
+                 _asyncio_internal=False):
+        if not _asyncio_internal:
+            raise RuntimeError("_ServerStream is a private asyncio class")
+        self._client_connected_cb = client_connected_cb
+        self._limit = limit
+        self._loop = events.get_running_loop()
+        self._streams = {}
+        self._shutdown_timeout = shutdown_timeout
+
+    def __init_subclass__(cls):
+        if not cls.__module__.startswith('asyncio.'):
+            raise TypeError(f"asyncio.{cls.__name__} "
+                            "class cannot be inherited from")
+
+    async def bind(self):
+        if self._server_impl is not None:
+            return
+        self._server_impl = await self._bind()
+
+    def is_bound(self):
+        return self._server_impl is not None
+
+    @property
+    def sockets(self):
+        # multiple value for socket bound to both IPv4 and IPv6 families
+        if self._server_impl is None:
+            return ()
+        return self._server_impl.sockets
+
+    def is_serving(self):
+        if self._server_impl is None:
+            return False
+        return self._server_impl.is_serving()
+
+    async def start_serving(self):
+        await self.bind()
+        await self._server_impl.start_serving()
+
+    async def serve_forever(self):
+        await self.start_serving()
+        await self._server_impl.serve_forever()
+
+    async def close(self):
+        if self._server_impl is None:
+            return
+        self._server_impl.close()
+        streams = list(self._streams.keys())
+        active_tasks = list(self._streams.values())
+        if streams:
+            await tasks.wait([stream.close() for stream in streams])
+        await self._server_impl.wait_closed()
+        self._server_impl = None
+        await self._shutdown_active_tasks(active_tasks)
+
+    async def abort(self):
+        if self._server_impl is None:
+            return
+        self._server_impl.close()
+        streams = list(self._streams.keys())
+        active_tasks = list(self._streams.values())
+        if streams:
+            await tasks.wait([stream.abort() for stream in streams])
+        await self._server_impl.wait_closed()
+        self._server_impl = None
+        await self._shutdown_active_tasks(active_tasks)
+
+    async def __aenter__(self):
+        await self.bind()
+        return self
+
+    async def __aexit__(self, exc_type, exc_value, exc_tb):
+        await self.close()
+
+    def _attach(self, stream, task):
+        self._streams[stream] = task
+
+    def _detach(self, stream, task):
+        del self._streams[stream]
+
+    async def _shutdown_active_tasks(self, active_tasks):
+        if not active_tasks:
+            return
+        # NOTE: tasks finished with exception are reported
+        # by the Task.__del__() method.
+        done, pending = await tasks.wait(active_tasks,
+                                         timeout=self._shutdown_timeout)
+        if not pending:
+            return
+        for task in pending:
+            task.cancel()
+        done, pending = await tasks.wait(pending,
+                                         timeout=self._shutdown_timeout)
+        for task in pending:
+            self._loop.call_exception_handler({
+                "message": (f'{task!r} ignored cancellation request '
+                            f'from a closing {self!r}'),
+                "stream_server": self
+            })
+
+    def __repr__(self):
+        ret = [f'{self.__class__.__name__}']
+        if self.is_serving():
+            ret.append('serving')
+        if self.sockets:
+            ret.append(f'sockets={self.sockets!r}')
+        return '<' + ' '.join(ret) + '>'
+
+    def __del__(self, _warn=warnings.warn):
+        if self._server_impl is not None:
+            _warn(f"unclosed stream server {self!r}",
+                  ResourceWarning, source=self)
+            self._server_impl.close()
+
+
+class StreamServer(_BaseStreamServer):
+
+    def __init__(self, client_connected_cb, /, host=None, port=None, *,
+                 limit=_DEFAULT_LIMIT,
+                 family=socket.AF_UNSPEC,
+                 flags=socket.AI_PASSIVE, sock=None, backlog=100,
+                 ssl=None, reuse_address=None, reuse_port=None,
+                 ssl_handshake_timeout=None,
+                 shutdown_timeout=60):
+        super().__init__(client_connected_cb,
+                         limit=limit,
+                         shutdown_timeout=shutdown_timeout,
+                         _asyncio_internal=True)
+        self._host = host
+        self._port = port
+        self._family = family
+        self._flags = flags
+        self._sock = sock
+        self._backlog = backlog
+        self._ssl = ssl
+        self._reuse_address = reuse_address
+        self._reuse_port = reuse_port
+        self._ssl_handshake_timeout = ssl_handshake_timeout
+
+    async def _bind(self):
+        def factory():
+            protocol = _ServerStreamProtocol(self,
+                                             self._limit,
+                                             self._client_connected_cb,
+                                             loop=self._loop,
+                                             _asyncio_internal=True)
+            return protocol
+        return await self._loop.create_server(
+            factory,
+            self._host,
+            self._port,
+            start_serving=False,
+            family=self._family,
+            flags=self._flags,
+            sock=self._sock,
+            backlog=self._backlog,
+            ssl=self._ssl,
+            reuse_address=self._reuse_address,
+            reuse_port=self._reuse_port,
+            ssl_handshake_timeout=self._ssl_handshake_timeout)
+
+
 if hasattr(socket, 'AF_UNIX'):
     # UNIX Domain Sockets are supported on this platform
 
     async def open_unix_connection(path=None, *,
                                    loop=None, limit=_DEFAULT_LIMIT, **kwds):
         """Similar to `open_connection` but works with UNIX Domain Sockets."""
+        warnings.warn("open_unix_connection() is deprecated since Python 3.8 "
+                      "in favor of connect_unix(), and scheduled for removal "
+                      "in Python 3.10",
+                      DeprecationWarning,
+                      stacklevel=2)
         if loop is None:
             loop = events.get_event_loop()
-        reader = StreamReader(limit=limit, loop=loop,
-                              _asyncio_internal=True)
+        reader = StreamReader(limit=limit, loop=loop)
         protocol = StreamReaderProtocol(reader, loop=loop,
                                         _asyncio_internal=True)
         transport, _ = await loop.create_unix_connection(
             lambda: protocol, path, **kwds)
-        writer = StreamWriter(transport, protocol, reader, loop,
-                              _asyncio_internal=True)
+        writer = StreamWriter(transport, protocol, reader, loop)
         return reader, writer
 
+
+    def connect_unix(path=None, *,
+                     limit=_DEFAULT_LIMIT,
+                     ssl=None, sock=None,
+                     server_hostname=None,
+                     ssl_handshake_timeout=None):
+        """Similar to `connect()` but works with UNIX Domain Sockets."""
+        # Design note:
+        # Don't use decorator approach but exilicit non-async
+        # function to fail fast and explicitly
+        # if passed arguments don't match the function signature
+        return _ContextManagerHelper(_connect_unix(path,
+                                                   limit,
+                                                   ssl, sock,
+                                                   server_hostname,
+                                                   ssl_handshake_timeout))
+
+
+    async def _connect_unix(path,
+                           limit,
+                           ssl, sock,
+                           server_hostname,
+                           ssl_handshake_timeout):
+        """Similar to `connect()` but works with UNIX Domain Sockets."""
+        loop = events.get_running_loop()
+        stream = Stream(mode=StreamMode.READWRITE,
+                        limit=limit,
+                        loop=loop,
+                        _asyncio_internal=True)
+        await loop.create_unix_connection(
+            lambda: _StreamProtocol(stream,
+                                    loop=loop,
+                                    _asyncio_internal=True),
+            path,
+            ssl=ssl,
+            sock=sock,
+            server_hostname=server_hostname,
+            ssl_handshake_timeout=ssl_handshake_timeout)
+        return stream
+
+
     async def start_unix_server(client_connected_cb, path=None, *,
                                 loop=None, limit=_DEFAULT_LIMIT, **kwds):
         """Similar to `start_server` but works with UNIX Domain Sockets."""
+        warnings.warn("start_unix_server() is deprecated since Python 3.8 "
+                      "in favor of UnixStreamServer(), and scheduled "
+                      "for removal in Python 3.10",
+                      DeprecationWarning,
+                      stacklevel=2)
         if loop is None:
             loop = events.get_event_loop()
 
         def factory():
-            reader = StreamReader(limit=limit, loop=loop,
-                                  _asyncio_internal=True)
+            reader = StreamReader(limit=limit, loop=loop)
             protocol = StreamReaderProtocol(reader, client_connected_cb,
                                             loop=loop,
                                             _asyncio_internal=True)
@@ -125,6 +483,42 @@ if hasattr(socket, 'AF_UNIX'):
 
         return await loop.create_unix_server(factory, path, **kwds)
 
+    class UnixStreamServer(_BaseStreamServer):
+
+        def __init__(self, client_connected_cb, /, path=None, *,
+                     limit=_DEFAULT_LIMIT,
+                     sock=None,
+                     backlog=100,
+                     ssl=None,
+                     ssl_handshake_timeout=None,
+                     shutdown_timeout=60):
+            super().__init__(client_connected_cb,
+                             limit=limit,
+                             shutdown_timeout=shutdown_timeout,
+                             _asyncio_internal=True)
+            self._path = path
+            self._sock = sock
+            self._backlog = backlog
+            self._ssl = ssl
+            self._ssl_handshake_timeout = ssl_handshake_timeout
+
+        async def _bind(self):
+            def factory():
+                protocol = _ServerStreamProtocol(self,
+                                                 self._limit,
+                                                 self._client_connected_cb,
+                                                 loop=self._loop,
+                                                 _asyncio_internal=True)
+                return protocol
+            return await self._loop.create_unix_server(
+                factory,
+                self._path,
+                start_serving=False,
+                sock=self._sock,
+                backlog=self._backlog,
+                ssl=self._ssl,
+                ssl_handshake_timeout=self._ssl_handshake_timeout)
+
 
 class FlowControlMixin(protocols.Protocol):
     """Reusable flow control logic for StreamWriter.drain().
@@ -203,6 +597,8 @@ class FlowControlMixin(protocols.Protocol):
         raise NotImplementedError
 
 
+# begin legacy stream APIs
+
 class StreamReaderProtocol(FlowControlMixin, protocols.Protocol):
     """Helper class to adapt between Protocol and StreamReader.
 
@@ -212,105 +608,47 @@ class StreamReaderProtocol(FlowControlMixin, protocols.Protocol):
     call inappropriate methods of the protocol.)
     """
 
-    _source_traceback = None
-
     def __init__(self, stream_reader, client_connected_cb=None, loop=None,
                  *, _asyncio_internal=False):
         super().__init__(loop=loop, _asyncio_internal=_asyncio_internal)
-        if stream_reader is not None:
-            self._stream_reader_wr = weakref.ref(stream_reader,
-                                                 self._on_reader_gc)
-            self._source_traceback = stream_reader._source_traceback
-        else:
-            self._stream_reader_wr = None
-        if client_connected_cb is not None:
-            # This is a stream created by the `create_server()` function.
-            # Keep a strong reference to the reader until a connection
-            # is established.
-            self._strong_reader = stream_reader
-        self._reject_connection = False
+        self._stream_reader = stream_reader
         self._stream_writer = None
-        self._transport = None
         self._client_connected_cb = client_connected_cb
         self._over_ssl = False
         self._closed = self._loop.create_future()
 
-    def _on_reader_gc(self, wr):
-        transport = self._transport
-        if transport is not None:
-            # connection_made was called
-            context = {
-                'message': ('An open stream object is being garbage '
-                            'collected; call "stream.close()" explicitly.')
-            }
-            if self._source_traceback:
-                context['source_traceback'] = self._source_traceback
-            self._loop.call_exception_handler(context)
-            transport.abort()
-        else:
-            self._reject_connection = True
-        self._stream_reader_wr = None
-
-    @property
-    def _stream_reader(self):
-        if self._stream_reader_wr is None:
-            return None
-        return self._stream_reader_wr()
-
     def connection_made(self, transport):
-        if self._reject_connection:
-            context = {
-                'message': ('An open stream was garbage collected prior to '
-                            'establishing network connection; '
-                            'call "stream.close()" explicitly.')
-            }
-            if self._source_traceback:
-                context['source_traceback'] = self._source_traceback
-            self._loop.call_exception_handler(context)
-            transport.abort()
-            return
-        self._transport = transport
-        reader = self._stream_reader
-        if reader is not None:
-            reader.set_transport(transport)
+        self._stream_reader.set_transport(transport)
         self._over_ssl = transport.get_extra_info('sslcontext') is not None
         if self._client_connected_cb is not None:
             self._stream_writer = StreamWriter(transport, self,
-                                               reader,
-                                               self._loop,
-                                               _asyncio_internal=True)
-            res = self._client_connected_cb(reader,
+                                               self._stream_reader,
+                                               self._loop)
+            res = self._client_connected_cb(self._stream_reader,
                                             self._stream_writer)
             if coroutines.iscoroutine(res):
                 self._loop.create_task(res)
-            self._strong_reader = None
 
     def connection_lost(self, exc):
-        reader = self._stream_reader
-        if reader is not None:
+        if self._stream_reader is not None:
             if exc is None:
-                reader.feed_eof()
+                self._stream_reader.feed_eof()
             else:
-                reader.set_exception(exc)
+                self._stream_reader.set_exception(exc)
         if not self._closed.done():
             if exc is None:
                 self._closed.set_result(None)
             else:
                 self._closed.set_exception(exc)
         super().connection_lost(exc)
-        self._stream_reader_wr = None
+        self._stream_reader = None
         self._stream_writer = None
-        self._transport = None
 
     def data_received(self, data):
-        reader = self._stream_reader
-        if reader is not None:
-            reader.feed_data(data)
+        self._stream_reader.feed_data(data)
 
     def eof_received(self):
-        reader = self._stream_reader
-        if reader is not None:
-            reader.feed_eof()
+        self._stream_reader.feed_eof()
         if self._over_ssl:
             # Prevent a warning in SSLProtocol.eof_received:
             # "returning true from eof_received()
@@ -318,9 +656,6 @@ class StreamReaderProtocol(FlowControlMixin, protocols.Protocol):
             return False
         return True
 
-    def _get_close_waiter(self, stream):
-        return self._closed
-
     def __del__(self):
         # Prevent reports about unhandled exceptions.
         # Better than self._closed._log_traceback = False hack
@@ -329,13 +664,6 @@ class StreamReaderProtocol(FlowControlMixin, protocols.Protocol):
             closed.exception()
 
 
-def _swallow_unhandled_exception(task):
-    # Do a trick to suppress unhandled exception
-    # if stream.write() was used without await and
-    # stream.drain() was paused and resumed with an exception
-    task.exception()
-
-
 class StreamWriter:
     """Wraps a Transport.
 
@@ -346,21 +674,13 @@ class StreamWriter:
     directly.
     """
 
-    def __init__(self, transport, protocol, reader, loop,
-                 *, _asyncio_internal=False):
-        if not _asyncio_internal:
-            warnings.warn(f"{self.__class__} should be instaniated "
-                          "by asyncio internals only, "
-                          "please avoid its creation from user code",
-                          DeprecationWarning)
+    def __init__(self, transport, protocol, reader, loop):
         self._transport = transport
         self._protocol = protocol
         # drain() expects that the reader has an exception() method
         assert reader is None or isinstance(reader, StreamReader)
         self._reader = reader
         self._loop = loop
-        self._complete_fut = self._loop.create_future()
-        self._complete_fut.set_result(None)
 
     def __repr__(self):
         info = [self.__class__.__name__, f'transport={self._transport!r}']
@@ -374,35 +694,9 @@ class StreamWriter:
 
     def write(self, data):
         self._transport.write(data)
-        return self._fast_drain()
 
     def writelines(self, data):
         self._transport.writelines(data)
-        return self._fast_drain()
-
-    def _fast_drain(self):
-        # The helper tries to use fast-path to return already existing complete future
-        # object if underlying transport is not paused and actual waiting for writing
-        # resume is not needed
-        if self._reader is not None:
-            # this branch will be simplified after merging reader with writer
-            exc = self._reader.exception()
-            if exc is not None:
-                fut = self._loop.create_future()
-                fut.set_exception(exc)
-                return fut
-        if not self._transport.is_closing():
-            if self._protocol._connection_lost:
-                fut = self._loop.create_future()
-                fut.set_exception(ConnectionResetError('Connection lost'))
-                return fut
-            if not self._protocol._paused:
-                # fast path, the stream is not paused
-                # no need to wait for resume signal
-                return self._complete_fut
-        ret = self._loop.create_task(self.drain())
-        ret.add_done_callback(_swallow_unhandled_exception)
-        return ret
 
     def write_eof(self):
         return self._transport.write_eof()
@@ -411,14 +705,13 @@ class StreamWriter:
         return self._transport.can_write_eof()
 
     def close(self):
-        self._transport.close()
-        return self._protocol._get_close_waiter(self)
+        return self._transport.close()
 
     def is_closing(self):
         return self._transport.is_closing()
 
     async def wait_closed(self):
-        await self._protocol._get_close_waiter(self)
+        await self._protocol._closed
 
     def get_extra_info(self, name, default=None):
         return self._transport.get_extra_info(name, default)
@@ -436,25 +729,19 @@ class StreamWriter:
             if exc is not None:
                 raise exc
         if self._transport.is_closing():
-            # Wait for protocol.connection_lost() call
-            # Raise connection closing error if any,
-            # ConnectionResetError otherwise
-            await sleep(0)
+            # Yield to the event loop so connection_lost() may be
+            # called.  Without this, _drain_helper() would return
+            # immediately, and code that calls
+            #     write(...); await drain()
+            # in a loop would never call connection_lost(), so it
+            # would not see an error when the socket is closed.
+            await tasks.sleep(0, loop=self._loop)
         await self._protocol._drain_helper()
 
 
 class StreamReader:
 
-    _source_traceback = None
-
-    def __init__(self, limit=_DEFAULT_LIMIT, loop=None,
-                 *, _asyncio_internal=False):
-        if not _asyncio_internal:
-            warnings.warn(f"{self.__class__} should be instaniated "
-                          "by asyncio internals only, "
-                          "please avoid its creation from user code",
-                          DeprecationWarning)
-
+    def __init__(self, limit=_DEFAULT_LIMIT, loop=None):
         # The line length limit is  a security feature;
         # it also doubles as half the buffer limit.
 
@@ -472,9 +759,6 @@ class StreamReader:
         self._exception = None
         self._transport = None
         self._paused = False
-        if self._loop.get_debug():
-            self._source_traceback = format_helpers.extract_stack(
-                sys._getframe(1))
 
     def __repr__(self):
         info = ['StreamReader']
@@ -802,3 +1086,671 @@ class StreamReader:
         if val == b'':
             raise StopAsyncIteration
         return val
+
+
+# end legacy stream APIs
+
+
+class _BaseStreamProtocol(FlowControlMixin, protocols.Protocol):
+    """Helper class to adapt between Protocol and StreamReader.
+
+    (This is a helper class instead of making StreamReader itself a
+    Protocol subclass, because the StreamReader has other potential
+    uses, and to prevent the user of the StreamReader to accidentally
+    call inappropriate methods of the protocol.)
+    """
+
+    _stream = None  # initialized in derived classes
+
+    def __init__(self, loop=None,
+                 *, _asyncio_internal=False):
+        super().__init__(loop=loop, _asyncio_internal=_asyncio_internal)
+        self._transport = None
+        self._over_ssl = False
+        self._closed = self._loop.create_future()
+
+    def connection_made(self, transport):
+        self._transport = transport
+        self._over_ssl = transport.get_extra_info('sslcontext') is not None
+
+    def connection_lost(self, exc):
+        stream = self._stream
+        if stream is not None:
+            if exc is None:
+                stream.feed_eof()
+            else:
+                stream.set_exception(exc)
+        if not self._closed.done():
+            if exc is None:
+                self._closed.set_result(None)
+            else:
+                self._closed.set_exception(exc)
+        super().connection_lost(exc)
+        self._transport = None
+
+    def data_received(self, data):
+        stream = self._stream
+        if stream is not None:
+            stream.feed_data(data)
+
+    def eof_received(self):
+        stream = self._stream
+        if stream is not None:
+            stream.feed_eof()
+        if self._over_ssl:
+            # Prevent a warning in SSLProtocol.eof_received:
+            # "returning true from eof_received()
+            # has no effect when using ssl"
+            return False
+        return True
+
+    def _get_close_waiter(self, stream):
+        return self._closed
+
+    def __del__(self):
+        # Prevent reports about unhandled exceptions.
+        # Better than self._closed._log_traceback = False hack
+        closed = self._get_close_waiter(self._stream)
+        if closed.done() and not closed.cancelled():
+            closed.exception()
+
+
+class _StreamProtocol(_BaseStreamProtocol):
+    _source_traceback = None
+
+    def __init__(self, stream, loop=None,
+                 *, _asyncio_internal=False):
+        super().__init__(loop=loop, _asyncio_internal=_asyncio_internal)
+        self._source_traceback = stream._source_traceback
+        self._stream_wr = weakref.ref(stream, self._on_gc)
+        self._reject_connection = False
+
+    def _on_gc(self, wr):
+        transport = self._transport
+        if transport is not None:
+            # connection_made was called
+            context = {
+                'message': ('An open stream object is being garbage '
+                            'collected; call "stream.close()" explicitly.')
+            }
+            if self._source_traceback:
+                context['source_traceback'] = self._source_traceback
+            self._loop.call_exception_handler(context)
+            transport.abort()
+        else:
+            self._reject_connection = True
+        self._stream_wr = None
+
+    @property
+    def _stream(self):
+        if self._stream_wr is None:
+            return None
+        return self._stream_wr()
+
+    def connection_made(self, transport):
+        if self._reject_connection:
+            context = {
+                'message': ('An open stream was garbage collected prior to '
+                            'establishing network connection; '
+                            'call "stream.close()" explicitly.')
+            }
+            if self._source_traceback:
+                context['source_traceback'] = self._source_traceback
+            self._loop.call_exception_handler(context)
+            transport.abort()
+            return
+        super().connection_made(transport)
+        stream = self._stream
+        if stream is None:
+            return
+        stream.set_transport(transport)
+        stream._protocol = self
+
+    def connection_lost(self, exc):
+        super().connection_lost(exc)
+        self._stream_wr = None
+
+
+class _ServerStreamProtocol(_BaseStreamProtocol):
+    def __init__(self, server, limit, client_connected_cb, loop=None,
+                 *, _asyncio_internal=False):
+        super().__init__(loop=loop, _asyncio_internal=_asyncio_internal)
+        assert self._closed
+        self._client_connected_cb = client_connected_cb
+        self._limit = limit
+        self._server = server
+        self._task = None
+
+    def connection_made(self, transport):
+        super().connection_made(transport)
+        stream = Stream(mode=StreamMode.READWRITE,
+                        transport=transport,
+                        protocol=self,
+                        limit=self._limit,
+                        loop=self._loop,
+                        is_server_side=True,
+                        _asyncio_internal=True)
+        self._stream = stream
+        # If self._client_connected_cb(self._stream) fails
+        # the exception is logged by transport
+        self._task = self._loop.create_task(
+            self._client_connected_cb(self._stream))
+        self._server._attach(stream, self._task)
+
+    def connection_lost(self, exc):
+        super().connection_lost(exc)
+        self._server._detach(self._stream, self._task)
+        self._stream = None
+
+
+class _OptionalAwait:
+    # The class doesn't create a coroutine
+    # if not awaited
+    # It prevents "coroutine is never awaited" message
+
+    __slots___ = ('_method',)
+
+    def __init__(self, method):
+        self._method = method
+
+    def __await__(self):
+        return self._method().__await__()
+
+
+class Stream:
+    """Wraps a Transport.
+
+    This exposes write(), writelines(), [can_]write_eof(),
+    get_extra_info() and close().  It adds drain() which returns an
+    optional Future on which you can wait for flow control.  It also
+    adds a transport property which references the Transport
+    directly.
+    """
+
+    _source_traceback = None
+
+    def __init__(self, mode, *,
+                 transport=None,
+                 protocol=None,
+                 loop=None,
+                 limit=_DEFAULT_LIMIT,
+                 is_server_side=False,
+                 _asyncio_internal=False):
+        if not _asyncio_internal:
+            warnings.warn(f"{self.__class__} should be instaniated "
+                          "by asyncio internals only, "
+                          "please avoid its creation from user code",
+                          DeprecationWarning)
+        self._mode = mode
+        self._transport = transport
+        self._protocol = protocol
+        self._is_server_side = is_server_side
+
+        # The line length limit is  a security feature;
+        # it also doubles as half the buffer limit.
+
+        if limit <= 0:
+            raise ValueError('Limit cannot be <= 0')
+
+        self._limit = limit
+        if loop is None:
+            self._loop = events.get_event_loop()
+        else:
+            self._loop = loop
+        self._buffer = bytearray()
+        self._eof = False    # Whether we're done.
+        self._waiter = None  # A future used by _wait_for_data()
+        self._exception = None
+        self._paused = False
+        self._complete_fut = self._loop.create_future()
+        self._complete_fut.set_result(None)
+
+        if self._loop.get_debug():
+            self._source_traceback = format_helpers.extract_stack(
+                sys._getframe(1))
+
+    def __repr__(self):
+        info = [self.__class__.__name__]
+        info.append(f'mode={self._mode}')
+        if self._buffer:
+            info.append(f'{len(self._buffer)} bytes')
+        if self._eof:
+            info.append('eof')
+        if self._limit != _DEFAULT_LIMIT:
+            info.append(f'limit={self._limit}')
+        if self._waiter:
+            info.append(f'waiter={self._waiter!r}')
+        if self._exception:
+            info.append(f'exception={self._exception!r}')
+        if self._transport:
+            info.append(f'transport={self._transport!r}')
+        if self._paused:
+            info.append('paused')
+        return '<{}>'.format(' '.join(info))
+
+    @property
+    def mode(self):
+        return self._mode
+
+    def is_server_side(self):
+        return self._is_server_side
+
+    @property
+    def transport(self):
+        return self._transport
+
+    def write(self, data):
+        _ensure_can_write(self._mode)
+        self._transport.write(data)
+        return self._fast_drain()
+
+    def writelines(self, data):
+        _ensure_can_write(self._mode)
+        self._transport.writelines(data)
+        return self._fast_drain()
+
+    def _fast_drain(self):
+        # The helper tries to use fast-path to return already existing
+        # complete future object if underlying transport is not paused
+        #and actual waiting for writing resume is not needed
+        exc = self.exception()
+        if exc is not None:
+            fut = self._loop.create_future()
+            fut.set_exception(exc)
+            return fut
+        if not self._transport.is_closing():
+            if self._protocol._connection_lost:
+                fut = self._loop.create_future()
+                fut.set_exception(ConnectionResetError('Connection lost'))
+                return fut
+            if not self._protocol._paused:
+                # fast path, the stream is not paused
+                # no need to wait for resume signal
+                return self._complete_fut
+        return _OptionalAwait(self.drain)
+
+    def write_eof(self):
+        _ensure_can_write(self._mode)
+        return self._transport.write_eof()
+
+    def can_write_eof(self):
+        if not self._mode.is_write():
+            return False
+        return self._transport.can_write_eof()
+
+    def close(self):
+        self._transport.close()
+        return _OptionalAwait(self.wait_closed)
+
+    def is_closing(self):
+        return self._transport.is_closing()
+
+    async def abort(self):
+        self._transport.abort()
+        await self.wait_closed()
+
+    async def wait_closed(self):
+        await self._protocol._get_close_waiter(self)
+
+    def get_extra_info(self, name, default=None):
+        return self._transport.get_extra_info(name, default)
+
+    async def drain(self):
+        """Flush the write buffer.
+
+        The intended use is to write
+
+          w.write(data)
+          await w.drain()
+        """
+        _ensure_can_write(self._mode)
+        exc = self.exception()
+        if exc is not None:
+            raise exc
+        if self._transport.is_closing():
+            # Wait for protocol.connection_lost() call
+            # Raise connection closing error if any,
+            # ConnectionResetError otherwise
+            await tasks.sleep(0)
+        await self._protocol._drain_helper()
+
+    async def sendfile(self, file, offset=0, count=None, *, fallback=True):
+        await self.drain()  # check for stream mode and exceptions
+        return await self._loop.sendfile(self._transport, file,
+                                         offset, count, fallback=fallback)
+
+    async def start_tls(self, sslcontext, *,
+                        server_hostname=None,
+                        ssl_handshake_timeout=None):
+        await self.drain()  # check for stream mode and exceptions
+        transport = await self._loop.start_tls(
+            self._transport, self._protocol, sslcontext,
+            server_side=self._is_server_side,
+            server_hostname=server_hostname,
+            ssl_handshake_timeout=ssl_handshake_timeout)
+        self._transport = transport
+        self._protocol._transport = transport
+        self._protocol._over_ssl = True
+
+    def exception(self):
+        return self._exception
+
+    def set_exception(self, exc):
+        self._exception = exc
+
+        waiter = self._waiter
+        if waiter is not None:
+            self._waiter = None
+            if not waiter.cancelled():
+                waiter.set_exception(exc)
+
+    def _wakeup_waiter(self):
+        """Wakeup read*() functions waiting for data or EOF."""
+        waiter = self._waiter
+        if waiter is not None:
+            self._waiter = None
+            if not waiter.cancelled():
+                waiter.set_result(None)
+
+    def set_transport(self, transport):
+        if transport is self._transport:
+            return
+        assert self._transport is None, 'Transport already set'
+        self._transport = transport
+
+    def _maybe_resume_transport(self):
+        if self._paused and len(self._buffer) <= self._limit:
+            self._paused = False
+            self._transport.resume_reading()
+
+    def feed_eof(self):
+        self._eof = True
+        self._wakeup_waiter()
+
+    def at_eof(self):
+        """Return True if the buffer is empty and 'feed_eof' was called."""
+        return self._eof and not self._buffer
+
+    def feed_data(self, data):
+        _ensure_can_read(self._mode)
+        assert not self._eof, 'feed_data after feed_eof'
+
+        if not data:
+            return
+
+        self._buffer.extend(data)
+        self._wakeup_waiter()
+
+        if (self._transport is not None and
+                not self._paused and
+                len(self._buffer) > 2 * self._limit):
+            try:
+                self._transport.pause_reading()
+            except NotImplementedError:
+                # The transport can't be paused.
+                # We'll just have to buffer all data.
+                # Forget the transport so we don't keep trying.
+                self._transport = None
+            else:
+                self._paused = True
+
+    async def _wait_for_data(self, func_name):
+        """Wait until feed_data() or feed_eof() is called.
+
+        If stream was paused, automatically resume it.
+        """
+        # StreamReader uses a future to link the protocol feed_data() method
+        # to a read coroutine. Running two read coroutines at the same time
+        # would have an unexpected behaviour. It would not possible to know
+        # which coroutine would get the next data.
+        if self._waiter is not None:
+            raise RuntimeError(
+                f'{func_name}() called while another coroutine is '
+                f'already waiting for incoming data')
+
+        assert not self._eof, '_wait_for_data after EOF'
+
+        # Waiting for data while paused will make deadlock, so prevent it.
+        # This is essential for readexactly(n) for case when n > self._limit.
+        if self._paused:
+            self._paused = False
+            self._transport.resume_reading()
+
+        self._waiter = self._loop.create_future()
+        try:
+            await self._waiter
+        finally:
+            self._waiter = None
+
+    async def readline(self):
+        """Read chunk of data from the stream until newline (b'\n') is found.
+
+        On success, return chunk that ends with newline. If only partial
+        line can be read due to EOF, return incomplete line without
+        terminating newline. When EOF was reached while no bytes read, empty
+        bytes object is returned.
+
+        If limit is reached, ValueError will be raised. In that case, if
+        newline was found, complete line including newline will be removed
+        from internal buffer. Else, internal buffer will be cleared. Limit is
+        compared against part of the line without newline.
+
+        If stream was paused, this function will automatically resume it if
+        needed.
+        """
+        _ensure_can_read(self._mode)
+        sep = b'\n'
+        seplen = len(sep)
+        try:
+            line = await self.readuntil(sep)
+        except exceptions.IncompleteReadError as e:
+            return e.partial
+        except exceptions.LimitOverrunError as e:
+            if self._buffer.startswith(sep, e.consumed):
+                del self._buffer[:e.consumed + seplen]
+            else:
+                self._buffer.clear()
+            self._maybe_resume_transport()
+            raise ValueError(e.args[0])
+        return line
+
+    async def readuntil(self, separator=b'\n'):
+        """Read data from the stream until ``separator`` is found.
+
+        On success, the data and separator will be removed from the
+        internal buffer (consumed). Returned data will include the
+        separator at the end.
+
+        Configured stream limit is used to check result. Limit sets the
+        maximal length of data that can be returned, not counting the
+        separator.
+
+        If an EOF occurs and the complete separator is still not found,
+        an IncompleteReadError exception will be raised, and the internal
+        buffer will be reset.  The IncompleteReadError.partial attribute
+        may contain the separator partially.
+
+        If the data cannot be read because of over limit, a
+        LimitOverrunError exception  will be raised, and the data
+        will be left in the internal buffer, so it can be read again.
+        """
+        _ensure_can_read(self._mode)
+        seplen = len(separator)
+        if seplen == 0:
+            raise ValueError('Separator should be at least one-byte string')
+
+        if self._exception is not None:
+            raise self._exception
+
+        # Consume whole buffer except last bytes, which length is
+        # one less than seplen. Let's check corner cases with
+        # separator='SEPARATOR':
+        # * we have received almost complete separator (without last
+        #   byte). i.e buffer='some textSEPARATO'. In this case we
+        #   can safely consume len(separator) - 1 bytes.
+        # * last byte of buffer is first byte of separator, i.e.
+        #   buffer='abcdefghijklmnopqrS'. We may safely consume
+        #   everything except that last byte, but this require to
+        #   analyze bytes of buffer that match partial separator.
+        #   This is slow and/or require FSM. For this case our
+        #   implementation is not optimal, since require rescanning
+        #   of data that is known to not belong to separator. In
+        #   real world, separator will not be so long to notice
+        #   performance problems. Even when reading MIME-encoded
+        #   messages :)
+
+        # `offset` is the number of bytes from the beginning of the buffer
+        # where there is no occurrence of `separator`.
+        offset = 0
+
+        # Loop until we find `separator` in the buffer, exceed the buffer size,
+        # or an EOF has happened.
+        while True:
+            buflen = len(self._buffer)
+
+            # Check if we now have enough data in the buffer for `separator` to
+            # fit.
+            if buflen - offset >= seplen:
+                isep = self._buffer.find(separator, offset)
+
+                if isep != -1:
+                    # `separator` is in the buffer. `isep` will be used later
+                    # to retrieve the data.
+                    break
+
+                # see upper comment for explanation.
+                offset = buflen + 1 - seplen
+                if offset > self._limit:
+                    raise exceptions.LimitOverrunError(
+                        'Separator is not found, and chunk exceed the limit',
+                        offset)
+
+            # Complete message (with full separator) may be present in buffer
+            # even when EOF flag is set. This may happen when the last chunk
+            # adds data which makes separator be found. That's why we check for
+            # EOF *ater* inspecting the buffer.
+            if self._eof:
+                chunk = bytes(self._buffer)
+                self._buffer.clear()
+                raise exceptions.IncompleteReadError(chunk, None)
+
+            # _wait_for_data() will resume reading if stream was paused.
+            await self._wait_for_data('readuntil')
+
+        if isep > self._limit:
+            raise exceptions.LimitOverrunError(
+                'Separator is found, but chunk is longer than limit', isep)
+
+        chunk = self._buffer[:isep + seplen]
+        del self._buffer[:isep + seplen]
+        self._maybe_resume_transport()
+        return bytes(chunk)
+
+    async def read(self, n=-1):
+        """Read up to `n` bytes from the stream.
+
+        If n is not provided, or set to -1, read until EOF and return all read
+        bytes. If the EOF was received and the internal buffer is empty, return
+        an empty bytes object.
+
+        If n is zero, return empty bytes object immediately.
+
+        If n is positive, this function try to read `n` bytes, and may return
+        less or equal bytes than requested, but at least one byte. If EOF was
+        received before any byte is read, this function returns empty byte
+        object.
+
+        Returned value is not limited with limit, configured at stream
+        creation.
+
+        If stream was paused, this function will automatically resume it if
+        needed.
+        """
+        _ensure_can_read(self._mode)
+
+        if self._exception is not None:
+            raise self._exception
+
+        if n == 0:
+            return b''
+
+        if n < 0:
+            # This used to just loop creating a new waiter hoping to
+            # collect everything in self._buffer, but that would
+            # deadlock if the subprocess sends more than self.limit
+            # bytes.  So just call self.read(self._limit) until EOF.
+            blocks = []
+            while True:
+                block = await self.read(self._limit)
+                if not block:
+                    break
+                blocks.append(block)
+            return b''.join(blocks)
+
+        if not self._buffer and not self._eof:
+            await self._wait_for_data('read')
+
+        # This will work right even if buffer is less than n bytes
+        data = bytes(self._buffer[:n])
+        del self._buffer[:n]
+
+        self._maybe_resume_transport()
+        return data
+
+    async def readexactly(self, n):
+        """Read exactly `n` bytes.
+
+        Raise an IncompleteReadError if EOF is reached before `n` bytes can be
+        read. The IncompleteReadError.partial attribute of the exception will
+        contain the partial read bytes.
+
+        if n is zero, return empty bytes object.
+
+        Returned value is not limited with limit, configured at stream
+        creation.
+
+        If stream was paused, this function will automatically resume it if
+        needed.
+        """
+        _ensure_can_read(self._mode)
+        if n < 0:
+            raise ValueError('readexactly size can not be less than zero')
+
+        if self._exception is not None:
+            raise self._exception
+
+        if n == 0:
+            return b''
+
+        while len(self._buffer) < n:
+            if self._eof:
+                incomplete = bytes(self._buffer)
+                self._buffer.clear()
+                raise exceptions.IncompleteReadError(incomplete, n)
+
+            await self._wait_for_data('readexactly')
+
+        if len(self._buffer) == n:
+            data = bytes(self._buffer)
+            self._buffer.clear()
+        else:
+            data = bytes(self._buffer[:n])
+            del self._buffer[:n]
+        self._maybe_resume_transport()
+        return data
+
+    def __aiter__(self):
+        _ensure_can_read(self._mode)
+        return self
+
+    async def __anext__(self):
+        val = await self.readline()
+        if val == b'':
+            raise StopAsyncIteration
+        return val
+
+    async def __aenter__(self):
+        return self
+
+    async def __aexit__(self, exc_type, exc_val, exc_tb):
+        await self.close()
index d34b6118fdcf72b30dd51edbbd7482810156b83a..e6bec71d6c7dac3905e1841e218b7f13dee23cd8 100644 (file)
@@ -27,6 +27,8 @@ class SubprocessStreamProtocol(streams.FlowControlMixin,
         self._process_exited = False
         self._pipe_fds = []
         self._stdin_closed = self._loop.create_future()
+        self._stdout_closed = self._loop.create_future()
+        self._stderr_closed = self._loop.create_future()
 
     def __repr__(self):
         info = [self.__class__.__name__]
@@ -40,30 +42,35 @@ class SubprocessStreamProtocol(streams.FlowControlMixin,
 
     def connection_made(self, transport):
         self._transport = transport
-
         stdout_transport = transport.get_pipe_transport(1)
         if stdout_transport is not None:
-            self.stdout = streams.StreamReader(limit=self._limit,
-                                               loop=self._loop,
-                                               _asyncio_internal=True)
+            self.stdout = streams.Stream(mode=streams.StreamMode.READ,
+                                         transport=stdout_transport,
+                                         protocol=self,
+                                         limit=self._limit,
+                                         loop=self._loop,
+                                         _asyncio_internal=True)
             self.stdout.set_transport(stdout_transport)
             self._pipe_fds.append(1)
 
         stderr_transport = transport.get_pipe_transport(2)
         if stderr_transport is not None:
-            self.stderr = streams.StreamReader(limit=self._limit,
-                                               loop=self._loop,
-                                               _asyncio_internal=True)
+            self.stderr = streams.Stream(mode=streams.StreamMode.READ,
+                                         transport=stderr_transport,
+                                         protocol=self,
+                                         limit=self._limit,
+                                         loop=self._loop,
+                                         _asyncio_internal=True)
             self.stderr.set_transport(stderr_transport)
             self._pipe_fds.append(2)
 
         stdin_transport = transport.get_pipe_transport(0)
         if stdin_transport is not None:
-            self.stdin = streams.StreamWriter(stdin_transport,
-                                              protocol=self,
-                                              reader=None,
-                                              loop=self._loop,
-                                              _asyncio_internal=True)
+            self.stdin = streams.Stream(mode=streams.StreamMode.WRITE,
+                                        transport=stdin_transport,
+                                        protocol=self,
+                                        loop=self._loop,
+                                        _asyncio_internal=True)
 
     def pipe_data_received(self, fd, data):
         if fd == 1:
@@ -114,6 +121,10 @@ class SubprocessStreamProtocol(streams.FlowControlMixin,
     def _get_close_waiter(self, stream):
         if stream is self.stdin:
             return self._stdin_closed
+        elif stream is self.stdout:
+            return self._stdout_closed
+        elif stream is self.stderr:
+            return self._stderr_closed
 
 
 class Process:
index b5b2e24c5ba4f3427a14b4067145d260bba73c80..61b40ba52a6486d76aebb3a66d9b9528e0fc06ee 100644 (file)
@@ -607,7 +607,7 @@ class IocpProactor:
 
             # ConnectPipe() failed with ERROR_PIPE_BUSY: retry later
             delay = min(delay * 2, CONNECT_PIPE_MAX_DELAY)
-            await tasks.sleep(delay, loop=self._loop)
+            await tasks.sleep(delay)
 
         return windows_utils.PipeHandle(handle)
 
index f6e82eb64ab02545c10b21c96325868db84d5357..c077881511b8ce3b5128f5d7774fd294d11d1be6 100644 (file)
@@ -30,21 +30,27 @@ class AllTest(unittest.TestCase):
             raise NoAll(modname)
         names = {}
         with self.subTest(module=modname):
-            try:
-                exec("from %s import *" % modname, names)
-            except Exception as e:
-                # Include the module name in the exception string
-                self.fail("__all__ failure in {}: {}: {}".format(
-                          modname, e.__class__.__name__, e))
-            if "__builtins__" in names:
-                del names["__builtins__"]
-            if '__annotations__' in names:
-                del names['__annotations__']
-            keys = set(names)
-            all_list = sys.modules[modname].__all__
-            all_set = set(all_list)
-            self.assertCountEqual(all_set, all_list, "in module {}".format(modname))
-            self.assertEqual(keys, all_set, "in module {}".format(modname))
+            with support.check_warnings(
+                ("", DeprecationWarning),
+                ("", ResourceWarning),
+                quiet=True):
+                try:
+                    exec("from %s import *" % modname, names)
+                except Exception as e:
+                    # Include the module name in the exception string
+                    self.fail("__all__ failure in {}: {}: {}".format(
+                              modname, e.__class__.__name__, e))
+                if "__builtins__" in names:
+                    del names["__builtins__"]
+                if '__annotations__' in names:
+                    del names['__annotations__']
+                if "__warningregistry__" in names:
+                    del names["__warningregistry__"]
+                keys = set(names)
+                all_list = sys.modules[modname].__all__
+                all_set = set(all_list)
+                self.assertCountEqual(all_set, all_list, "in module {}".format(modname))
+                self.assertEqual(keys, all_set, "in module {}".format(modname))
 
     def walk_modules(self, basedir, modpath):
         for fn in sorted(os.listdir(basedir)):
index 31018c5c563637b29b8e233fe4a16c917d1fa07b..02a97c60ac1a93fc4ca5afd227a9a79c77daf1bc 100644 (file)
@@ -1152,8 +1152,9 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase):
     @unittest.skipUnless(hasattr(socket, 'AF_INET6'), 'no IPv6 support')
     def test_create_server_ipv6(self):
         async def main():
-            srv = await asyncio.start_server(
-                lambda: None, '::1', 0, loop=self.loop)
+            with self.assertWarns(DeprecationWarning):
+                srv = await asyncio.start_server(
+                    lambda: None, '::1', 0, loop=self.loop)
             try:
                 self.assertGreater(len(srv.sockets), 0)
             finally:
index f24e363ebfcfa326c6e37344f5e33ba36ba4b309..b1531fb9343f5eca04d0881b5390f6e20b9ce713 100644 (file)
@@ -58,9 +58,10 @@ class BaseTestBufferedProtocol(func_tests.FunctionalTestCaseMixin):
             writer.close()
             await writer.wait_closed()
 
-        srv = self.loop.run_until_complete(
-            asyncio.start_server(
-                on_server_client, '127.0.0.1', 0))
+        with self.assertWarns(DeprecationWarning):
+            srv = self.loop.run_until_complete(
+                asyncio.start_server(
+                    on_server_client, '127.0.0.1', 0))
 
         addr = srv.sockets[0].getsockname()
         self.loop.run_until_complete(
index 297a3b3901d631462b14bdd6abc5d783d2089fd6..11c0ce495d5261930ca46aefc31cd0ff6f219511 100644 (file)
@@ -94,7 +94,9 @@ class StreamReaderTests(BaseTest):
     def test_readline(self):
         DATA = b'line1\nline2\nline3'
 
-        stream = asyncio.StreamReader(loop=self.loop, _asyncio_internal=True)
+        stream = asyncio.Stream(mode=asyncio.StreamMode.READ,
+                                loop=self.loop,
+                                _asyncio_internal=True)
         stream.feed_data(DATA)
         stream.feed_eof()
 
index 4e758ad12e600ec22369ecf2e3ba9e0fa9e208b9..0e38e6c8ecd4c24781e073862dc89d5e18f27123 100644 (file)
@@ -46,8 +46,9 @@ class BaseStartServer(func_tests.FunctionalTestCaseMixin):
             async with srv:
                 await srv.serve_forever()
 
-        srv = self.loop.run_until_complete(asyncio.start_server(
-            serve, support.HOSTv4, 0, loop=self.loop, start_serving=False))
+        with self.assertWarns(DeprecationWarning):
+            srv = self.loop.run_until_complete(asyncio.start_server(
+                serve, support.HOSTv4, 0, loop=self.loop, start_serving=False))
 
         self.assertFalse(srv.is_serving())
 
@@ -102,8 +103,9 @@ class SelectorStartServerTests(BaseStartServer, unittest.TestCase):
                 await srv.serve_forever()
 
         with test_utils.unix_socket_path() as addr:
-            srv = self.loop.run_until_complete(asyncio.start_unix_server(
-                serve, addr, loop=self.loop, start_serving=False))
+            with self.assertWarns(DeprecationWarning):
+                srv = self.loop.run_until_complete(asyncio.start_unix_server(
+                    serve, addr, loop=self.loop, start_serving=False))
 
             main_task = self.loop.create_task(main(srv))
 
index 079b25585566b1f3cf0b73b2195f7679bc050f0e..4215abf5d8630bbea71008d2beedc86aea6ebae0 100644 (file)
@@ -649,12 +649,13 @@ class BaseStartTLS(func_tests.FunctionalTestCaseMixin):
                 sock.close()
 
         async def client(addr):
-            reader, writer = await asyncio.open_connection(
-                *addr,
-                ssl=client_sslctx,
-                server_hostname='',
-                loop=self.loop,
-                ssl_handshake_timeout=1.0)
+            with self.assertWarns(DeprecationWarning):
+                reader, writer = await asyncio.open_connection(
+                    *addr,
+                    ssl=client_sslctx,
+                    server_hostname='',
+                    loop=self.loop,
+                    ssl_handshake_timeout=1.0)
 
         with self.tcp_server(server,
                              max_clients=1,
@@ -688,12 +689,13 @@ class BaseStartTLS(func_tests.FunctionalTestCaseMixin):
                 sock.close()
 
         async def client(addr):
-            reader, writer = await asyncio.open_connection(
-                *addr,
-                ssl=client_sslctx,
-                server_hostname='',
-                loop=self.loop,
-                ssl_handshake_timeout=1.0)
+            with self.assertWarns(DeprecationWarning):
+                reader, writer = await asyncio.open_connection(
+                    *addr,
+                    ssl=client_sslctx,
+                    server_hostname='',
+                    loop=self.loop,
+                    ssl_handshake_timeout=1.0)
 
         with self.tcp_server(server,
                              max_clients=1,
@@ -724,11 +726,12 @@ class BaseStartTLS(func_tests.FunctionalTestCaseMixin):
                 sock.close()
 
         async def client(addr):
-            reader, writer = await asyncio.open_connection(
-                *addr,
-                ssl=client_sslctx,
-                server_hostname='',
-                loop=self.loop)
+            with self.assertWarns(DeprecationWarning):
+                reader, writer = await asyncio.open_connection(
+                    *addr,
+                    ssl=client_sslctx,
+                    server_hostname='',
+                    loop=self.loop)
 
             self.assertEqual(await reader.readline(), b'A\n')
             writer.write(b'B')
index fed609816daca8f3c15fc8355ad14078725f15e8..df3d7e7dfa455ce280204cb4dab3b25c6ff2b71e 100644 (file)
@@ -1,6 +1,8 @@
 """Tests for streams.py."""
 
+import contextlib
 import gc
+import io
 import os
 import queue
 import pickle
@@ -16,6 +18,7 @@ except ImportError:
     ssl = None
 
 import asyncio
+from asyncio.streams import _StreamProtocol, _ensure_can_read, _ensure_can_write
 from test.test_asyncio import utils as test_utils
 
 
@@ -23,6 +26,24 @@ def tearDownModule():
     asyncio.set_event_loop_policy(None)
 
 
+class StreamModeTests(unittest.TestCase):
+    def test__ensure_can_read_ok(self):
+        self.assertIsNone(_ensure_can_read(asyncio.StreamMode.READ))
+        self.assertIsNone(_ensure_can_read(asyncio.StreamMode.READWRITE))
+
+    def test__ensure_can_read_fail(self):
+        with self.assertRaisesRegex(RuntimeError, "The stream is write-only"):
+            _ensure_can_read(asyncio.StreamMode.WRITE)
+
+    def test__ensure_can_write_ok(self):
+        self.assertIsNone(_ensure_can_write(asyncio.StreamMode.WRITE))
+        self.assertIsNone(_ensure_can_write(asyncio.StreamMode.READWRITE))
+
+    def test__ensure_can_write_fail(self):
+        with self.assertRaisesRegex(RuntimeError, "The stream is read-only"):
+            _ensure_can_write(asyncio.StreamMode.READ)
+
+
 class StreamTests(test_utils.TestCase):
 
     DATA = b'line1\nline2\nline3\n'
@@ -42,13 +63,15 @@ class StreamTests(test_utils.TestCase):
 
     @mock.patch('asyncio.streams.events')
     def test_ctor_global_loop(self, m_events):
-        stream = asyncio.StreamReader(_asyncio_internal=True)
+        stream = asyncio.Stream(mode=asyncio.StreamMode.READ,
+                                _asyncio_internal=True)
         self.assertIs(stream._loop, m_events.get_event_loop.return_value)
 
     def _basetest_open_connection(self, open_connection_fut):
         messages = []
         self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
-        reader, writer = self.loop.run_until_complete(open_connection_fut)
+        with self.assertWarns(DeprecationWarning):
+            reader, writer = self.loop.run_until_complete(open_connection_fut)
         writer.write(b'GET / HTTP/1.0\r\n\r\n')
         f = reader.readline()
         data = self.loop.run_until_complete(f)
@@ -76,7 +99,9 @@ class StreamTests(test_utils.TestCase):
         messages = []
         self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
         try:
-            reader, writer = self.loop.run_until_complete(open_connection_fut)
+            with self.assertWarns(DeprecationWarning):
+                reader, writer = self.loop.run_until_complete(
+                    open_connection_fut)
         finally:
             asyncio.set_event_loop(None)
         writer.write(b'GET / HTTP/1.0\r\n\r\n')
@@ -112,7 +137,8 @@ class StreamTests(test_utils.TestCase):
     def _basetest_open_connection_error(self, open_connection_fut):
         messages = []
         self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
-        reader, writer = self.loop.run_until_complete(open_connection_fut)
+        with self.assertWarns(DeprecationWarning):
+            reader, writer = self.loop.run_until_complete(open_connection_fut)
         writer._protocol.connection_lost(ZeroDivisionError())
         f = reader.read()
         with self.assertRaises(ZeroDivisionError):
@@ -135,23 +161,26 @@ class StreamTests(test_utils.TestCase):
             self._basetest_open_connection_error(conn_fut)
 
     def test_feed_empty_data(self):
-        stream = asyncio.StreamReader(loop=self.loop,
-                                      _asyncio_internal=True)
+        stream = asyncio.Stream(mode=asyncio.StreamMode.READ,
+                                loop=self.loop,
+                                _asyncio_internal=True)
 
         stream.feed_data(b'')
         self.assertEqual(b'', stream._buffer)
 
     def test_feed_nonempty_data(self):
-        stream = asyncio.StreamReader(loop=self.loop,
-                                      _asyncio_internal=True)
+        stream = asyncio.Stream(mode=asyncio.StreamMode.READ,
+                                loop=self.loop,
+                                _asyncio_internal=True)
 
         stream.feed_data(self.DATA)
         self.assertEqual(self.DATA, stream._buffer)
 
     def test_read_zero(self):
         # Read zero bytes.
-        stream = asyncio.StreamReader(loop=self.loop,
-                                      _asyncio_internal=True)
+        stream = asyncio.Stream(mode=asyncio.StreamMode.READ,
+                                loop=self.loop,
+                                _asyncio_internal=True)
         stream.feed_data(self.DATA)
 
         data = self.loop.run_until_complete(stream.read(0))
@@ -160,8 +189,9 @@ class StreamTests(test_utils.TestCase):
 
     def test_read(self):
         # Read bytes.
-        stream = asyncio.StreamReader(loop=self.loop,
-                                      _asyncio_internal=True)
+        stream = asyncio.Stream(mode=asyncio.StreamMode.READ,
+                                loop=self.loop,
+                                _asyncio_internal=True)
         read_task = asyncio.Task(stream.read(30), loop=self.loop)
 
         def cb():
@@ -174,8 +204,9 @@ class StreamTests(test_utils.TestCase):
 
     def test_read_line_breaks(self):
         # Read bytes without line breaks.
-        stream = asyncio.StreamReader(loop=self.loop,
-                                      _asyncio_internal=True)
+        stream = asyncio.Stream(mode=asyncio.StreamMode.READ,
+                                loop=self.loop,
+                                _asyncio_internal=True)
         stream.feed_data(b'line1')
         stream.feed_data(b'line2')
 
@@ -186,8 +217,9 @@ class StreamTests(test_utils.TestCase):
 
     def test_read_eof(self):
         # Read bytes, stop at eof.
-        stream = asyncio.StreamReader(loop=self.loop,
-                                      _asyncio_internal=True)
+        stream = asyncio.Stream(mode=asyncio.StreamMode.READ,
+                                loop=self.loop,
+                                _asyncio_internal=True)
         read_task = asyncio.Task(stream.read(1024), loop=self.loop)
 
         def cb():
@@ -200,8 +232,9 @@ class StreamTests(test_utils.TestCase):
 
     def test_read_until_eof(self):
         # Read all bytes until eof.
-        stream = asyncio.StreamReader(loop=self.loop,
-                                      _asyncio_internal=True)
+        stream = asyncio.Stream(mode=asyncio.StreamMode.READ,
+                                loop=self.loop,
+                                _asyncio_internal=True)
         read_task = asyncio.Task(stream.read(-1), loop=self.loop)
 
         def cb():
@@ -216,8 +249,9 @@ class StreamTests(test_utils.TestCase):
         self.assertEqual(b'', stream._buffer)
 
     def test_read_exception(self):
-        stream = asyncio.StreamReader(loop=self.loop,
-                                      _asyncio_internal=True)
+        stream = asyncio.Stream(mode=asyncio.StreamMode.READ,
+                                loop=self.loop,
+                                _asyncio_internal=True)
         stream.feed_data(b'line\n')
 
         data = self.loop.run_until_complete(stream.read(2))
@@ -229,16 +263,19 @@ class StreamTests(test_utils.TestCase):
 
     def test_invalid_limit(self):
         with self.assertRaisesRegex(ValueError, 'imit'):
-            asyncio.StreamReader(limit=0, loop=self.loop,
-                                 _asyncio_internal=True)
+            asyncio.Stream(mode=asyncio.StreamMode.READ,
+                           limit=0, loop=self.loop,
+                           _asyncio_internal=True)
 
         with self.assertRaisesRegex(ValueError, 'imit'):
-            asyncio.StreamReader(limit=-1, loop=self.loop,
-                                 _asyncio_internal=True)
+            asyncio.Stream(mode=asyncio.StreamMode.READ,
+                           limit=-1, loop=self.loop,
+                           _asyncio_internal=True)
 
     def test_read_limit(self):
-        stream = asyncio.StreamReader(limit=3, loop=self.loop,
-                                      _asyncio_internal=True)
+        stream = asyncio.Stream(mode=asyncio.StreamMode.READ,
+                                limit=3, loop=self.loop,
+                                _asyncio_internal=True)
         stream.feed_data(b'chunk')
         data = self.loop.run_until_complete(stream.read(5))
         self.assertEqual(b'chunk', data)
@@ -247,8 +284,9 @@ class StreamTests(test_utils.TestCase):
     def test_readline(self):
         # Read one line. 'readline' will need to wait for the data
         # to come from 'cb'
-        stream = asyncio.StreamReader(loop=self.loop,
-                                      _asyncio_internal=True)
+        stream = asyncio.Stream(mode=asyncio.StreamMode.READ,
+                                loop=self.loop,
+                                _asyncio_internal=True)
         stream.feed_data(b'chunk1 ')
         read_task = asyncio.Task(stream.readline(), loop=self.loop)
 
@@ -263,11 +301,12 @@ class StreamTests(test_utils.TestCase):
         self.assertEqual(b' chunk4', stream._buffer)
 
     def test_readline_limit_with_existing_data(self):
-        # Read one line. The data is in StreamReader's buffer
+        # Read one line. The data is in Stream's buffer
         # before the event loop is run.
 
-        stream = asyncio.StreamReader(limit=3, loop=self.loop,
-                                      _asyncio_internal=True)
+        stream = asyncio.Stream(mode=asyncio.StreamMode.READ,
+                                limit=3, loop=self.loop,
+                                _asyncio_internal=True)
         stream.feed_data(b'li')
         stream.feed_data(b'ne1\nline2\n')
 
@@ -276,8 +315,9 @@ class StreamTests(test_utils.TestCase):
         # The buffer should contain the remaining data after exception
         self.assertEqual(b'line2\n', stream._buffer)
 
-        stream = asyncio.StreamReader(limit=3, loop=self.loop,
-                                      _asyncio_internal=True)
+        stream = asyncio.Stream(mode=asyncio.StreamMode.READ,
+                                limit=3, loop=self.loop,
+                                _asyncio_internal=True)
         stream.feed_data(b'li')
         stream.feed_data(b'ne1')
         stream.feed_data(b'li')
@@ -292,8 +332,9 @@ class StreamTests(test_utils.TestCase):
         self.assertEqual(b'', stream._buffer)
 
     def test_at_eof(self):
-        stream = asyncio.StreamReader(loop=self.loop,
-                                      _asyncio_internal=True)
+        stream = asyncio.Stream(mode=asyncio.StreamMode.READ,
+                                loop=self.loop,
+                                _asyncio_internal=True)
         self.assertFalse(stream.at_eof())
 
         stream.feed_data(b'some data\n')
@@ -308,11 +349,12 @@ class StreamTests(test_utils.TestCase):
         self.assertTrue(stream.at_eof())
 
     def test_readline_limit(self):
-        # Read one line. StreamReaders are fed with data after
+        # Read one line. Streams are fed with data after
         # their 'readline' methods are called.
 
-        stream = asyncio.StreamReader(limit=7, loop=self.loop,
-                                      _asyncio_internal=True)
+        stream = asyncio.Stream(mode=asyncio.StreamMode.READ,
+                                limit=7, loop=self.loop,
+                                _asyncio_internal=True)
         def cb():
             stream.feed_data(b'chunk1')
             stream.feed_data(b'chunk2')
@@ -326,8 +368,9 @@ class StreamTests(test_utils.TestCase):
         # a ValueError it should be empty.
         self.assertEqual(b'', stream._buffer)
 
-        stream = asyncio.StreamReader(limit=7, loop=self.loop,
-                                      _asyncio_internal=True)
+        stream = asyncio.Stream(mode=asyncio.StreamMode.READ,
+                                limit=7, loop=self.loop,
+                                _asyncio_internal=True)
         def cb():
             stream.feed_data(b'chunk1')
             stream.feed_data(b'chunk2\n')
@@ -340,8 +383,9 @@ class StreamTests(test_utils.TestCase):
         self.assertEqual(b'chunk3\n', stream._buffer)
 
         # check strictness of the limit
-        stream = asyncio.StreamReader(limit=7, loop=self.loop,
-                                      _asyncio_internal=True)
+        stream = asyncio.Stream(mode=asyncio.StreamMode.READ,
+                                limit=7, loop=self.loop,
+                                _asyncio_internal=True)
         stream.feed_data(b'1234567\n')
         line = self.loop.run_until_complete(stream.readline())
         self.assertEqual(b'1234567\n', line)
@@ -360,8 +404,9 @@ class StreamTests(test_utils.TestCase):
     def test_readline_nolimit_nowait(self):
         # All needed data for the first 'readline' call will be
         # in the buffer.
-        stream = asyncio.StreamReader(loop=self.loop,
-                                      _asyncio_internal=True)
+        stream = asyncio.Stream(mode=asyncio.StreamMode.READ,
+                                loop=self.loop,
+                                _asyncio_internal=True)
         stream.feed_data(self.DATA[:6])
         stream.feed_data(self.DATA[6:])
 
@@ -371,8 +416,9 @@ class StreamTests(test_utils.TestCase):
         self.assertEqual(b'line2\nline3\n', stream._buffer)
 
     def test_readline_eof(self):
-        stream = asyncio.StreamReader(loop=self.loop,
-                                      _asyncio_internal=True)
+        stream = asyncio.Stream(mode=asyncio.StreamMode.READ,
+                                loop=self.loop,
+                                _asyncio_internal=True)
         stream.feed_data(b'some data')
         stream.feed_eof()
 
@@ -380,16 +426,18 @@ class StreamTests(test_utils.TestCase):
         self.assertEqual(b'some data', line)
 
     def test_readline_empty_eof(self):
-        stream = asyncio.StreamReader(loop=self.loop,
-                                      _asyncio_internal=True)
+        stream = asyncio.Stream(mode=asyncio.StreamMode.READ,
+                                loop=self.loop,
+                                _asyncio_internal=True)
         stream.feed_eof()
 
         line = self.loop.run_until_complete(stream.readline())
         self.assertEqual(b'', line)
 
     def test_readline_read_byte_count(self):
-        stream = asyncio.StreamReader(loop=self.loop,
-                                      _asyncio_internal=True)
+        stream = asyncio.Stream(mode=asyncio.StreamMode.READ,
+                                loop=self.loop,
+                                _asyncio_internal=True)
         stream.feed_data(self.DATA)
 
         self.loop.run_until_complete(stream.readline())
@@ -400,8 +448,9 @@ class StreamTests(test_utils.TestCase):
         self.assertEqual(b'ine3\n', stream._buffer)
 
     def test_readline_exception(self):
-        stream = asyncio.StreamReader(loop=self.loop,
-                                      _asyncio_internal=True)
+        stream = asyncio.Stream(mode=asyncio.StreamMode.READ,
+                                loop=self.loop,
+                                _asyncio_internal=True)
         stream.feed_data(b'line\n')
 
         data = self.loop.run_until_complete(stream.readline())
@@ -413,14 +462,16 @@ class StreamTests(test_utils.TestCase):
         self.assertEqual(b'', stream._buffer)
 
     def test_readuntil_separator(self):
-        stream = asyncio.StreamReader(loop=self.loop,
-                                      _asyncio_internal=True)
+        stream = asyncio.Stream(mode=asyncio.StreamMode.READ,
+                                loop=self.loop,
+                                _asyncio_internal=True)
         with self.assertRaisesRegex(ValueError, 'Separator should be'):
             self.loop.run_until_complete(stream.readuntil(separator=b''))
 
     def test_readuntil_multi_chunks(self):
-        stream = asyncio.StreamReader(loop=self.loop,
-                                      _asyncio_internal=True)
+        stream = asyncio.Stream(mode=asyncio.StreamMode.READ,
+                                loop=self.loop,
+                                _asyncio_internal=True)
 
         stream.feed_data(b'lineAAA')
         data = self.loop.run_until_complete(stream.readuntil(separator=b'AAA'))
@@ -438,8 +489,9 @@ class StreamTests(test_utils.TestCase):
         self.assertEqual(b'xxx', stream._buffer)
 
     def test_readuntil_multi_chunks_1(self):
-        stream = asyncio.StreamReader(loop=self.loop,
-                                      _asyncio_internal=True)
+        stream = asyncio.Stream(mode=asyncio.StreamMode.READ,
+                                loop=self.loop,
+                                _asyncio_internal=True)
 
         stream.feed_data(b'QWEaa')
         stream.feed_data(b'XYaa')
@@ -474,8 +526,9 @@ class StreamTests(test_utils.TestCase):
         self.assertEqual(b'', stream._buffer)
 
     def test_readuntil_eof(self):
-        stream = asyncio.StreamReader(loop=self.loop,
-                                      _asyncio_internal=True)
+        stream = asyncio.Stream(mode=asyncio.StreamMode.READ,
+                                loop=self.loop,
+                                _asyncio_internal=True)
         stream.feed_data(b'some dataAA')
         stream.feed_eof()
 
@@ -486,8 +539,9 @@ class StreamTests(test_utils.TestCase):
         self.assertEqual(b'', stream._buffer)
 
     def test_readuntil_limit_found_sep(self):
-        stream = asyncio.StreamReader(loop=self.loop, limit=3,
-                                      _asyncio_internal=True)
+        stream = asyncio.Stream(mode=asyncio.StreamMode.READ,
+                                loop=self.loop, limit=3,
+                                _asyncio_internal=True)
         stream.feed_data(b'some dataAA')
 
         with self.assertRaisesRegex(asyncio.LimitOverrunError,
@@ -505,8 +559,9 @@ class StreamTests(test_utils.TestCase):
 
     def test_readexactly_zero_or_less(self):
         # Read exact number of bytes (zero or less).
-        stream = asyncio.StreamReader(loop=self.loop,
-                                      _asyncio_internal=True)
+        stream = asyncio.Stream(mode=asyncio.StreamMode.READ,
+                                loop=self.loop,
+                                _asyncio_internal=True)
         stream.feed_data(self.DATA)
 
         data = self.loop.run_until_complete(stream.readexactly(0))
@@ -519,8 +574,9 @@ class StreamTests(test_utils.TestCase):
 
     def test_readexactly(self):
         # Read exact number of bytes.
-        stream = asyncio.StreamReader(loop=self.loop,
-                                      _asyncio_internal=True)
+        stream = asyncio.Stream(mode=asyncio.StreamMode.READ,
+                                loop=self.loop,
+                                _asyncio_internal=True)
 
         n = 2 * len(self.DATA)
         read_task = asyncio.Task(stream.readexactly(n), loop=self.loop)
@@ -536,8 +592,9 @@ class StreamTests(test_utils.TestCase):
         self.assertEqual(self.DATA, stream._buffer)
 
     def test_readexactly_limit(self):
-        stream = asyncio.StreamReader(limit=3, loop=self.loop,
-                                      _asyncio_internal=True)
+        stream = asyncio.Stream(mode=asyncio.StreamMode.READ,
+                                limit=3, loop=self.loop,
+                                _asyncio_internal=True)
         stream.feed_data(b'chunk')
         data = self.loop.run_until_complete(stream.readexactly(5))
         self.assertEqual(b'chunk', data)
@@ -545,8 +602,9 @@ class StreamTests(test_utils.TestCase):
 
     def test_readexactly_eof(self):
         # Read exact number of bytes (eof).
-        stream = asyncio.StreamReader(loop=self.loop,
-                                      _asyncio_internal=True)
+        stream = asyncio.Stream(mode=asyncio.StreamMode.READ,
+                                loop=self.loop,
+                                _asyncio_internal=True)
         n = 2 * len(self.DATA)
         read_task = asyncio.Task(stream.readexactly(n), loop=self.loop)
 
@@ -564,8 +622,9 @@ class StreamTests(test_utils.TestCase):
         self.assertEqual(b'', stream._buffer)
 
     def test_readexactly_exception(self):
-        stream = asyncio.StreamReader(loop=self.loop,
-                                      _asyncio_internal=True)
+        stream = asyncio.Stream(mode=asyncio.StreamMode.READ,
+                                loop=self.loop,
+                                _asyncio_internal=True)
         stream.feed_data(b'line\n')
 
         data = self.loop.run_until_complete(stream.readexactly(2))
@@ -576,8 +635,9 @@ class StreamTests(test_utils.TestCase):
             ValueError, self.loop.run_until_complete, stream.readexactly(2))
 
     def test_exception(self):
-        stream = asyncio.StreamReader(loop=self.loop,
-                                      _asyncio_internal=True)
+        stream = asyncio.Stream(mode=asyncio.StreamMode.READ,
+                                loop=self.loop,
+                                _asyncio_internal=True)
         self.assertIsNone(stream.exception())
 
         exc = ValueError()
@@ -585,8 +645,9 @@ class StreamTests(test_utils.TestCase):
         self.assertIs(stream.exception(), exc)
 
     def test_exception_waiter(self):
-        stream = asyncio.StreamReader(loop=self.loop,
-                                      _asyncio_internal=True)
+        stream = asyncio.Stream(mode=asyncio.StreamMode.READ,
+                                loop=self.loop,
+                                _asyncio_internal=True)
 
         async def set_err():
             stream.set_exception(ValueError())
@@ -599,8 +660,9 @@ class StreamTests(test_utils.TestCase):
         self.assertRaises(ValueError, t1.result)
 
     def test_exception_cancel(self):
-        stream = asyncio.StreamReader(loop=self.loop,
-                                      _asyncio_internal=True)
+        stream = asyncio.Stream(mode=asyncio.StreamMode.READ,
+                                loop=self.loop,
+                                _asyncio_internal=True)
 
         t = asyncio.Task(stream.readline(), loop=self.loop)
         test_utils.run_briefly(self.loop)
@@ -655,8 +717,9 @@ class StreamTests(test_utils.TestCase):
                     self.server = None
 
         async def client(addr):
-            reader, writer = await asyncio.open_connection(
-                *addr, loop=self.loop)
+            with self.assertWarns(DeprecationWarning):
+                reader, writer = await asyncio.open_connection(
+                    *addr, loop=self.loop)
             # send a line
             writer.write(b"hello world!\n")
             # read it back
@@ -670,7 +733,8 @@ class StreamTests(test_utils.TestCase):
 
         # test the server variant with a coroutine as client handler
         server = MyServer(self.loop)
-        addr = server.start()
+        with self.assertWarns(DeprecationWarning):
+            addr = server.start()
         msg = self.loop.run_until_complete(asyncio.Task(client(addr),
                                                         loop=self.loop))
         server.stop()
@@ -678,7 +742,8 @@ class StreamTests(test_utils.TestCase):
 
         # test the server variant with a callback as client handler
         server = MyServer(self.loop)
-        addr = server.start_callback()
+        with self.assertWarns(DeprecationWarning):
+            addr = server.start_callback()
         msg = self.loop.run_until_complete(asyncio.Task(client(addr),
                                                         loop=self.loop))
         server.stop()
@@ -726,8 +791,9 @@ class StreamTests(test_utils.TestCase):
                     self.server = None
 
         async def client(path):
-            reader, writer = await asyncio.open_unix_connection(
-                path, loop=self.loop)
+            with self.assertWarns(DeprecationWarning):
+                reader, writer = await asyncio.open_unix_connection(
+                    path, loop=self.loop)
             # send a line
             writer.write(b"hello world!\n")
             # read it back
@@ -742,7 +808,8 @@ class StreamTests(test_utils.TestCase):
         # test the server variant with a coroutine as client handler
         with test_utils.unix_socket_path() as path:
             server = MyServer(self.loop, path)
-            server.start()
+            with self.assertWarns(DeprecationWarning):
+                server.start()
             msg = self.loop.run_until_complete(asyncio.Task(client(path),
                                                             loop=self.loop))
             server.stop()
@@ -751,7 +818,8 @@ class StreamTests(test_utils.TestCase):
         # test the server variant with a callback as client handler
         with test_utils.unix_socket_path() as path:
             server = MyServer(self.loop, path)
-            server.start_callback()
+            with self.assertWarns(DeprecationWarning):
+                server.start_callback()
             msg = self.loop.run_until_complete(asyncio.Task(client(path),
                                                             loop=self.loop))
             server.stop()
@@ -763,7 +831,7 @@ class StreamTests(test_utils.TestCase):
     def test_read_all_from_pipe_reader(self):
         # See asyncio issue 168.  This test is derived from the example
         # subprocess_attach_read_pipe.py, but we configure the
-        # StreamReader's limit so that twice it is less than the size
+        # Stream's limit so that twice it is less than the size
         # of the data writter.  Also we must explicitly attach a child
         # watcher to the event loop.
 
@@ -777,10 +845,11 @@ os.close(fd)
         args = [sys.executable, '-c', code, str(wfd)]
 
         pipe = open(rfd, 'rb', 0)
-        reader = asyncio.StreamReader(loop=self.loop, limit=1,
-                                      _asyncio_internal=True)
-        protocol = asyncio.StreamReaderProtocol(reader, loop=self.loop,
-                                                _asyncio_internal=True)
+        stream = asyncio.Stream(mode=asyncio.StreamMode.READ,
+                                loop=self.loop, limit=1,
+                                _asyncio_internal=True)
+        protocol = _StreamProtocol(stream, loop=self.loop,
+                                   _asyncio_internal=True)
         transport, _ = self.loop.run_until_complete(
             self.loop.connect_read_pipe(lambda: protocol, pipe))
 
@@ -797,29 +866,30 @@ os.close(fd)
             asyncio.set_child_watcher(None)
 
         os.close(wfd)
-        data = self.loop.run_until_complete(reader.read(-1))
+        data = self.loop.run_until_complete(stream.read(-1))
         self.assertEqual(data, b'data')
 
     def test_streamreader_constructor(self):
         self.addCleanup(asyncio.set_event_loop, None)
         asyncio.set_event_loop(self.loop)
 
-        # asyncio issue #184: Ensure that StreamReaderProtocol constructor
+        # asyncio issue #184: Ensure that _StreamProtocol constructor
         # retrieves the current loop if the loop parameter is not set
-        reader = asyncio.StreamReader(_asyncio_internal=True)
+        reader = asyncio.Stream(mode=asyncio.StreamMode.READ,
+                                _asyncio_internal=True)
         self.assertIs(reader._loop, self.loop)
 
     def test_streamreaderprotocol_constructor(self):
         self.addCleanup(asyncio.set_event_loop, None)
         asyncio.set_event_loop(self.loop)
 
-        # asyncio issue #184: Ensure that StreamReaderProtocol constructor
+        # asyncio issue #184: Ensure that _StreamProtocol constructor
         # retrieves the current loop if the loop parameter is not set
-        reader = mock.Mock()
-        protocol = asyncio.StreamReaderProtocol(reader, _asyncio_internal=True)
+        stream = mock.Mock()
+        protocol = _StreamProtocol(stream, _asyncio_internal=True)
         self.assertIs(protocol._loop, self.loop)
 
-    def test_drain_raises(self):
+    def test_drain_raises_deprecated(self):
         # See http://bugs.python.org/issue25441
 
         # This test should not use asyncio for the mock server; the
@@ -833,15 +903,16 @@ os.close(fd)
 
         def server():
             # Runs in a separate thread.
-            with socket.create_server(('localhost', 0)) as sock:
+            with socket.create_server(('127.0.0.1', 0)) as sock:
                 addr = sock.getsockname()
                 q.put(addr)
                 clt, _ = sock.accept()
                 clt.close()
 
         async def client(host, port):
-            reader, writer = await asyncio.open_connection(
-                host, port, loop=self.loop)
+            with self.assertWarns(DeprecationWarning):
+                reader, writer = await asyncio.open_connection(
+                    host, port, loop=self.loop)
 
             while True:
                 writer.write(b"foo\n")
@@ -863,55 +934,106 @@ os.close(fd)
         thread.join()
         self.assertEqual([], messages)
 
+    def test_drain_raises(self):
+        # See http://bugs.python.org/issue25441
+
+        # This test should not use asyncio for the mock server; the
+        # whole point of the test is to test for a bug in drain()
+        # where it never gives up the event loop but the socket is
+        # closed on the  server side.
+
+        messages = []
+        self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
+        q = queue.Queue()
+
+        def server():
+            # Runs in a separate thread.
+            with socket.create_server(('localhost', 0)) as sock:
+                addr = sock.getsockname()
+                q.put(addr)
+                clt, _ = sock.accept()
+                clt.close()
+
+        async def client(host, port):
+            stream = await asyncio.connect(host, port)
+
+            while True:
+                stream.write(b"foo\n")
+                await stream.drain()
+
+        # Start the server thread and wait for it to be listening.
+        thread = threading.Thread(target=server)
+        thread.setDaemon(True)
+        thread.start()
+        addr = q.get()
+
+        # Should not be stuck in an infinite loop.
+        with self.assertRaises((ConnectionResetError, ConnectionAbortedError,
+                                BrokenPipeError)):
+            self.loop.run_until_complete(client(*addr))
+
+        # Clean up the thread.  (Only on success; on failure, it may
+        # be stuck in accept().)
+        thread.join()
+        self.assertEqual([], messages)
+
     def test___repr__(self):
-        stream = asyncio.StreamReader(loop=self.loop,
-                                      _asyncio_internal=True)
-        self.assertEqual("<StreamReader>", repr(stream))
+        stream = asyncio.Stream(mode=asyncio.StreamMode.READ,
+                                loop=self.loop,
+                                _asyncio_internal=True)
+        self.assertEqual("<Stream mode=StreamMode.READ>", repr(stream))
 
     def test___repr__nondefault_limit(self):
-        stream = asyncio.StreamReader(loop=self.loop, limit=123,
-                                      _asyncio_internal=True)
-        self.assertEqual("<StreamReader limit=123>", repr(stream))
+        stream = asyncio.Stream(mode=asyncio.StreamMode.READ,
+                                loop=self.loop, limit=123,
+                                _asyncio_internal=True)
+        self.assertEqual("<Stream mode=StreamMode.READ limit=123>", repr(stream))
 
     def test___repr__eof(self):
-        stream = asyncio.StreamReader(loop=self.loop,
-                                      _asyncio_internal=True)
+        stream = asyncio.Stream(mode=asyncio.StreamMode.READ,
+                                loop=self.loop,
+                                _asyncio_internal=True)
         stream.feed_eof()
-        self.assertEqual("<StreamReader eof>", repr(stream))
+        self.assertEqual("<Stream mode=StreamMode.READ eof>", repr(stream))
 
     def test___repr__data(self):
-        stream = asyncio.StreamReader(loop=self.loop,
-                                      _asyncio_internal=True)
+        stream = asyncio.Stream(mode=asyncio.StreamMode.READ,
+                                loop=self.loop,
+                                _asyncio_internal=True)
         stream.feed_data(b'data')
-        self.assertEqual("<StreamReader 4 bytes>", repr(stream))
+        self.assertEqual("<Stream mode=StreamMode.READ 4 bytes>", repr(stream))
 
     def test___repr__exception(self):
-        stream = asyncio.StreamReader(loop=self.loop,
-                                      _asyncio_internal=True)
+        stream = asyncio.Stream(mode=asyncio.StreamMode.READ,
+                                loop=self.loop,
+                                _asyncio_internal=True)
         exc = RuntimeError()
         stream.set_exception(exc)
-        self.assertEqual("<StreamReader exception=RuntimeError()>",
+        self.assertEqual("<Stream mode=StreamMode.READ exception=RuntimeError()>",
                          repr(stream))
 
     def test___repr__waiter(self):
-        stream = asyncio.StreamReader(loop=self.loop,
-                                      _asyncio_internal=True)
+        stream = asyncio.Stream(mode=asyncio.StreamMode.READ,
+                                loop=self.loop,
+                                _asyncio_internal=True)
         stream._waiter = asyncio.Future(loop=self.loop)
         self.assertRegex(
             repr(stream),
-            r"<StreamReader waiter=<Future pending[\S ]*>>")
+            r"<Stream .+ waiter=<Future pending[\S ]*>>")
         stream._waiter.set_result(None)
         self.loop.run_until_complete(stream._waiter)
         stream._waiter = None
-        self.assertEqual("<StreamReader>", repr(stream))
+        self.assertEqual("<Stream mode=StreamMode.READ>", repr(stream))
 
     def test___repr__transport(self):
-        stream = asyncio.StreamReader(loop=self.loop,
-                                      _asyncio_internal=True)
+        stream = asyncio.Stream(mode=asyncio.StreamMode.READ,
+                                loop=self.loop,
+                                _asyncio_internal=True)
         stream._transport = mock.Mock()
         stream._transport.__repr__ = mock.Mock()
         stream._transport.__repr__.return_value = "<Transport>"
-        self.assertEqual("<StreamReader transport=<Transport>>", repr(stream))
+        self.assertEqual("<Stream mode=StreamMode.READ transport=<Transport>>",
+                         repr(stream))
 
     def test_IncompleteReadError_pickleable(self):
         e = asyncio.IncompleteReadError(b'abc', 10)
@@ -930,10 +1052,11 @@ os.close(fd)
                 self.assertEqual(str(e), str(e2))
                 self.assertEqual(e.consumed, e2.consumed)
 
-    def test_wait_closed_on_close(self):
+    def test_wait_closed_on_close_deprecated(self):
         with test_utils.run_test_server() as httpd:
-            rd, wr = self.loop.run_until_complete(
-                asyncio.open_connection(*httpd.address, loop=self.loop))
+            with self.assertWarns(DeprecationWarning):
+                rd, wr = self.loop.run_until_complete(
+                    asyncio.open_connection(*httpd.address, loop=self.loop))
 
             wr.write(b'GET / HTTP/1.0\r\n\r\n')
             f = rd.readline()
@@ -947,10 +1070,28 @@ os.close(fd)
             self.assertTrue(wr.is_closing())
             self.loop.run_until_complete(wr.wait_closed())
 
-    def test_wait_closed_on_close_with_unread_data(self):
+    def test_wait_closed_on_close(self):
         with test_utils.run_test_server() as httpd:
-            rd, wr = self.loop.run_until_complete(
-                asyncio.open_connection(*httpd.address, loop=self.loop))
+            stream = self.loop.run_until_complete(
+                asyncio.connect(*httpd.address))
+
+            stream.write(b'GET / HTTP/1.0\r\n\r\n')
+            f = stream.readline()
+            data = self.loop.run_until_complete(f)
+            self.assertEqual(data, b'HTTP/1.0 200 OK\r\n')
+            f = stream.read()
+            data = self.loop.run_until_complete(f)
+            self.assertTrue(data.endswith(b'\r\n\r\nTest message'))
+            self.assertFalse(stream.is_closing())
+            stream.close()
+            self.assertTrue(stream.is_closing())
+            self.loop.run_until_complete(stream.wait_closed())
+
+    def test_wait_closed_on_close_with_unread_data_deprecated(self):
+        with test_utils.run_test_server() as httpd:
+            with self.assertWarns(DeprecationWarning):
+                rd, wr = self.loop.run_until_complete(
+                    asyncio.open_connection(*httpd.address, loop=self.loop))
 
             wr.write(b'GET / HTTP/1.0\r\n\r\n')
             f = rd.readline()
@@ -959,32 +1100,44 @@ os.close(fd)
             wr.close()
             self.loop.run_until_complete(wr.wait_closed())
 
+    def test_wait_closed_on_close_with_unread_data(self):
+        with test_utils.run_test_server() as httpd:
+            stream = self.loop.run_until_complete(
+                asyncio.connect(*httpd.address))
+
+            stream.write(b'GET / HTTP/1.0\r\n\r\n')
+            f = stream.readline()
+            data = self.loop.run_until_complete(f)
+            self.assertEqual(data, b'HTTP/1.0 200 OK\r\n')
+            stream.close()
+            self.loop.run_until_complete(stream.wait_closed())
+
     def test_del_stream_before_sock_closing(self):
         messages = []
         self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
 
-        with test_utils.run_test_server() as httpd:
-            rd, wr = self.loop.run_until_complete(
-                asyncio.open_connection(*httpd.address, loop=self.loop))
-            sock = wr.get_extra_info('socket')
-            self.assertNotEqual(sock.fileno(), -1)
+        async def test():
 
-            wr.write(b'GET / HTTP/1.0\r\n\r\n')
-            f = rd.readline()
-            data = self.loop.run_until_complete(f)
-            self.assertEqual(data, b'HTTP/1.0 200 OK\r\n')
+            with test_utils.run_test_server() as httpd:
+                stream = await asyncio.connect(*httpd.address)
+                sock = stream.get_extra_info('socket')
+                self.assertNotEqual(sock.fileno(), -1)
 
-            # drop refs to reader/writer
-            del rd
-            del wr
-            gc.collect()
-            # make a chance to close the socket
-            test_utils.run_briefly(self.loop)
+                await stream.write(b'GET / HTTP/1.0\r\n\r\n')
+                data = await stream.readline()
+                self.assertEqual(data, b'HTTP/1.0 200 OK\r\n')
 
-            self.assertEqual(1, len(messages))
-            self.assertEqual(sock.fileno(), -1)
+                # drop refs to reader/writer
+                del stream
+                gc.collect()
+                # make a chance to close the socket
+                await asyncio.sleep(0)
 
-        self.assertEqual(1, len(messages))
+                self.assertEqual(1, len(messages), messages)
+                self.assertEqual(sock.fileno(), -1)
+
+        self.loop.run_until_complete(test())
+        self.assertEqual(1, len(messages), messages)
         self.assertEqual('An open stream object is being garbage '
                          'collected; call "stream.close()" explicitly.',
                          messages[0]['message'])
@@ -994,11 +1147,12 @@ os.close(fd)
         self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
 
         with test_utils.run_test_server() as httpd:
-            rd = asyncio.StreamReader(loop=self.loop,
-                                      _asyncio_internal=True)
-            pr = asyncio.StreamReaderProtocol(rd, loop=self.loop,
-                                              _asyncio_internal=True)
-            del rd
+            stream = asyncio.Stream(mode=asyncio.StreamMode.READ,
+                                    loop=self.loop,
+                                    _asyncio_internal=True)
+            pr = _StreamProtocol(stream, loop=self.loop,
+                                 _asyncio_internal=True)
+            del stream
             gc.collect()
             tr, _ = self.loop.run_until_complete(
                 self.loop.create_connection(
@@ -1015,14 +1169,14 @@ os.close(fd)
 
     def test_async_writer_api(self):
         async def inner(httpd):
-            rd, wr = await asyncio.open_connection(*httpd.address)
+            stream = await asyncio.connect(*httpd.address)
 
-            await wr.write(b'GET / HTTP/1.0\r\n\r\n')
-            data = await rd.readline()
+            await stream.write(b'GET / HTTP/1.0\r\n\r\n')
+            data = await stream.readline()
             self.assertEqual(data, b'HTTP/1.0 200 OK\r\n')
-            data = await rd.read()
+            data = await stream.read()
             self.assertTrue(data.endswith(b'\r\n\r\nTest message'))
-            await wr.close()
+            await stream.close()
 
         messages = []
         self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
@@ -1032,18 +1186,18 @@ os.close(fd)
 
         self.assertEqual(messages, [])
 
-    def test_async_writer_api(self):
+    def test_async_writer_api_exception_after_close(self):
         async def inner(httpd):
-            rd, wr = await asyncio.open_connection(*httpd.address)
+            stream = await asyncio.connect(*httpd.address)
 
-            await wr.write(b'GET / HTTP/1.0\r\n\r\n')
-            data = await rd.readline()
+            await stream.write(b'GET / HTTP/1.0\r\n\r\n')
+            data = await stream.readline()
             self.assertEqual(data, b'HTTP/1.0 200 OK\r\n')
-            data = await rd.read()
+            data = await stream.read()
             self.assertTrue(data.endswith(b'\r\n\r\nTest message'))
-            wr.close()
+            stream.close()
             with self.assertRaises(ConnectionResetError):
-                await wr.write(b'data')
+                await stream.write(b'data')
 
         messages = []
         self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
@@ -1059,11 +1213,13 @@ os.close(fd)
         self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
 
         with test_utils.run_test_server() as httpd:
-            rd, wr = self.loop.run_until_complete(
-                asyncio.open_connection(*httpd.address,
-                                        loop=self.loop))
+            with self.assertWarns(DeprecationWarning):
+                rd, wr = self.loop.run_until_complete(
+                    asyncio.open_connection(*httpd.address,
+                                            loop=self.loop))
 
-            f = wr.close()
+            wr.close()
+            f = wr.wait_closed()
             self.loop.run_until_complete(f)
             assert rd.at_eof()
             f = rd.read()
@@ -1074,22 +1230,514 @@ os.close(fd)
 
     def test_stream_reader_create_warning(self):
         with self.assertWarns(DeprecationWarning):
-            asyncio.StreamReader(loop=self.loop)
-
-    def test_stream_reader_protocol_create_warning(self):
-        reader = asyncio.StreamReader(loop=self.loop,
-                                      _asyncio_internal=True)
-        with self.assertWarns(DeprecationWarning):
-            asyncio.StreamReaderProtocol(reader, loop=self.loop)
+            asyncio.StreamReader
 
     def test_stream_writer_create_warning(self):
-        reader = asyncio.StreamReader(loop=self.loop,
-                                      _asyncio_internal=True)
-        proto = asyncio.StreamReaderProtocol(reader, loop=self.loop,
-                                             _asyncio_internal=True)
         with self.assertWarns(DeprecationWarning):
-            asyncio.StreamWriter('transport', proto, reader, self.loop)
+            asyncio.StreamWriter
+
+    def test_stream_reader_forbidden_ops(self):
+        async def inner():
+            stream = asyncio.Stream(mode=asyncio.StreamMode.READ,
+                                    _asyncio_internal=True)
+            with self.assertRaisesRegex(RuntimeError, "The stream is read-only"):
+                await stream.write(b'data')
+            with self.assertRaisesRegex(RuntimeError, "The stream is read-only"):
+                await stream.writelines([b'data', b'other'])
+            with self.assertRaisesRegex(RuntimeError, "The stream is read-only"):
+                stream.write_eof()
+            with self.assertRaisesRegex(RuntimeError, "The stream is read-only"):
+                await stream.drain()
+
+        self.loop.run_until_complete(inner())
+
+    def test_stream_writer_forbidden_ops(self):
+        async def inner():
+            stream = asyncio.Stream(mode=asyncio.StreamMode.WRITE,
+                                    _asyncio_internal=True)
+            with self.assertRaisesRegex(RuntimeError, "The stream is write-only"):
+                stream.feed_data(b'data')
+            with self.assertRaisesRegex(RuntimeError, "The stream is write-only"):
+                await stream.readline()
+            with self.assertRaisesRegex(RuntimeError, "The stream is write-only"):
+                await stream.readuntil()
+            with self.assertRaisesRegex(RuntimeError, "The stream is write-only"):
+                await stream.read()
+            with self.assertRaisesRegex(RuntimeError, "The stream is write-only"):
+                await stream.readexactly(10)
+            with self.assertRaisesRegex(RuntimeError, "The stream is write-only"):
+                async for chunk in stream:
+                    pass
+
+        self.loop.run_until_complete(inner())
+
+    def _basetest_connect(self, stream):
+        messages = []
+        self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
+
+        stream.write(b'GET / HTTP/1.0\r\n\r\n')
+        f = stream.readline()
+        data = self.loop.run_until_complete(f)
+        self.assertEqual(data, b'HTTP/1.0 200 OK\r\n')
+        f = stream.read()
+        data = self.loop.run_until_complete(f)
+        self.assertTrue(data.endswith(b'\r\n\r\nTest message'))
+        stream.close()
+        self.loop.run_until_complete(stream.wait_closed())
+
+        self.assertEqual([], messages)
+
+    def test_connect(self):
+        with test_utils.run_test_server() as httpd:
+            stream = self.loop.run_until_complete(
+                asyncio.connect(*httpd.address))
+            self.assertFalse(stream.is_server_side())
+            self._basetest_connect(stream)
+
+    @support.skip_unless_bind_unix_socket
+    def test_connect_unix(self):
+        with test_utils.run_test_unix_server() as httpd:
+            stream = self.loop.run_until_complete(
+                asyncio.connect_unix(httpd.address))
+            self._basetest_connect(stream)
+
+    def test_stream_async_context_manager(self):
+        async def test(httpd):
+            stream = await asyncio.connect(*httpd.address)
+            async with stream:
+                await stream.write(b'GET / HTTP/1.0\r\n\r\n')
+                data = await stream.readline()
+                self.assertEqual(data, b'HTTP/1.0 200 OK\r\n')
+                data = await stream.read()
+                self.assertTrue(data.endswith(b'\r\n\r\nTest message'))
+            self.assertTrue(stream.is_closing())
+
+        with test_utils.run_test_server() as httpd:
+            self.loop.run_until_complete(test(httpd))
+
+    def test_connect_async_context_manager(self):
+        async def test(httpd):
+            async with asyncio.connect(*httpd.address) as stream:
+                await stream.write(b'GET / HTTP/1.0\r\n\r\n')
+                data = await stream.readline()
+                self.assertEqual(data, b'HTTP/1.0 200 OK\r\n')
+                data = await stream.read()
+                self.assertTrue(data.endswith(b'\r\n\r\nTest message'))
+            self.assertTrue(stream.is_closing())
+
+        with test_utils.run_test_server() as httpd:
+            self.loop.run_until_complete(test(httpd))
+
+    @support.skip_unless_bind_unix_socket
+    def test_connect_unix_async_context_manager(self):
+        async def test(httpd):
+            async with asyncio.connect_unix(httpd.address) as stream:
+                await stream.write(b'GET / HTTP/1.0\r\n\r\n')
+                data = await stream.readline()
+                self.assertEqual(data, b'HTTP/1.0 200 OK\r\n')
+                data = await stream.read()
+                self.assertTrue(data.endswith(b'\r\n\r\nTest message'))
+            self.assertTrue(stream.is_closing())
+
+        with test_utils.run_test_unix_server() as httpd:
+            self.loop.run_until_complete(test(httpd))
+
+    def test_stream_server(self):
+
+        async def handle_client(stream):
+            self.assertTrue(stream.is_server_side())
+            data = await stream.readline()
+            await stream.write(data)
+            await stream.close()
+
+        async def client(srv):
+            addr = srv.sockets[0].getsockname()
+            stream = await asyncio.connect(*addr)
+            # send a line
+            await stream.write(b"hello world!\n")
+            # read it back
+            msgback = await stream.readline()
+            await stream.close()
+            self.assertEqual(msgback, b"hello world!\n")
+            await srv.close()
+
+        async def test():
+            async with asyncio.StreamServer(handle_client, '127.0.0.1', 0) as server:
+                await server.start_serving()
+                task = asyncio.create_task(client(server))
+                with contextlib.suppress(asyncio.CancelledError):
+                    await server.serve_forever()
+                await task
+
+        messages = []
+        self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
+        self.loop.run_until_complete(test())
+        self.assertEqual(messages, [])
+
+    @support.skip_unless_bind_unix_socket
+    def test_unix_stream_server(self):
+
+        async def handle_client(stream):
+            data = await stream.readline()
+            await stream.write(data)
+            await stream.close()
+
+        async def client(srv):
+            addr = srv.sockets[0].getsockname()
+            stream = await asyncio.connect_unix(addr)
+            # send a line
+            await stream.write(b"hello world!\n")
+            # read it back
+            msgback = await stream.readline()
+            await stream.close()
+            self.assertEqual(msgback, b"hello world!\n")
+            await srv.close()
+
+        async def test():
+            with test_utils.unix_socket_path() as path:
+                async with asyncio.UnixStreamServer(handle_client, path) as server:
+                    await server.start_serving()
+                    task = asyncio.create_task(client(server))
+                    with contextlib.suppress(asyncio.CancelledError):
+                        await server.serve_forever()
+                    await task
+
+        messages = []
+        self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
+        self.loop.run_until_complete(test())
+        self.assertEqual(messages, [])
+
+    def test_stream_server_inheritance_forbidden(self):
+        with self.assertRaises(TypeError):
+            class MyServer(asyncio.StreamServer):
+                pass
+
+    @support.skip_unless_bind_unix_socket
+    def test_unix_stream_server_inheritance_forbidden(self):
+        with self.assertRaises(TypeError):
+            class MyServer(asyncio.UnixStreamServer):
+                pass
+
+    def test_stream_server_bind(self):
+        async def handle_client(stream):
+            await stream.close()
+
+        async def test():
+            srv = asyncio.StreamServer(handle_client, '127.0.0.1', 0)
+            self.assertFalse(srv.is_bound())
+            self.assertEqual(0, len(srv.sockets))
+            await srv.bind()
+            self.assertTrue(srv.is_bound())
+            self.assertEqual(1, len(srv.sockets))
+            await srv.close()
+            self.assertFalse(srv.is_bound())
+            self.assertEqual(0, len(srv.sockets))
+
+        messages = []
+        self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
+        self.loop.run_until_complete(test())
+        self.assertEqual(messages, [])
+
+    def test_stream_server_bind_async_with(self):
+        async def handle_client(stream):
+            await stream.close()
+
+        async def test():
+            async with asyncio.StreamServer(handle_client, '127.0.0.1', 0) as srv:
+                self.assertTrue(srv.is_bound())
+                self.assertEqual(1, len(srv.sockets))
+
+        messages = []
+        self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
+        self.loop.run_until_complete(test())
+        self.assertEqual(messages, [])
+
+    def test_stream_server_start_serving(self):
+        async def handle_client(stream):
+            await stream.close()
+
+        async def test():
+            async with asyncio.StreamServer(handle_client, '127.0.0.1', 0) as srv:
+                self.assertFalse(srv.is_serving())
+                await srv.start_serving()
+                self.assertTrue(srv.is_serving())
+                await srv.close()
+                self.assertFalse(srv.is_serving())
+
+        messages = []
+        self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
+        self.loop.run_until_complete(test())
+        self.assertEqual(messages, [])
+
+    def test_stream_server_close(self):
+        server_stream_aborted = False
+        fut = self.loop.create_future()
+
+        async def handle_client(stream):
+            await fut
+            self.assertEqual(b'', await stream.readline())
+            nonlocal server_stream_aborted
+            server_stream_aborted = True
+
+        async def client(srv):
+            addr = srv.sockets[0].getsockname()
+            stream = await asyncio.connect(*addr)
+            fut.set_result(None)
+            self.assertEqual(b'', await stream.readline())
+            await stream.close()
+
+        async def test():
+            async with asyncio.StreamServer(handle_client, '127.0.0.1', 0) as server:
+                await server.start_serving()
+                task = asyncio.create_task(client(server))
+                await fut
+                await server.close()
+                await task
+
+        messages = []
+        self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
+        self.loop.run_until_complete(test())
+        self.assertEqual(messages, [])
+        self.assertTrue(fut.done())
+        self.assertTrue(server_stream_aborted)
+
+    def test_stream_server_abort(self):
+        server_stream_aborted = False
+        fut = self.loop.create_future()
+
+        async def handle_client(stream):
+            await fut
+            self.assertEqual(b'', await stream.readline())
+            nonlocal server_stream_aborted
+            server_stream_aborted = True
+
+        async def client(srv):
+            addr = srv.sockets[0].getsockname()
+            stream = await asyncio.connect(*addr)
+            fut.set_result(None)
+            self.assertEqual(b'', await stream.readline())
+            await stream.close()
+
+        async def test():
+            async with asyncio.StreamServer(handle_client, '127.0.0.1', 0) as server:
+                await server.start_serving()
+                task = asyncio.create_task(client(server))
+                await fut
+                await server.abort()
+                await task
+
+        messages = []
+        self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
+        self.loop.run_until_complete(test())
+        self.assertEqual(messages, [])
+        self.assertTrue(fut.done())
+        self.assertTrue(server_stream_aborted)
+
+    def test_stream_shutdown_hung_task(self):
+        fut1 = self.loop.create_future()
+        fut2 = self.loop.create_future()
+
+        async def handle_client(stream):
+            while True:
+                await asyncio.sleep(0.01)
+
+        async def client(srv):
+            addr = srv.sockets[0].getsockname()
+            stream = await asyncio.connect(*addr)
+            fut1.set_result(None)
+            await fut2
+            self.assertEqual(b'', await stream.readline())
+            await stream.close()
+
+        async def test():
+            async with asyncio.StreamServer(handle_client,
+                                            '127.0.0.1',
+                                            0,
+                                            shutdown_timeout=0.3) as server:
+                await server.start_serving()
+                task = asyncio.create_task(client(server))
+                await fut1
+                await server.close()
+                fut2.set_result(None)
+                await task
+
+        messages = []
+        self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
+        self.loop.run_until_complete(test())
+        self.assertEqual(messages, [])
+        self.assertTrue(fut1.done())
+        self.assertTrue(fut2.done())
+
+    def test_stream_shutdown_hung_task_prevents_cancellation(self):
+        fut1 = self.loop.create_future()
+        fut2 = self.loop.create_future()
+        do_handle_client = True
+
+        async def handle_client(stream):
+            while do_handle_client:
+                with contextlib.suppress(asyncio.CancelledError):
+                    await asyncio.sleep(0.01)
+
+        async def client(srv):
+            addr = srv.sockets[0].getsockname()
+            stream = await asyncio.connect(*addr)
+            fut1.set_result(None)
+            await fut2
+            self.assertEqual(b'', await stream.readline())
+            await stream.close()
+
+        async def test():
+            async with asyncio.StreamServer(handle_client,
+                                            '127.0.0.1',
+                                            0,
+                                            shutdown_timeout=0.3) as server:
+                await server.start_serving()
+                task = asyncio.create_task(client(server))
+                await fut1
+                await server.close()
+                nonlocal do_handle_client
+                do_handle_client = False
+                fut2.set_result(None)
+                await task
+
+        messages = []
+        self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
+        self.loop.run_until_complete(test())
+        self.assertEqual(1, len(messages))
+        self.assertRegex(messages[0]['message'],
+                         "<Task pending .+ ignored cancellation request")
+        self.assertTrue(fut1.done())
+        self.assertTrue(fut2.done())
+
+    def test_sendfile(self):
+        messages = []
+        self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
+
+        with open(support.TESTFN, 'wb') as fp:
+            fp.write(b'data\n')
+        self.addCleanup(support.unlink, support.TESTFN)
+
+        async def serve_callback(stream):
+            data = await stream.readline()
+            self.assertEqual(data, b'begin\n')
+            data = await stream.readline()
+            self.assertEqual(data, b'data\n')
+            data = await stream.readline()
+            self.assertEqual(data, b'end\n')
+            await stream.write(b'done\n')
+            await stream.close()
+
+        async def do_connect(host, port):
+            stream = await asyncio.connect(host, port)
+            await stream.write(b'begin\n')
+            with open(support.TESTFN, 'rb') as fp:
+                await stream.sendfile(fp)
+            await stream.write(b'end\n')
+            data = await stream.readline()
+            self.assertEqual(data, b'done\n')
+            await stream.close()
+
+        async def test():
+            async with asyncio.StreamServer(serve_callback, '127.0.0.1', 0) as srv:
+                await srv.start_serving()
+                await do_connect(*srv.sockets[0].getsockname())
+
+        self.loop.run_until_complete(test())
+
+        self.assertEqual([], messages)
+
+
+    @unittest.skipIf(ssl is None, 'No ssl module')
+    def test_connect_start_tls(self):
+        with test_utils.run_test_server(use_ssl=True) as httpd:
+            # connect without SSL but upgrade to TLS just after
+            # connection is established
+            stream = self.loop.run_until_complete(
+                asyncio.connect(*httpd.address))
+
+            self.loop.run_until_complete(
+                stream.start_tls(
+                    sslcontext=test_utils.dummy_ssl_context()))
+            self._basetest_connect(stream)
+
+    def test_repr_unbound(self):
+        async def serve(stream):
+            pass
+
+        async def test():
+            srv = asyncio.StreamServer(serve)
+            self.assertEqual('<StreamServer>', repr(srv))
+            await srv.close()
+
+        self.loop.run_until_complete(test())
+
+    def test_repr_bound(self):
+        async def serve(stream):
+            pass
+
+        async def test():
+            srv = asyncio.StreamServer(serve, '127.0.0.1', 0)
+            await srv.bind()
+            self.assertRegex(repr(srv), r'<StreamServer sockets=\(.+\)>')
+            await srv.close()
+
+        self.loop.run_until_complete(test())
+
+    def test_repr_serving(self):
+        async def serve(stream):
+            pass
+
+        async def test():
+            srv = asyncio.StreamServer(serve, '127.0.0.1', 0)
+            await srv.start_serving()
+            self.assertRegex(repr(srv), r'<StreamServer serving sockets=\(.+\)>')
+            await srv.close()
+
+        self.loop.run_until_complete(test())
+
+
+    @unittest.skipUnless(sys.platform != 'win32',
+                         "Don't support pipes for Windows")
+    def test_read_pipe(self):
+        async def test():
+            rpipe, wpipe = os.pipe()
+            pipeobj = io.open(rpipe, 'rb', 1024)
+
+            async with asyncio.connect_read_pipe(pipeobj) as stream:
+                self.assertEqual(stream.mode, asyncio.StreamMode.READ)
+
+                os.write(wpipe, b'1')
+                data = await stream.readexactly(1)
+                self.assertEqual(data, b'1')
+
+                os.write(wpipe, b'2345')
+                data = await stream.readexactly(4)
+                self.assertEqual(data, b'2345')
+                os.close(wpipe)
+
+        self.loop.run_until_complete(test())
+
+    @unittest.skipUnless(sys.platform != 'win32',
+                         "Don't support pipes for Windows")
+    def test_write_pipe(self):
+        async def test():
+            rpipe, wpipe = os.pipe()
+            pipeobj = io.open(wpipe, 'wb', 1024)
+
+            async with asyncio.connect_write_pipe(pipeobj) as stream:
+                self.assertEqual(stream.mode, asyncio.StreamMode.WRITE)
+
+                await stream.write(b'1')
+                data = os.read(rpipe, 1024)
+                self.assertEqual(data, b'1')
+
+                await stream.write(b'2345')
+                data = os.read(rpipe, 1024)
+                self.assertEqual(data, b'2345')
+
+                os.close(rpipe)
 
+        self.loop.run_until_complete(test())
 
 
 if __name__ == '__main__':
index e201a0696796d68abeeedccd0edc69731a61e9b5..13aef7cf1f776b0b49597baf3a3bef3180e132c6 100644 (file)
@@ -17,6 +17,7 @@ import _winapi
 
 import asyncio
 from asyncio import windows_events
+from asyncio.streams import _StreamProtocol
 from test.test_asyncio import utils as test_utils
 from test.support.script_helper import spawn_python
 
@@ -100,16 +101,16 @@ class ProactorTests(test_utils.TestCase):
 
         clients = []
         for i in range(5):
-            stream_reader = asyncio.StreamReader(loop=self.loop,
-                                                 _asyncio_internal=True)
-            protocol = asyncio.StreamReaderProtocol(stream_reader,
-                                                    loop=self.loop,
-                                                    _asyncio_internal=True)
+            stream = asyncio.Stream(mode=asyncio.StreamMode.READ,
+                                    loop=self.loop, _asyncio_internal=True)
+            protocol = _StreamProtocol(stream,
+                                       loop=self.loop,
+                                       _asyncio_internal=True)
             trans, proto = await self.loop.create_pipe_connection(
                 lambda: protocol, ADDRESS)
             self.assertIsInstance(trans, asyncio.Transport)
             self.assertEqual(protocol, proto)
-            clients.append((stream_reader, trans))
+            clients.append((stream, trans))
 
         for i, (r, w) in enumerate(clients):
             w.write('lower-{}\n'.format(i).encode())
@@ -118,6 +119,7 @@ class ProactorTests(test_utils.TestCase):
             response = await r.readline()
             self.assertEqual(response, 'LOWER-{}\n'.format(i).encode())
             w.close()
+            await r.close()
 
         server.close()
 
diff --git a/Misc/NEWS.d/next/Library/2019-05-14-12-25-44.bpo-36889.MChPqP.rst b/Misc/NEWS.d/next/Library/2019-05-14-12-25-44.bpo-36889.MChPqP.rst
new file mode 100644 (file)
index 0000000..d08c0e2
--- /dev/null
@@ -0,0 +1,6 @@
+Introduce :class:`asyncio.Stream` class that merges :class:`asyncio.StreamReader` and :class:`asyncio.StreamWriter` functionality.
+:class:`asyncio.Stream` can work in readonly, writeonly and readwrite modes.
+Provide :func:`asyncio.connect`, :func:`asyncio.connect_unix`, :func:`asyncio.connect_read_pipe` and :func:`asyncio.connect_write_pipe` factories to open :class:`asyncio.Stream` connections. Provide :class:`asyncio.StreamServer` and :class:`UnixStreamServer` to serve servers with asyncio.Stream API.
+Modify :func:`asyncio.create_subprocess_shell` and :func:`asyncio.create_subprocess_exec` to use :class:`asyncio.Stream` instead of deprecated :class:`StreamReader` and :class:`StreamWriter`.
+Deprecate :class:`asyncio.StreamReader` and :class:`asyncio.StreamWriter`.
+Deprecate usage of private classes, e.g. :class:`asyncio.FlowControlMixing` and :class:`asyncio.StreamReaderProtocol` outside of asyncio package.