From 2b57016458bc8a4400c03e204b431e2723a0e579 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 1 Nov 2013 14:18:02 -0700 Subject: [PATCH] asyncio: Refactor ssl transport ready loop (Nikolay Kim). --- Lib/asyncio/selector_events.py | 94 ++++++------ Lib/test/test_asyncio/test_selector_events.py | 134 +++++++++++------- 2 files changed, 136 insertions(+), 92 deletions(-) diff --git a/Lib/asyncio/selector_events.py b/Lib/asyncio/selector_events.py index 44430b2168..a975dbb78e 100644 --- a/Lib/asyncio/selector_events.py +++ b/Lib/asyncio/selector_events.py @@ -286,7 +286,7 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop): err = sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) if err != 0: # Jump to the except clause below. - raise OSError(err, 'Connect call failed') + raise OSError(err, 'Connect call failed %s' % (address,)) except (BlockingIOError, InterruptedError): self.add_writer(fd, self._sock_connect, fut, True, sock, address) except Exception as exc: @@ -413,7 +413,7 @@ class _SelectorTransport(transports.Transport): try: self._protocol.pause_writing() except Exception: - tulip_log.exception('pause_writing() failed') + logger.exception('pause_writing() failed') def _maybe_resume_protocol(self): if (self._protocol_paused and @@ -422,7 +422,7 @@ class _SelectorTransport(transports.Transport): try: self._protocol.resume_writing() except Exception: - tulip_log.exception('resume_writing() failed') + logger.exception('resume_writing() failed') def set_write_buffer_limits(self, high=None, low=None): if high is None: @@ -635,15 +635,16 @@ class _SelectorSslTransport(_SelectorTransport): compression=self._sock.compression(), ) - self._loop.add_reader(self._sock_fd, self._on_ready) - self._loop.add_writer(self._sock_fd, self._on_ready) + self._read_wants_write = False + self._write_wants_read = False + self._loop.add_reader(self._sock_fd, self._read_ready) self._loop.call_soon(self._protocol.connection_made, self) if self._waiter is not None: self._loop.call_soon(self._waiter.set_result, None) def pause_reading(self): # XXX This is a bit icky, given the comment at the top of - # _on_ready(). Is it possible to evoke a deadlock? I don't + # _read_ready(). Is it possible to evoke a deadlock? I don't # know, although it doesn't look like it; write() will still # accept more data for the buffer and eventually the app will # call resume_reading() again, and things will flow again. @@ -658,41 +659,55 @@ class _SelectorSslTransport(_SelectorTransport): self._paused = False if self._closing: return - self._loop.add_reader(self._sock_fd, self._on_ready) + self._loop.add_reader(self._sock_fd, self._read_ready) - def _on_ready(self): - # Because of renegotiations (?), there's no difference between - # readable and writable. We just try both. XXX This may be - # incorrect; we probably need to keep state about what we - # should do next. + def _read_ready(self): + if self._write_wants_read: + self._write_wants_read = False + self._write_ready() - # First try reading. - if not self._closing and not self._paused: - try: - data = self._sock.recv(self.max_size) - except (BlockingIOError, InterruptedError, - ssl.SSLWantReadError, ssl.SSLWantWriteError): - pass - except Exception as exc: - self._fatal_error(exc) + if self._buffer: + self._loop.add_writer(self._sock_fd, self._write_ready) + + try: + data = self._sock.recv(self.max_size) + except (BlockingIOError, InterruptedError, ssl.SSLWantReadError): + pass + except ssl.SSLWantWriteError: + self._read_wants_write = True + self._loop.remove_reader(self._sock_fd) + self._loop.add_writer(self._sock_fd, self._write_ready) + except Exception as exc: + self._fatal_error(exc) + else: + if data: + self._protocol.data_received(data) else: - if data: - self._protocol.data_received(data) - else: - try: - self._protocol.eof_received() - finally: - self.close() + try: + self._protocol.eof_received() + finally: + self.close() + + def _write_ready(self): + if self._read_wants_write: + self._read_wants_write = False + self._read_ready() + + if not (self._paused or self._closing): + self._loop.add_reader(self._sock_fd, self._read_ready) - # Now try writing, if there's anything to write. if self._buffer: data = b''.join(self._buffer) self._buffer.clear() try: n = self._sock.send(data) except (BlockingIOError, InterruptedError, - ssl.SSLWantReadError, ssl.SSLWantWriteError): + ssl.SSLWantWriteError): n = 0 + except ssl.SSLWantReadError: + n = 0 + self._loop.remove_writer(self._sock_fd) + self._write_wants_read = True except Exception as exc: self._loop.remove_writer(self._sock_fd) self._fatal_error(exc) @@ -701,11 +716,12 @@ class _SelectorSslTransport(_SelectorTransport): if n < len(data): self._buffer.append(data[n:]) - self._maybe_resume_protocol() # May append to buffer. + self._maybe_resume_protocol() # May append to buffer. - if self._closing and not self._buffer: + if not self._buffer: self._loop.remove_writer(self._sock_fd) - self._call_connection_lost(None) + if self._closing: + self._call_connection_lost(None) def write(self, data): assert isinstance(data, bytes), repr(type(data)) @@ -718,20 +734,16 @@ class _SelectorSslTransport(_SelectorTransport): self._conn_lost += 1 return - # We could optimize, but the callback can do this for now. + if not self._buffer: + self._loop.add_writer(self._sock_fd, self._write_ready) + + # Add it to the buffer. self._buffer.append(data) self._maybe_pause_protocol() def can_write_eof(self): return False - def close(self): - if self._closing: - return - self._closing = True - self._conn_lost += 1 - self._loop.remove_reader(self._sock_fd) - class _SelectorDatagramTransport(_SelectorTransport): diff --git a/Lib/test/test_asyncio/test_selector_events.py b/Lib/test/test_asyncio/test_selector_events.py index fbd5d723d5..3b8238d557 100644 --- a/Lib/test/test_asyncio/test_selector_events.py +++ b/Lib/test/test_asyncio/test_selector_events.py @@ -1003,8 +1003,7 @@ class SelectorSslTransportTests(unittest.TestCase): self.loop, self.sock, self.protocol, self.sslcontext, waiter=waiter) self.assertTrue(self.sslsock.do_handshake.called) - self.loop.assert_reader(1, tr._on_ready) - self.loop.assert_writer(1, tr._on_ready) + self.loop.assert_reader(1, tr._read_ready) test_utils.run_briefly(self.loop) self.assertIsNone(waiter.result()) @@ -1047,13 +1046,13 @@ class SelectorSslTransportTests(unittest.TestCase): def test_pause_resume_reading(self): tr = self._make_one() self.assertFalse(tr._paused) - self.loop.assert_reader(1, tr._on_ready) + self.loop.assert_reader(1, tr._read_ready) tr.pause_reading() self.assertTrue(tr._paused) self.assertFalse(1 in self.loop.readers) tr.resume_reading() self.assertFalse(tr._paused) - self.loop.assert_reader(1, tr._on_ready) + self.loop.assert_reader(1, tr._read_ready) def test_write_no_data(self): transport = self._make_one() @@ -1084,140 +1083,173 @@ class SelectorSslTransportTests(unittest.TestCase): transport.write(b'data') m_log.warning.assert_called_with('socket.send() raised exception.') - def test_on_ready_recv(self): + def test_read_ready_recv(self): self.sslsock.recv.return_value = b'data' transport = self._make_one() - transport._on_ready() + transport._read_ready() self.assertTrue(self.sslsock.recv.called) self.assertEqual((b'data',), self.protocol.data_received.call_args[0]) - def test_on_ready_recv_eof(self): + def test_read_ready_write_wants_read(self): + self.loop.add_writer = unittest.mock.Mock() + self.sslsock.recv.side_effect = BlockingIOError + transport = self._make_one() + transport._write_wants_read = True + transport._write_ready = unittest.mock.Mock() + transport._buffer.append(b'data') + transport._read_ready() + + self.assertFalse(transport._write_wants_read) + transport._write_ready.assert_called_with() + self.loop.add_writer.assert_called_with( + transport._sock_fd, transport._write_ready) + + def test_read_ready_recv_eof(self): self.sslsock.recv.return_value = b'' transport = self._make_one() transport.close = unittest.mock.Mock() - transport._on_ready() + transport._read_ready() transport.close.assert_called_with() self.protocol.eof_received.assert_called_with() - def test_on_ready_recv_conn_reset(self): + def test_read_ready_recv_conn_reset(self): err = self.sslsock.recv.side_effect = ConnectionResetError() transport = self._make_one() transport._force_close = unittest.mock.Mock() - transport._on_ready() + transport._read_ready() transport._force_close.assert_called_with(err) - def test_on_ready_recv_retry(self): + def test_read_ready_recv_retry(self): self.sslsock.recv.side_effect = ssl.SSLWantReadError transport = self._make_one() - transport._on_ready() + transport._read_ready() self.assertTrue(self.sslsock.recv.called) self.assertFalse(self.protocol.data_received.called) - self.sslsock.recv.side_effect = ssl.SSLWantWriteError - transport._on_ready() - self.assertFalse(self.protocol.data_received.called) - self.sslsock.recv.side_effect = BlockingIOError - transport._on_ready() + transport._read_ready() self.assertFalse(self.protocol.data_received.called) self.sslsock.recv.side_effect = InterruptedError - transport._on_ready() + transport._read_ready() + self.assertFalse(self.protocol.data_received.called) + + def test_read_ready_recv_write(self): + self.loop.remove_reader = unittest.mock.Mock() + self.loop.add_writer = unittest.mock.Mock() + self.sslsock.recv.side_effect = ssl.SSLWantWriteError + transport = self._make_one() + transport._read_ready() self.assertFalse(self.protocol.data_received.called) + self.assertTrue(transport._read_wants_write) - def test_on_ready_recv_exc(self): + self.loop.remove_reader.assert_called_with(transport._sock_fd) + self.loop.add_writer.assert_called_with( + transport._sock_fd, transport._write_ready) + + def test_read_ready_recv_exc(self): err = self.sslsock.recv.side_effect = OSError() transport = self._make_one() transport._fatal_error = unittest.mock.Mock() - transport._on_ready() + transport._read_ready() transport._fatal_error.assert_called_with(err) - def test_on_ready_send(self): - self.sslsock.recv.side_effect = ssl.SSLWantReadError + def test_write_ready_send(self): self.sslsock.send.return_value = 4 transport = self._make_one() transport._buffer = collections.deque([b'data']) - transport._on_ready() + transport._write_ready() self.assertEqual(collections.deque(), transport._buffer) self.assertTrue(self.sslsock.send.called) - def test_on_ready_send_none(self): - self.sslsock.recv.side_effect = ssl.SSLWantReadError + def test_write_ready_send_none(self): self.sslsock.send.return_value = 0 transport = self._make_one() transport._buffer = collections.deque([b'data1', b'data2']) - transport._on_ready() + transport._write_ready() self.assertTrue(self.sslsock.send.called) self.assertEqual(collections.deque([b'data1data2']), transport._buffer) - def test_on_ready_send_partial(self): - self.sslsock.recv.side_effect = ssl.SSLWantReadError + def test_write_ready_send_partial(self): self.sslsock.send.return_value = 2 transport = self._make_one() transport._buffer = collections.deque([b'data1', b'data2']) - transport._on_ready() + transport._write_ready() self.assertTrue(self.sslsock.send.called) self.assertEqual(collections.deque([b'ta1data2']), transport._buffer) - def test_on_ready_send_closing_partial(self): - self.sslsock.recv.side_effect = ssl.SSLWantReadError + def test_write_ready_send_closing_partial(self): self.sslsock.send.return_value = 2 transport = self._make_one() transport._buffer = collections.deque([b'data1', b'data2']) - transport._on_ready() + transport._write_ready() self.assertTrue(self.sslsock.send.called) self.assertFalse(self.sslsock.close.called) - def test_on_ready_send_closing(self): - self.sslsock.recv.side_effect = ssl.SSLWantReadError + def test_write_ready_send_closing(self): self.sslsock.send.return_value = 4 transport = self._make_one() transport.close() transport._buffer = collections.deque([b'data']) - transport._on_ready() + transport._write_ready() self.assertFalse(self.loop.writers) self.protocol.connection_lost.assert_called_with(None) - def test_on_ready_send_closing_empty_buffer(self): - self.sslsock.recv.side_effect = ssl.SSLWantReadError + def test_write_ready_send_closing_empty_buffer(self): self.sslsock.send.return_value = 4 transport = self._make_one() transport.close() transport._buffer = collections.deque() - transport._on_ready() + transport._write_ready() self.assertFalse(self.loop.writers) self.protocol.connection_lost.assert_called_with(None) - def test_on_ready_send_retry(self): - self.sslsock.recv.side_effect = ssl.SSLWantReadError - + def test_write_ready_send_retry(self): transport = self._make_one() transport._buffer = collections.deque([b'data']) - self.sslsock.send.side_effect = ssl.SSLWantReadError - transport._on_ready() - self.assertTrue(self.sslsock.send.called) - self.assertEqual(collections.deque([b'data']), transport._buffer) - self.sslsock.send.side_effect = ssl.SSLWantWriteError - transport._on_ready() + transport._write_ready() self.assertEqual(collections.deque([b'data']), transport._buffer) self.sslsock.send.side_effect = BlockingIOError() - transport._on_ready() + transport._write_ready() self.assertEqual(collections.deque([b'data']), transport._buffer) - def test_on_ready_send_exc(self): - self.sslsock.recv.side_effect = ssl.SSLWantReadError + def test_write_ready_send_read(self): + transport = self._make_one() + transport._buffer = collections.deque([b'data']) + + self.loop.remove_writer = unittest.mock.Mock() + self.sslsock.send.side_effect = ssl.SSLWantReadError + transport._write_ready() + self.assertFalse(self.protocol.data_received.called) + self.assertTrue(transport._write_wants_read) + self.loop.remove_writer.assert_called_with(transport._sock_fd) + + def test_write_ready_send_exc(self): err = self.sslsock.send.side_effect = OSError() transport = self._make_one() transport._buffer = collections.deque([b'data']) transport._fatal_error = unittest.mock.Mock() - transport._on_ready() + transport._write_ready() transport._fatal_error.assert_called_with(err) self.assertEqual(collections.deque(), transport._buffer) + def test_write_ready_read_wants_write(self): + self.loop.add_reader = unittest.mock.Mock() + self.sslsock.send.side_effect = BlockingIOError + transport = self._make_one() + transport._read_wants_write = True + transport._read_ready = unittest.mock.Mock() + transport._write_ready() + + self.assertFalse(transport._read_wants_write) + transport._read_ready.assert_called_with() + self.loop.add_reader.assert_called_with( + transport._sock_fd, transport._read_ready) + def test_write_eof(self): tr = self._make_one() self.assertFalse(tr.can_write_eof()) -- 2.40.0