]> granicus.if.org Git - python/commitdiff
bpo-33654: Support BufferedProtocol in set_protocol() and start_tls() (GH-7130)
authorYury Selivanov <yury@magic.io>
Mon, 28 May 2018 18:31:28 +0000 (14:31 -0400)
committerGitHub <noreply@github.com>
Mon, 28 May 2018 18:31:28 +0000 (14:31 -0400)
In this commit:

* Support BufferedProtocol in set_protocol() and start_tls()
* Fix proactor to cancel readers reliably
* Update tests to be compatible with OpenSSL 1.1.1
* Clarify BufferedProtocol docs
* Bump TLS tests timeouts to 60 seconds; eliminate possible race from start_serving
* Rewrite test_start_tls_server_1

13 files changed:
Doc/library/asyncio-protocol.rst
Lib/asyncio/base_events.py
Lib/asyncio/proactor_events.py
Lib/asyncio/protocols.py
Lib/asyncio/selector_events.py
Lib/asyncio/sslproto.py
Lib/asyncio/unix_events.py
Lib/test/test_asyncio/test_buffered_proto.py
Lib/test/test_asyncio/test_events.py
Lib/test/test_asyncio/test_proactor_events.py
Lib/test/test_asyncio/test_selector_events.py
Lib/test/test_asyncio/test_sslproto.py
Misc/NEWS.d/next/Library/2018-05-26-13-09-34.bpo-33654.IbYWxA.rst [new file with mode: 0644]

index ef6441605cd72c2f4aa11b41289f559fad165221..9a08a4a49021ccba8363f06dd833d0f95009bb4b 100644 (file)
@@ -463,16 +463,23 @@ The idea of BufferedProtocol is that it allows to manually allocate
 and control the receive buffer.  Event loops can then use the buffer
 provided by the protocol to avoid unnecessary data copies.  This
 can result in noticeable performance improvement for protocols that
-receive big amounts of data.  Sophisticated protocols can allocate
-the buffer only once at creation time.
+receive big amounts of data.  Sophisticated protocols implementations
+can allocate the buffer only once at creation time.
 
 The following callbacks are called on :class:`BufferedProtocol`
 instances:
 
-.. method:: BufferedProtocol.get_buffer()
+.. method:: BufferedProtocol.get_buffer(sizehint)
 
-   Called to allocate a new receive buffer.  Must return an object
-   that implements the :ref:`buffer protocol <bufferobjects>`.
+   Called to allocate a new receive buffer.
+
+   *sizehint* is a recommended minimal size for the returned
+   buffer.  It is acceptable to return smaller or bigger buffers
+   than what *sizehint* suggests.  When set to -1, the buffer size
+   can be arbitrary. It is an error to return a zero-sized buffer.
+
+   Must return an object that implements the
+   :ref:`buffer protocol <bufferobjects>`.
 
 .. method:: BufferedProtocol.buffer_updated(nbytes)
 
index 09eb440b0ef7af86f50171cf462e5f6cb0f3a0c7..a0243f5bac9a2c37945ef1ab25bf5cefa6f8ebf1 100644 (file)
@@ -157,7 +157,6 @@ def _run_until_complete_cb(fut):
     futures._get_loop(fut).stop()
 
 
-
 class _SendfileFallbackProtocol(protocols.Protocol):
     def __init__(self, transp):
         if not isinstance(transp, transports._FlowControlMixin):
@@ -304,6 +303,9 @@ class Server(events.AbstractServer):
 
     async def start_serving(self):
         self._start_serving()
+        # Skip one loop iteration so that all 'loop.add_reader'
+        # go through.
+        await tasks.sleep(0, loop=self._loop)
 
     async def serve_forever(self):
         if self._serving_forever_fut is not None:
@@ -1363,6 +1365,9 @@ class BaseEventLoop(events.AbstractEventLoop):
                         ssl, backlog, ssl_handshake_timeout)
         if start_serving:
             server._start_serving()
+            # Skip one loop iteration so that all 'loop.add_reader'
+            # go through.
+            await tasks.sleep(0, loop=self)
 
         if self._debug:
             logger.info("%r is serving", server)
index 877dfb0746708ecb9ab51adb76ec0eead3810f80..337ed0fb2047510fd651adb9a58bb72c2408f0eb 100644 (file)
@@ -30,7 +30,7 @@ class _ProactorBasePipeTransport(transports._FlowControlMixin,
         super().__init__(extra, loop)
         self._set_extra(sock)
         self._sock = sock
-        self._protocol = protocol
+        self.set_protocol(protocol)
         self._server = server
         self._buffer = None  # None or bytearray.
         self._read_fut = None
@@ -159,16 +159,26 @@ class _ProactorReadPipeTransport(_ProactorBasePipeTransport,
 
     def __init__(self, loop, sock, protocol, waiter=None,
                  extra=None, server=None):
+        self._loop_reading_cb = None
+        self._paused = True
         super().__init__(loop, sock, protocol, waiter, extra, server)
-        self._paused = False
+
         self._reschedule_on_resume = False
+        self._loop.call_soon(self._loop_reading)
+        self._paused = False
 
-        if protocols._is_buffered_protocol(protocol):
-            self._loop_reading = self._loop_reading__get_buffer
+    def set_protocol(self, protocol):
+        if isinstance(protocol, protocols.BufferedProtocol):
+            self._loop_reading_cb = self._loop_reading__get_buffer
         else:
-            self._loop_reading = self._loop_reading__data_received
+            self._loop_reading_cb = self._loop_reading__data_received
 
-        self._loop.call_soon(self._loop_reading)
+        super().set_protocol(protocol)
+
+        if self.is_reading():
+            # reset reading callback / buffers / self._read_fut
+            self.pause_reading()
+            self.resume_reading()
 
     def is_reading(self):
         return not self._paused and not self._closing
@@ -179,6 +189,13 @@ class _ProactorReadPipeTransport(_ProactorBasePipeTransport,
         self._paused = True
 
         if self._read_fut is not None and not self._read_fut.done():
+            # TODO: This is an ugly hack to cancel the current read future
+            # *and* avoid potential race conditions, as read cancellation
+            # goes through `future.cancel()` and `loop.call_soon()`.
+            # We then use this special attribute in the reader callback to
+            # exit *immediately* without doing any cleanup/rescheduling.
+            self._read_fut.__asyncio_cancelled_on_pause__ = True
+
             self._read_fut.cancel()
             self._read_fut = None
             self._reschedule_on_resume = True
@@ -210,7 +227,14 @@ class _ProactorReadPipeTransport(_ProactorBasePipeTransport,
         if not keep_open:
             self.close()
 
-    def _loop_reading__data_received(self, fut=None):
+    def _loop_reading(self, fut=None):
+        self._loop_reading_cb(fut)
+
+    def _loop_reading__data_received(self, fut):
+        if (fut is not None and
+                getattr(fut, '__asyncio_cancelled_on_pause__', False)):
+            return
+
         if self._paused:
             self._reschedule_on_resume = True
             return
@@ -253,14 +277,18 @@ class _ProactorReadPipeTransport(_ProactorBasePipeTransport,
             if not self._closing:
                 raise
         else:
-            self._read_fut.add_done_callback(self._loop_reading)
+            self._read_fut.add_done_callback(self._loop_reading__data_received)
         finally:
             if data:
                 self._protocol.data_received(data)
             elif data == b'':
                 self._loop_reading__on_eof()
 
-    def _loop_reading__get_buffer(self, fut=None):
+    def _loop_reading__get_buffer(self, fut):
+        if (fut is not None and
+                getattr(fut, '__asyncio_cancelled_on_pause__', False)):
+            return
+
         if self._paused:
             self._reschedule_on_resume = True
             return
@@ -310,7 +338,9 @@ class _ProactorReadPipeTransport(_ProactorBasePipeTransport,
             return
 
         try:
-            buf = self._protocol.get_buffer()
+            buf = self._protocol.get_buffer(-1)
+            if not len(buf):
+                raise RuntimeError('get_buffer() returned an empty buffer')
         except Exception as exc:
             self._fatal_error(
                 exc, 'Fatal error: protocol.get_buffer() call failed.')
@@ -319,7 +349,7 @@ class _ProactorReadPipeTransport(_ProactorBasePipeTransport,
         try:
             # schedule a new read
             self._read_fut = self._loop._proactor.recv_into(self._sock, buf)
-            self._read_fut.add_done_callback(self._loop_reading)
+            self._read_fut.add_done_callback(self._loop_reading__get_buffer)
         except ConnectionAbortedError as exc:
             if not self._closing:
                 self._fatal_error(exc, 'Fatal read error on pipe transport')
index dc298a8d5c9510d5bd722a71269926d193c16173..b8d2e6be552e1e338d5a4e5d3921adbcfeb1618f 100644 (file)
@@ -130,11 +130,15 @@ class BufferedProtocol(BaseProtocol):
     * CL: connection_lost()
     """
 
-    def get_buffer(self):
+    def get_buffer(self, sizehint):
         """Called to allocate a new receive buffer.
 
+        *sizehint* is a recommended minimal size for the returned
+        buffer.  When set to -1, the buffer size can be arbitrary.
+
         Must return an object that implements the
         :ref:`buffer protocol <bufferobjects>`.
+        It is an error to return a zero-sized buffer.
         """
 
     def buffer_updated(self, nbytes):
@@ -185,7 +189,3 @@ class SubprocessProtocol(BaseProtocol):
 
     def process_exited(self):
         """Called when subprocess has exited."""
-
-
-def _is_buffered_protocol(proto):
-    return hasattr(proto, 'get_buffer') and not hasattr(proto, 'data_received')
index 5473c7055212a4b6e4a066055fab1ba91f6df978..116c08d6ff7fddaefc59c6fc07f973dda981aa67 100644 (file)
@@ -597,8 +597,10 @@ class _SelectorTransport(transports._FlowControlMixin,
                 self._extra['peername'] = None
         self._sock = sock
         self._sock_fd = sock.fileno()
-        self._protocol = protocol
-        self._protocol_connected = True
+
+        self._protocol_connected = False
+        self.set_protocol(protocol)
+
         self._server = server
         self._buffer = self._buffer_factory()
         self._conn_lost = 0  # Set when call to connection_lost scheduled.
@@ -640,6 +642,7 @@ class _SelectorTransport(transports._FlowControlMixin,
 
     def set_protocol(self, protocol):
         self._protocol = protocol
+        self._protocol_connected = True
 
     def get_protocol(self):
         return self._protocol
@@ -721,11 +724,7 @@ class _SelectorSocketTransport(_SelectorTransport):
     def __init__(self, loop, sock, protocol, waiter=None,
                  extra=None, server=None):
 
-        if protocols._is_buffered_protocol(protocol):
-            self._read_ready = self._read_ready__get_buffer
-        else:
-            self._read_ready = self._read_ready__data_received
-
+        self._read_ready_cb = None
         super().__init__(loop, sock, protocol, extra, server)
         self._eof = False
         self._paused = False
@@ -745,6 +744,14 @@ class _SelectorSocketTransport(_SelectorTransport):
             self._loop.call_soon(futures._set_result_unless_cancelled,
                                  waiter, None)
 
+    def set_protocol(self, protocol):
+        if isinstance(protocol, protocols.BufferedProtocol):
+            self._read_ready_cb = self._read_ready__get_buffer
+        else:
+            self._read_ready_cb = self._read_ready__data_received
+
+        super().set_protocol(protocol)
+
     def is_reading(self):
         return not self._paused and not self._closing
 
@@ -764,12 +771,17 @@ class _SelectorSocketTransport(_SelectorTransport):
         if self._loop.get_debug():
             logger.debug("%r resumes reading", self)
 
+    def _read_ready(self):
+        self._read_ready_cb()
+
     def _read_ready__get_buffer(self):
         if self._conn_lost:
             return
 
         try:
-            buf = self._protocol.get_buffer()
+            buf = self._protocol.get_buffer(-1)
+            if not len(buf):
+                raise RuntimeError('get_buffer() returned an empty buffer')
         except Exception as exc:
             self._fatal_error(
                 exc, 'Fatal error: protocol.get_buffer() call failed.')
index 2bbf134c0f7e6f6c78bb603ec4fae107d1dcbe5d..2bfa45dd1585afa4290579a15f1aa07fb9fb6371 100644 (file)
@@ -441,6 +441,8 @@ class SSLProtocol(protocols.Protocol):
         self._waiter = waiter
         self._loop = loop
         self._app_protocol = app_protocol
+        self._app_protocol_is_buffer = \
+            isinstance(app_protocol, protocols.BufferedProtocol)
         self._app_transport = _SSLProtocolTransport(self._loop, self)
         # _SSLPipe instance (None until the connection is made)
         self._sslpipe = None
@@ -522,7 +524,16 @@ class SSLProtocol(protocols.Protocol):
 
         for chunk in appdata:
             if chunk:
-                self._app_protocol.data_received(chunk)
+                try:
+                    if self._app_protocol_is_buffer:
+                        _feed_data_to_bufferred_proto(
+                            self._app_protocol, chunk)
+                    else:
+                        self._app_protocol.data_received(chunk)
+                except Exception as ex:
+                    self._fatal_error(
+                        ex, 'application protocol failed to receive SSL data')
+                    return
             else:
                 self._start_shutdown()
                 break
@@ -709,3 +720,22 @@ class SSLProtocol(protocols.Protocol):
                 self._transport.abort()
         finally:
             self._finalize()
+
+
+def _feed_data_to_bufferred_proto(proto, data):
+    data_len = len(data)
+    while data_len:
+        buf = proto.get_buffer(data_len)
+        buf_len = len(buf)
+        if not buf_len:
+            raise RuntimeError('get_buffer() returned an empty buffer')
+
+        if buf_len >= data_len:
+            buf[:data_len] = data
+            proto.buffer_updated(data_len)
+            return
+        else:
+            buf[:buf_len] = data[:buf_len]
+            proto.buffer_updated(buf_len)
+            data = data[buf_len:]
+            data_len = len(data)
index f64037a25c67b8c5a97e3cd2a71acab07733266d..7cad7e3637a11f5826d785c06d576bb2aa1d2fd4 100644 (file)
@@ -20,6 +20,7 @@ from . import coroutines
 from . import events
 from . import futures
 from . import selector_events
+from . import tasks
 from . import transports
 from .log import logger
 
@@ -308,6 +309,9 @@ class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop):
                                     ssl, backlog, ssl_handshake_timeout)
         if start_serving:
             server._start_serving()
+            # Skip one loop iteration so that all 'loop.add_reader'
+            # go through.
+            await tasks.sleep(0, loop=self)
 
         return server
 
index 22f9269e814f99ffcb0677ad44fe5f78ff3a4098..89d3df72d98b6201090e55d8d6569f67f24b05d4 100644 (file)
@@ -9,7 +9,7 @@ class ReceiveStuffProto(asyncio.BufferedProtocol):
         self.cb = cb
         self.con_lost_fut = con_lost_fut
 
-    def get_buffer(self):
+    def get_buffer(self, sizehint):
         self.buffer = bytearray(100)
         return self.buffer
 
index 64d726d16d1cd89f6dae9ab89d7c1ee82b57026d..d7b0a665a0abc10a2ceeb41ef37344356423aa78 100644 (file)
@@ -2095,7 +2095,7 @@ class SubprocessTestsMixin:
 
 class SendfileBase:
 
-    DATA = b"12345abcde" * 16 * 1024  # 160 KiB
+    DATA = b"12345abcde" * 64 * 1024  # 64 KiB (don't use smaller sizes)
 
     @classmethod
     def setUpClass(cls):
@@ -2452,7 +2452,7 @@ class SendfileMixin(SendfileBase):
         self.assertEqual(srv_proto.data, self.DATA)
         self.assertEqual(self.file.tell(), len(self.DATA))
 
-    def test_sendfile_close_peer_in_middle_of_receiving(self):
+    def test_sendfile_close_peer_in_the_middle_of_receiving(self):
         srv_proto, cli_proto = self.prepare_sendfile(close_after=1024)
         with self.assertRaises(ConnectionError):
             self.run_loop(
@@ -2465,7 +2465,7 @@ class SendfileMixin(SendfileBase):
                         self.file.tell())
         self.assertTrue(cli_proto.transport.is_closing())
 
-    def test_sendfile_fallback_close_peer_in_middle_of_receiving(self):
+    def test_sendfile_fallback_close_peer_in_the_middle_of_receiving(self):
 
         def sendfile_native(transp, file, offset, count):
             # to raise SendfileNotAvailableError
index 6313d594477a74d4984302f268f94965d12b1fc7..6da6b4a34db81e352c53db952aab0fb3fc8266e1 100644 (file)
@@ -465,8 +465,8 @@ class ProactorSocketTransportBufferedProtoTests(test_utils.TestCase):
         self.loop._proactor = self.proactor
 
         self.protocol = test_utils.make_test_protocol(asyncio.BufferedProtocol)
-        self.buf = mock.Mock()
-        self.protocol.get_buffer.side_effect = lambda: self.buf
+        self.buf = bytearray(1)
+        self.protocol.get_buffer.side_effect = lambda hint: self.buf
 
         self.sock = mock.Mock(socket.socket)
 
@@ -505,6 +505,64 @@ class ProactorSocketTransportBufferedProtoTests(test_utils.TestCase):
         self.assertTrue(self.protocol.get_buffer.called)
         self.assertFalse(self.protocol.buffer_updated.called)
 
+    def test_get_buffer_zerosized(self):
+        transport = self.socket_transport()
+        transport._fatal_error = mock.Mock()
+
+        self.loop.call_exception_handler = mock.Mock()
+        self.protocol.get_buffer.side_effect = lambda hint: bytearray(0)
+
+        transport._loop_reading()
+
+        self.assertTrue(transport._fatal_error.called)
+        self.assertTrue(self.protocol.get_buffer.called)
+        self.assertFalse(self.protocol.buffer_updated.called)
+
+    def test_proto_type_switch(self):
+        self.protocol = test_utils.make_test_protocol(asyncio.Protocol)
+        tr = self.socket_transport()
+
+        res = asyncio.Future(loop=self.loop)
+        res.set_result(b'data')
+
+        tr = self.socket_transport()
+        tr._read_fut = res
+        tr._loop_reading(res)
+        self.loop._proactor.recv.assert_called_with(self.sock, 32768)
+        self.protocol.data_received.assert_called_with(b'data')
+
+        # switch protocol to a BufferedProtocol
+
+        buf_proto = test_utils.make_test_protocol(asyncio.BufferedProtocol)
+        buf = bytearray(4)
+        buf_proto.get_buffer.side_effect = lambda hint: buf
+
+        tr.set_protocol(buf_proto)
+        test_utils.run_briefly(self.loop)
+        res = asyncio.Future(loop=self.loop)
+        res.set_result(4)
+
+        tr._read_fut = res
+        tr._loop_reading(res)
+        self.loop._proactor.recv_into.assert_called_with(self.sock, buf)
+        buf_proto.buffer_updated.assert_called_with(4)
+
+    def test_proto_buf_switch(self):
+        tr = self.socket_transport()
+        test_utils.run_briefly(self.loop)
+        self.protocol.get_buffer.assert_called_with(-1)
+
+        # switch protocol to *another* BufferedProtocol
+
+        buf_proto = test_utils.make_test_protocol(asyncio.BufferedProtocol)
+        buf = bytearray(4)
+        buf_proto.get_buffer.side_effect = lambda hint: buf
+        tr._read_fut.done.side_effect = lambda: False
+        tr.set_protocol(buf_proto)
+        self.assertFalse(buf_proto.get_buffer.called)
+        test_utils.run_briefly(self.loop)
+        buf_proto.get_buffer.assert_called_with(-1)
+
     def test_buffer_updated_error(self):
         transport = self.socket_transport()
         transport._fatal_error = mock.Mock()
index 5c4ff5745b626a9dbf76093b6fe13b86bb81c55b..68b6ee9abbf11a69f4ff29547ddbb7af1791177f 100644 (file)
@@ -772,7 +772,8 @@ class BaseSelectorEventLoopTests(test_utils.TestCase):
             accept2_mock.return_value = None
             with mock_obj(self.loop, 'create_task') as task_mock:
                 task_mock.return_value = None
-                self.loop._accept_connection(mock.Mock(), sock, backlog=backlog)
+                self.loop._accept_connection(
+                    mock.Mock(), sock, backlog=backlog)
         self.assertEqual(sock.accept.call_count, backlog)
 
 
@@ -1285,8 +1286,8 @@ class SelectorSocketTransportBufferedProtocolTests(test_utils.TestCase):
         self.loop = self.new_test_loop()
 
         self.protocol = test_utils.make_test_protocol(asyncio.BufferedProtocol)
-        self.buf = mock.Mock()
-        self.protocol.get_buffer.side_effect = lambda: self.buf
+        self.buf = bytearray(1)
+        self.protocol.get_buffer.side_effect = lambda hint: self.buf
 
         self.sock = mock.Mock(socket.socket)
         self.sock_fd = self.sock.fileno.return_value = 7
@@ -1319,6 +1320,42 @@ class SelectorSocketTransportBufferedProtocolTests(test_utils.TestCase):
         self.assertTrue(self.protocol.get_buffer.called)
         self.assertFalse(self.protocol.buffer_updated.called)
 
+    def test_get_buffer_zerosized(self):
+        transport = self.socket_transport()
+        transport._fatal_error = mock.Mock()
+
+        self.loop.call_exception_handler = mock.Mock()
+        self.protocol.get_buffer.side_effect = lambda hint: bytearray(0)
+
+        transport._read_ready()
+
+        self.assertTrue(transport._fatal_error.called)
+        self.assertTrue(self.protocol.get_buffer.called)
+        self.assertFalse(self.protocol.buffer_updated.called)
+
+    def test_proto_type_switch(self):
+        self.protocol = test_utils.make_test_protocol(asyncio.Protocol)
+        transport = self.socket_transport()
+
+        self.sock.recv.return_value = b'data'
+        transport._read_ready()
+
+        self.protocol.data_received.assert_called_with(b'data')
+
+        # switch protocol to a BufferedProtocol
+
+        buf_proto = test_utils.make_test_protocol(asyncio.BufferedProtocol)
+        buf = bytearray(4)
+        buf_proto.get_buffer.side_effect = lambda hint: buf
+
+        transport.set_protocol(buf_proto)
+
+        self.sock.recv_into.return_value = 10
+        transport._read_ready()
+
+        buf_proto.get_buffer.assert_called_with(-1)
+        buf_proto.buffer_updated.assert_called_with(10)
+
     def test_buffer_updated_error(self):
         transport = self.socket_transport()
         transport._fatal_error = mock.Mock()
@@ -1354,7 +1391,7 @@ class SelectorSocketTransportBufferedProtocolTests(test_utils.TestCase):
         self.sock.recv_into.return_value = 10
         transport._read_ready()
 
-        self.protocol.get_buffer.assert_called_with()
+        self.protocol.get_buffer.assert_called_with(-1)
         self.protocol.buffer_updated.assert_called_with(10)
 
     def test_read_ready_eof(self):
index c534a341352b00f86e7412d7fd9f848ebd4a0ed2..932487a9e3c639127a54956ab25324263f824af7 100644 (file)
@@ -1,8 +1,7 @@
 """Tests for asyncio/sslproto.py."""
 
-import os
 import logging
-import time
+import socket
 import unittest
 from unittest import mock
 try:
@@ -185,17 +184,67 @@ class SslProtoHandshakeTests(test_utils.TestCase):
 
 class BaseStartTLS(func_tests.FunctionalTestCaseMixin):
 
+    PAYLOAD_SIZE = 1024 * 100
+    TIMEOUT = 60
+
     def new_loop(self):
         raise NotImplementedError
 
-    def test_start_tls_client_1(self):
-        HELLO_MSG = b'1' * 1024 * 1024
+    def test_buf_feed_data(self):
+
+        class Proto(asyncio.BufferedProtocol):
+
+            def __init__(self, bufsize, usemv):
+                self.buf = bytearray(bufsize)
+                self.mv = memoryview(self.buf)
+                self.data = b''
+                self.usemv = usemv
+
+            def get_buffer(self, sizehint):
+                if self.usemv:
+                    return self.mv
+                else:
+                    return self.buf
+
+            def buffer_updated(self, nsize):
+                if self.usemv:
+                    self.data += self.mv[:nsize]
+                else:
+                    self.data += self.buf[:nsize]
+
+        for usemv in [False, True]:
+            proto = Proto(1, usemv)
+            sslproto._feed_data_to_bufferred_proto(proto, b'12345')
+            self.assertEqual(proto.data, b'12345')
+
+            proto = Proto(2, usemv)
+            sslproto._feed_data_to_bufferred_proto(proto, b'12345')
+            self.assertEqual(proto.data, b'12345')
+
+            proto = Proto(2, usemv)
+            sslproto._feed_data_to_bufferred_proto(proto, b'1234')
+            self.assertEqual(proto.data, b'1234')
+
+            proto = Proto(4, usemv)
+            sslproto._feed_data_to_bufferred_proto(proto, b'1234')
+            self.assertEqual(proto.data, b'1234')
+
+            proto = Proto(100, usemv)
+            sslproto._feed_data_to_bufferred_proto(proto, b'12345')
+            self.assertEqual(proto.data, b'12345')
+
+            proto = Proto(0, usemv)
+            with self.assertRaisesRegex(RuntimeError, 'empty buffer'):
+                sslproto._feed_data_to_bufferred_proto(proto, b'12345')
+
+    def test_start_tls_client_reg_proto_1(self):
+        HELLO_MSG = b'1' * self.PAYLOAD_SIZE
 
         server_context = test_utils.simple_server_sslcontext()
         client_context = test_utils.simple_client_sslcontext()
 
         def serve(sock):
-            sock.settimeout(5)
+            sock.settimeout(self.TIMEOUT)
 
             data = sock.recv_all(len(HELLO_MSG))
             self.assertEqual(len(data), len(HELLO_MSG))
@@ -205,6 +254,8 @@ class BaseStartTLS(func_tests.FunctionalTestCaseMixin):
             sock.sendall(b'O')
             data = sock.recv_all(len(HELLO_MSG))
             self.assertEqual(len(data), len(HELLO_MSG))
+
+            sock.shutdown(socket.SHUT_RDWR)
             sock.close()
 
         class ClientProto(asyncio.Protocol):
@@ -246,17 +297,80 @@ class BaseStartTLS(func_tests.FunctionalTestCaseMixin):
             self.loop.run_until_complete(
                 asyncio.wait_for(client(srv.addr), loop=self.loop, timeout=10))
 
+    def test_start_tls_client_buf_proto_1(self):
+        HELLO_MSG = b'1' * self.PAYLOAD_SIZE
+
+        server_context = test_utils.simple_server_sslcontext()
+        client_context = test_utils.simple_client_sslcontext()
+
+        def serve(sock):
+            sock.settimeout(self.TIMEOUT)
+
+            data = sock.recv_all(len(HELLO_MSG))
+            self.assertEqual(len(data), len(HELLO_MSG))
+
+            sock.start_tls(server_context, server_side=True)
+
+            sock.sendall(b'O')
+            data = sock.recv_all(len(HELLO_MSG))
+            self.assertEqual(len(data), len(HELLO_MSG))
+
+            sock.shutdown(socket.SHUT_RDWR)
+            sock.close()
+
+        class ClientProto(asyncio.BufferedProtocol):
+            def __init__(self, on_data, on_eof):
+                self.on_data = on_data
+                self.on_eof = on_eof
+                self.con_made_cnt = 0
+                self.buf = bytearray(1)
+
+            def connection_made(proto, tr):
+                proto.con_made_cnt += 1
+                # Ensure connection_made gets called only once.
+                self.assertEqual(proto.con_made_cnt, 1)
+
+            def get_buffer(self, sizehint):
+                return self.buf
+
+            def buffer_updated(self, nsize):
+                assert nsize == 1
+                self.on_data.set_result(bytes(self.buf[:nsize]))
+
+            def eof_received(self):
+                self.on_eof.set_result(True)
+
+        async def client(addr):
+            await asyncio.sleep(0.5, loop=self.loop)
+
+            on_data = self.loop.create_future()
+            on_eof = self.loop.create_future()
+
+            tr, proto = await self.loop.create_connection(
+                lambda: ClientProto(on_data, on_eof), *addr)
+
+            tr.write(HELLO_MSG)
+            new_tr = await self.loop.start_tls(tr, proto, client_context)
+
+            self.assertEqual(await on_data, b'O')
+            new_tr.write(HELLO_MSG)
+            await on_eof
+
+            new_tr.close()
+
+        with self.tcp_server(serve) as srv:
+            self.loop.run_until_complete(
+                asyncio.wait_for(client(srv.addr),
+                                 loop=self.loop, timeout=self.TIMEOUT))
+
     def test_start_tls_server_1(self):
-        HELLO_MSG = b'1' * 1024 * 1024
+        HELLO_MSG = b'1' * self.PAYLOAD_SIZE
 
         server_context = test_utils.simple_server_sslcontext()
         client_context = test_utils.simple_client_sslcontext()
-        # TODO: fix TLSv1.3 support
-        client_context.options |= ssl.OP_NO_TLSv1_3
 
         def client(sock, addr):
-            time.sleep(0.5)
-            sock.settimeout(5)
+            sock.settimeout(self.TIMEOUT)
 
             sock.connect(addr)
             data = sock.recv_all(len(HELLO_MSG))
@@ -264,12 +378,15 @@ class BaseStartTLS(func_tests.FunctionalTestCaseMixin):
 
             sock.start_tls(client_context)
             sock.sendall(HELLO_MSG)
+
+            sock.shutdown(socket.SHUT_RDWR)
             sock.close()
 
         class ServerProto(asyncio.Protocol):
-            def __init__(self, on_con, on_eof):
+            def __init__(self, on_con, on_eof, on_con_lost):
                 self.on_con = on_con
                 self.on_eof = on_eof
+                self.on_con_lost = on_con_lost
                 self.data = b''
 
             def connection_made(self, tr):
@@ -281,7 +398,13 @@ class BaseStartTLS(func_tests.FunctionalTestCaseMixin):
             def eof_received(self):
                 self.on_eof.set_result(1)
 
-        async def main():
+            def connection_lost(self, exc):
+                if exc is None:
+                    self.on_con_lost.set_result(None)
+                else:
+                    self.on_con_lost.set_exception(exc)
+
+        async def main(proto, on_con, on_eof, on_con_lost):
             tr = await on_con
             tr.write(HELLO_MSG)
 
@@ -292,24 +415,29 @@ class BaseStartTLS(func_tests.FunctionalTestCaseMixin):
                 server_side=True)
 
             await on_eof
+            await on_con_lost
             self.assertEqual(proto.data, HELLO_MSG)
             new_tr.close()
 
-            server.close()
-            await server.wait_closed()
+        async def run_main():
+            on_con = self.loop.create_future()
+            on_eof = self.loop.create_future()
+            on_con_lost = self.loop.create_future()
+            proto = ServerProto(on_con, on_eof, on_con_lost)
 
-        on_con = self.loop.create_future()
-        on_eof = self.loop.create_future()
-        proto = ServerProto(on_con, on_eof)
+            server = await self.loop.create_server(
+                lambda: proto, '127.0.0.1', 0)
+            addr = server.sockets[0].getsockname()
 
-        server = self.loop.run_until_complete(
-            self.loop.create_server(
-                lambda: proto, '127.0.0.1', 0))
-        addr = server.sockets[0].getsockname()
+            with self.tcp_client(lambda sock: client(sock, addr)):
+                await asyncio.wait_for(
+                    main(proto, on_con, on_eof, on_con_lost),
+                    loop=self.loop, timeout=self.TIMEOUT)
 
-        with self.tcp_client(lambda sock: client(sock, addr)):
-            self.loop.run_until_complete(
-                asyncio.wait_for(main(), loop=self.loop, timeout=10))
+            server.close()
+            await server.wait_closed()
+
+        self.loop.run_until_complete(run_main())
 
     def test_start_tls_wrong_args(self):
         async def main():
@@ -332,7 +460,6 @@ class SelectorStartTLSTests(BaseStartTLS, unittest.TestCase):
 
 @unittest.skipIf(ssl is None, 'No ssl module')
 @unittest.skipUnless(hasattr(asyncio, 'ProactorEventLoop'), 'Windows only')
-@unittest.skipIf(os.environ.get('APPVEYOR'), 'XXX: issue 32458')
 class ProactorStartTLSTests(BaseStartTLS, unittest.TestCase):
 
     def new_loop(self):
diff --git a/Misc/NEWS.d/next/Library/2018-05-26-13-09-34.bpo-33654.IbYWxA.rst b/Misc/NEWS.d/next/Library/2018-05-26-13-09-34.bpo-33654.IbYWxA.rst
new file mode 100644 (file)
index 0000000..3ae506d
--- /dev/null
@@ -0,0 +1,3 @@
+Fix transport.set_protocol() to support switching between asyncio.Protocol
+and asyncio.BufferedProtocol.  Fix loop.start_tls() to work with
+asyncio.BufferedProtocols.