]> granicus.if.org Git - python/commitdiff
asyncio: Write flow control for proactor event loop.
authorGuido van Rossum <guido@python.org>
Wed, 4 Dec 2013 20:12:07 +0000 (12:12 -0800)
committerGuido van Rossum <guido@python.org>
Wed, 4 Dec 2013 20:12:07 +0000 (12:12 -0800)
Lib/asyncio/proactor_events.py
Lib/test/test_asyncio/test_proactor_events.py

index ce226b9ba50d14c48dba036d41432eff8787411a..979bc25fed99ced86890b60521f1f9a097f6a358 100644 (file)
@@ -24,12 +24,14 @@ class _ProactorBasePipeTransport(transports.BaseTransport):
         self._sock = sock
         self._protocol = protocol
         self._server = server
-        self._buffer = []
+        self._buffer = None  # None or bytearray.
         self._read_fut = None
         self._write_fut = None
         self._conn_lost = 0
         self._closing = False  # Set when close() called.
         self._eof_written = False
+        self._protocol_paused = False
+        self.set_write_buffer_limits()
         if self._server is not None:
             self._server.attach(self)
         self._loop.call_soon(self._protocol.connection_made, self)
@@ -63,7 +65,7 @@ class _ProactorBasePipeTransport(transports.BaseTransport):
         if self._read_fut:
             self._read_fut.cancel()
         self._write_fut = self._read_fut = None
-        self._buffer = []
+        self._buffer = None
         self._loop.call_soon(self._call_connection_lost, exc)
 
     def _call_connection_lost(self, exc):
@@ -82,6 +84,53 @@ class _ProactorBasePipeTransport(transports.BaseTransport):
                 server.detach(self)
                 self._server = None
 
+    # XXX The next four methods are nearly identical to corresponding
+    # ones in _SelectorTransport.  Maybe refactor buffer management to
+    # share the implementations?  (Also these are really only needed
+    # by _ProactorWritePipeTransport but since _buffer is defined on
+    # the base class I am putting it here for now.)
+
+    def _maybe_pause_protocol(self):
+        size = self.get_write_buffer_size()
+        if size <= self._high_water:
+            return
+        if not self._protocol_paused:
+            self._protocol_paused = True
+            try:
+                self._protocol.pause_writing()
+            except Exception:
+                logger.exception('pause_writing() failed')
+
+    def _maybe_resume_protocol(self):
+        if (self._protocol_paused and
+            self.get_write_buffer_size() <= self._low_water):
+            self._protocol_paused = False
+            try:
+                self._protocol.resume_writing()
+            except Exception:
+                logger.exception('resume_writing() failed')
+
+    def set_write_buffer_limits(self, high=None, low=None):
+        if high is None:
+            if low is None:
+                high = 64*1024
+            else:
+                high = 4*low
+        if low is None:
+            low = high // 4
+        if not high >= low >= 0:
+            raise ValueError('high (%r) must be >= low (%r) must be >= 0' %
+                             (high, low))
+        self._high_water = high
+        self._low_water = low
+
+    def get_write_buffer_size(self):
+        # NOTE: This doesn't take into account data already passed to
+        # send() even if send() hasn't finished yet.
+        if not self._buffer:
+            return 0
+        return len(self._buffer)
+
 
 class _ProactorReadPipeTransport(_ProactorBasePipeTransport,
                                  transports.ReadTransport):
@@ -95,12 +144,15 @@ class _ProactorReadPipeTransport(_ProactorBasePipeTransport,
         self._loop.call_soon(self._loop_reading)
 
     def pause_reading(self):
-        assert not self._closing, 'Cannot pause_reading() when closing'
-        assert not self._paused, 'Already paused'
+        if self._closing:
+            raise RuntimeError('Cannot pause_reading() when closing')
+        if self._paused:
+            raise RuntimeError('Already paused')
         self._paused = True
 
     def resume_reading(self):
-        assert self._paused, 'Not paused'
+        if not self._paused:
+            raise RuntimeError('Not paused')
         self._paused = False
         if self._closing:
             return
@@ -155,9 +207,11 @@ class _ProactorWritePipeTransport(_ProactorBasePipeTransport,
     """Transport for write pipes."""
 
     def write(self, data):
-        assert isinstance(data, bytes), repr(data)
+        if not isinstance(data, (bytes, bytearray, memoryview)):
+            raise TypeError('data argument must be byte-ish (%r)',
+                            type(data))
         if self._eof_written:
-            raise IOError('write_eof() already called')
+            raise RuntimeError('write_eof() already called')
 
         if not data:
             return
@@ -167,26 +221,53 @@ class _ProactorWritePipeTransport(_ProactorBasePipeTransport,
                 logger.warning('socket.send() raised exception.')
             self._conn_lost += 1
             return
-        self._buffer.append(data)
-        if self._write_fut is None:
-            self._loop_writing()
 
-    def _loop_writing(self, f=None):
+        # Observable states:
+        # 1. IDLE: _write_fut and _buffer both None
+        # 2. WRITING: _write_fut set; _buffer None
+        # 3. BACKED UP: _write_fut set; _buffer a bytearray
+        # We always copy the data, so the caller can't modify it
+        # while we're still waiting for the I/O to happen.
+        if self._write_fut is None:  # IDLE -> WRITING
+            assert self._buffer is None
+            # Pass a copy, except if it's already immutable.
+            self._loop_writing(data=bytes(data))
+            # XXX Should we pause the protocol at this point
+            # if len(data) > self._high_water?  (That would
+            # require keeping track of the number of bytes passed
+            # to a send() that hasn't finished yet.)
+        elif not self._buffer:  # WRITING -> BACKED UP
+            # Make a mutable copy which we can extend.
+            self._buffer = bytearray(data)
+            self._maybe_pause_protocol()
+        else:  # BACKED UP
+            # Append to buffer (also copies).
+            self._buffer.extend(data)
+            self._maybe_pause_protocol()
+
+    def _loop_writing(self, f=None, data=None):
         try:
             assert f is self._write_fut
             self._write_fut = None
             if f:
                 f.result()
-            data = b''.join(self._buffer)
-            self._buffer = []
+            if data is None:
+                data = self._buffer
+                self._buffer = None
             if not data:
                 if self._closing:
                     self._loop.call_soon(self._call_connection_lost, None)
                 if self._eof_written:
                     self._sock.shutdown(socket.SHUT_WR)
-                return
-            self._write_fut = self._loop._proactor.send(self._sock, data)
-            self._write_fut.add_done_callback(self._loop_writing)
+            else:
+                self._write_fut = self._loop._proactor.send(self._sock, data)
+                self._write_fut.add_done_callback(self._loop_writing)
+            # Now that we've reduced the buffer size, tell the
+            # protocol to resume writing if it was paused.  Note that
+            # we do this last since the callback is called immediately
+            # and it may add more data to the buffer (even causing the
+            # protocol to be paused again).
+            self._maybe_resume_protocol()
         except ConnectionResetError as exc:
             self._force_close(exc)
         except OSError as exc:
@@ -330,7 +411,8 @@ class BaseProactorEventLoop(base_events.BaseEventLoop):
         self._csock.send(b'x')
 
     def _start_serving(self, protocol_factory, sock, ssl=None, server=None):
-        assert not ssl, 'IocpEventLoop is incompatible with SSL.'
+        if ssl:
+            raise ValueError('IocpEventLoop is incompatible with SSL.')
 
         def loop(f=None):
             try:
index 5a2a51c42e6e8246a8cfde953315b49f5cf85dff..9964f425d21d9d7a2cd975b004cd35ce5ba0c4fd 100644 (file)
@@ -111,8 +111,8 @@ class ProactorSocketTransportTests(unittest.TestCase):
         tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol)
         tr._loop_writing = unittest.mock.Mock()
         tr.write(b'data')
-        self.assertEqual(tr._buffer, [b'data'])
-        self.assertTrue(tr._loop_writing.called)
+        self.assertEqual(tr._buffer, None)
+        tr._loop_writing.assert_called_with(data=b'data')
 
     def test_write_no_data(self):
         tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol)
@@ -124,12 +124,12 @@ class ProactorSocketTransportTests(unittest.TestCase):
         tr._write_fut = unittest.mock.Mock()
         tr._loop_writing = unittest.mock.Mock()
         tr.write(b'data')
-        self.assertEqual(tr._buffer, [b'data'])
+        self.assertEqual(tr._buffer, b'data')
         self.assertFalse(tr._loop_writing.called)
 
     def test_loop_writing(self):
         tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol)
-        tr._buffer = [b'da', b'ta']
+        tr._buffer = bytearray(b'data')
         tr._loop_writing()
         self.loop._proactor.send.assert_called_with(self.sock, b'data')
         self.loop._proactor.send.return_value.add_done_callback.\
@@ -150,7 +150,7 @@ class ProactorSocketTransportTests(unittest.TestCase):
         tr.write(b'data')
         tr.write(b'data')
         tr.write(b'data')
-        self.assertEqual(tr._buffer, [])
+        self.assertEqual(tr._buffer, None)
         m_log.warning.assert_called_with('socket.send() raised exception.')
 
     def test_loop_writing_stop(self):
@@ -226,7 +226,7 @@ class ProactorSocketTransportTests(unittest.TestCase):
         write_fut.cancel.assert_called_with()
         test_utils.run_briefly(self.loop)
         self.protocol.connection_lost.assert_called_with(None)
-        self.assertEqual([], tr._buffer)
+        self.assertEqual(None, tr._buffer)
         self.assertEqual(tr._conn_lost, 1)
 
     def test_force_close_idempotent(self):
@@ -243,7 +243,7 @@ class ProactorSocketTransportTests(unittest.TestCase):
 
         test_utils.run_briefly(self.loop)
         self.protocol.connection_lost.assert_called_with(None)
-        self.assertEqual([], tr._buffer)
+        self.assertEqual(None, tr._buffer)
 
     def test_call_connection_lost(self):
         tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol)