]> granicus.if.org Git - python/commitdiff
asyncio: Improve default writelines().
authorGuido van Rossum <guido@python.org>
Tue, 3 Dec 2013 02:36:30 +0000 (18:36 -0800)
committerGuido van Rossum <guido@python.org>
Tue, 3 Dec 2013 02:36:30 +0000 (18:36 -0800)
Lib/asyncio/transports.py
Lib/test/test_asyncio/test_transports.py

index 86b850e918a94677f6ae794e96bad61771ef2c88..c2feb93d0ac513e646c4df66c0516aa0160b6e75 100644 (file)
@@ -1,5 +1,9 @@
 """Abstract Transport class."""
 
+import sys
+
+PY34 = sys.version_info >= (3, 4)
+
 __all__ = ['ReadTransport', 'WriteTransport', 'Transport']
 
 
@@ -85,11 +89,15 @@ class WriteTransport(BaseTransport):
     def writelines(self, list_of_data):
         """Write a list (or any iterable) of data bytes to the transport.
 
-        The default implementation just calls write() for each item in
-        the list/iterable.
+        The default implementation concatenates the arguments and
+        calls write() on the result.
         """
-        for data in list_of_data:
-            self.write(data)
+        if not PY34:
+            # In Python 3.3, bytes.join() doesn't handle memoryview.
+            list_of_data = (
+                bytes(data) if isinstance(data, memoryview) else data
+                for data in list_of_data)
+        self.write(b''.join(list_of_data))
 
     def write_eof(self):
         """Close the write end after flushing buffered data.
index f96445c19c85eef5f1aca57e5348a456ac39a93f..29393b527df851dbe93bec4fcf4463b809155615 100644 (file)
@@ -24,12 +24,18 @@ class TransportTests(unittest.TestCase):
         transport = transports.Transport()
         transport.write = unittest.mock.Mock()
 
-        transport.writelines(['line1', 'line2', 'line3'])
-        self.assertEqual(3, transport.write.call_count)
+        transport.writelines([b'line1',
+                              bytearray(b'line2'),
+                              memoryview(b'line3')])
+        self.assertEqual(1, transport.write.call_count)
+        transport.write.assert_called_with(b'line1line2line3')
 
     def test_not_implemented(self):
         transport = transports.Transport()
 
+        self.assertRaises(NotImplementedError,
+                          transport.set_write_buffer_limits)
+        self.assertRaises(NotImplementedError, transport.get_write_buffer_size)
         self.assertRaises(NotImplementedError, transport.write, 'data')
         self.assertRaises(NotImplementedError, transport.write_eof)
         self.assertRaises(NotImplementedError, transport.can_write_eof)