asyncio: WriteTransport.set_write_buffer_size to call _maybe_pause_protocol
authorYury Selivanov <yselivanov@sprymix.com>
Wed, 19 Feb 2014 16:10:52 +0000 (11:10 -0500)
committerYury Selivanov <yselivanov@sprymix.com>
Wed, 19 Feb 2014 16:10:52 +0000 (11:10 -0500)
Lib/asyncio/transports.py
Lib/test/test_asyncio/test_transports.py

index 5b975aa7315c269ade88a5495aede2e1d67c46b7..5f674f99d773f2a09252c7a27de89ab0e6dc318b 100644 (file)
@@ -241,7 +241,7 @@ class _FlowControlMixin(Transport):
     def __init__(self, extra=None):
         super().__init__(extra)
         self._protocol_paused = False
-        self.set_write_buffer_limits()
+        self._set_write_buffer_limits()
 
     def _maybe_pause_protocol(self):
         size = self.get_write_buffer_size()
@@ -273,7 +273,7 @@ class _FlowControlMixin(Transport):
                     'protocol': self._protocol,
                 })
 
-    def set_write_buffer_limits(self, high=None, low=None):
+    def _set_write_buffer_limits(self, high=None, low=None):
         if high is None:
             if low is None:
                 high = 64*1024
@@ -287,5 +287,9 @@ class _FlowControlMixin(Transport):
         self._high_water = high
         self._low_water = low
 
+    def set_write_buffer_limits(self, high=None, low=None):
+        self._set_write_buffer_limits(high=high, low=low)
+        self._maybe_pause_protocol()
+
     def get_write_buffer_size(self):
         raise NotImplementedError
index d16db8074b6203ced306ee669a6a4dbaff13b6c6..4c645268d945a788037572bffc2b81791f8d67c6 100644 (file)
@@ -4,6 +4,7 @@ import unittest
 import unittest.mock
 
 import asyncio
+from asyncio import transports
 
 
 class TransportTests(unittest.TestCase):
@@ -60,6 +61,28 @@ class TransportTests(unittest.TestCase):
         self.assertRaises(NotImplementedError, transport.terminate)
         self.assertRaises(NotImplementedError, transport.kill)
 
+    def test_flowcontrol_mixin_set_write_limits(self):
+
+        class MyTransport(transports._FlowControlMixin,
+                          transports.Transport):
+
+            def get_write_buffer_size(self):
+                return 512
+
+        transport = MyTransport()
+        transport._protocol = unittest.mock.Mock()
+
+        self.assertFalse(transport._protocol_paused)
+
+        with self.assertRaisesRegex(ValueError, 'high.*must be >= low'):
+            transport.set_write_buffer_limits(high=0, low=1)
+
+        transport.set_write_buffer_limits(high=1024, low=128)
+        self.assertFalse(transport._protocol_paused)
+
+        transport.set_write_buffer_limits(high=256, low=128)
+        self.assertTrue(transport._protocol_paused)
+
 
 if __name__ == '__main__':
     unittest.main()