]> granicus.if.org Git - python/commitdiff
asyncio.streams: Use bytebuffer in StreamReader; Add assertion in feed_data
authorYury Selivanov <yselivanov@sprymix.com>
Wed, 5 Feb 2014 23:11:13 +0000 (18:11 -0500)
committerYury Selivanov <yselivanov@sprymix.com>
Wed, 5 Feb 2014 23:11:13 +0000 (18:11 -0500)
Lib/asyncio/streams.py
Lib/test/test_asyncio/test_streams.py

index 06f052a2dc16bfae4b4f3957eed30df804fd491b..3da1d10facb3b4610d173da53128e34684d68764 100644 (file)
@@ -4,8 +4,6 @@ __all__ = ['StreamReader', 'StreamWriter', 'StreamReaderProtocol',
            'open_connection', 'start_server', 'IncompleteReadError',
            ]
 
-import collections
-
 from . import events
 from . import futures
 from . import protocols
@@ -259,9 +257,7 @@ class StreamReader:
         if loop is None:
             loop = events.get_event_loop()
         self._loop = loop
-        # TODO: Use a bytearray for a buffer, like the transport.
-        self._buffer = collections.deque()  # Deque of bytes objects.
-        self._byte_count = 0  # Bytes in buffer.
+        self._buffer = bytearray()
         self._eof = False  # Whether we're done.
         self._waiter = None  # A future.
         self._exception = None
@@ -285,7 +281,7 @@ class StreamReader:
         self._transport = transport
 
     def _maybe_resume_transport(self):
-        if self._paused and self._byte_count <= self._limit:
+        if self._paused and len(self._buffer) <= self._limit:
             self._paused = False
             self._transport.resume_reading()
 
@@ -298,11 +294,12 @@ class StreamReader:
                 waiter.set_result(True)
 
     def feed_data(self, data):
+        assert not self._eof, 'feed_data after feed_eof'
+
         if not data:
             return
 
-        self._buffer.append(data)
-        self._byte_count += len(data)
+        self._buffer.extend(data)
 
         waiter = self._waiter
         if waiter is not None:
@@ -312,7 +309,7 @@ class StreamReader:
 
         if (self._transport is not None and
             not self._paused and
-            self._byte_count > 2*self._limit):
+            len(self._buffer) > 2*self._limit):
             try:
                 self._transport.pause_reading()
             except NotImplementedError:
@@ -338,28 +335,22 @@ class StreamReader:
         if self._exception is not None:
             raise self._exception
 
-        parts = []
-        parts_size = 0
+        line = bytearray()
         not_enough = True
 
         while not_enough:
             while self._buffer and not_enough:
-                data = self._buffer.popleft()
-                ichar = data.find(b'\n')
+                ichar = self._buffer.find(b'\n')
                 if ichar < 0:
-                    parts.append(data)
-                    parts_size += len(data)
+                    line.extend(self._buffer)
+                    self._buffer.clear()
                 else:
                     ichar += 1
-                    head, tail = data[:ichar], data[ichar:]
-                    if tail:
-                        self._buffer.appendleft(tail)
+                    line.extend(self._buffer[:ichar])
+                    del self._buffer[:ichar]
                     not_enough = False
-                    parts.append(head)
-                    parts_size += len(head)
 
-                if parts_size > self._limit:
-                    self._byte_count -= parts_size
+                if len(line) > self._limit:
                     self._maybe_resume_transport()
                     raise ValueError('Line is too long')
 
@@ -373,11 +364,8 @@ class StreamReader:
                 finally:
                     self._waiter = None
 
-        line = b''.join(parts)
-        self._byte_count -= parts_size
         self._maybe_resume_transport()
-
-        return line
+        return bytes(line)
 
     @tasks.coroutine
     def read(self, n=-1):
@@ -395,36 +383,23 @@ class StreamReader:
                 finally:
                     self._waiter = None
         else:
-            if not self._byte_count and not self._eof:
+            if not self._buffer and not self._eof:
                 self._waiter = self._create_waiter('read')
                 try:
                     yield from self._waiter
                 finally:
                     self._waiter = None
 
-        if n < 0 or self._byte_count <= n:
-            data = b''.join(self._buffer)
+        if n < 0 or len(self._buffer) <= n:
+            data = bytes(self._buffer)
             self._buffer.clear()
-            self._byte_count = 0
-            self._maybe_resume_transport()
-            return data
-
-        parts = []
-        parts_bytes = 0
-        while self._buffer and parts_bytes < n:
-            data = self._buffer.popleft()
-            data_bytes = len(data)
-            if n < parts_bytes + data_bytes:
-                data_bytes = n - parts_bytes
-                data, rest = data[:data_bytes], data[data_bytes:]
-                self._buffer.appendleft(rest)
-
-            parts.append(data)
-            parts_bytes += data_bytes
-            self._byte_count -= data_bytes
-            self._maybe_resume_transport()
-
-        return b''.join(parts)
+        else:
+            # n > 0 and len(self._buffer) > n
+            data = bytes(self._buffer[:n])
+            del self._buffer[:n]
+
+        self._maybe_resume_transport()
+        return data
 
     @tasks.coroutine
     def readexactly(self, n):
index 01d565cd8b324ccb8a307b0bc47570dab0c7c6bc..83474a87eee962aea5cf84fdee5ff52d66db60f0 100644 (file)
@@ -79,13 +79,13 @@ class StreamReaderTests(unittest.TestCase):
         stream = asyncio.StreamReader(loop=self.loop)
 
         stream.feed_data(b'')
-        self.assertEqual(0, stream._byte_count)
+        self.assertEqual(b'', stream._buffer)
 
-    def test_feed_data_byte_count(self):
+    def test_feed_nonempty_data(self):
         stream = asyncio.StreamReader(loop=self.loop)
 
         stream.feed_data(self.DATA)
-        self.assertEqual(len(self.DATA), stream._byte_count)
+        self.assertEqual(self.DATA, stream._buffer)
 
     def test_read_zero(self):
         # Read zero bytes.
@@ -94,7 +94,7 @@ class StreamReaderTests(unittest.TestCase):
 
         data = self.loop.run_until_complete(stream.read(0))
         self.assertEqual(b'', data)
-        self.assertEqual(len(self.DATA), stream._byte_count)
+        self.assertEqual(self.DATA, stream._buffer)
 
     def test_read(self):
         # Read bytes.
@@ -107,7 +107,7 @@ class StreamReaderTests(unittest.TestCase):
 
         data = self.loop.run_until_complete(read_task)
         self.assertEqual(self.DATA, data)
-        self.assertFalse(stream._byte_count)
+        self.assertEqual(b'', stream._buffer)
 
     def test_read_line_breaks(self):
         # Read bytes without line breaks.
@@ -118,7 +118,7 @@ class StreamReaderTests(unittest.TestCase):
         data = self.loop.run_until_complete(stream.read(5))
 
         self.assertEqual(b'line1', data)
-        self.assertEqual(5, stream._byte_count)
+        self.assertEqual(b'line2', stream._buffer)
 
     def test_read_eof(self):
         # Read bytes, stop at eof.
@@ -131,7 +131,7 @@ class StreamReaderTests(unittest.TestCase):
 
         data = self.loop.run_until_complete(read_task)
         self.assertEqual(b'', data)
-        self.assertFalse(stream._byte_count)
+        self.assertEqual(b'', stream._buffer)
 
     def test_read_until_eof(self):
         # Read all bytes until eof.
@@ -147,7 +147,7 @@ class StreamReaderTests(unittest.TestCase):
         data = self.loop.run_until_complete(read_task)
 
         self.assertEqual(b'chunk1\nchunk2', data)
-        self.assertFalse(stream._byte_count)
+        self.assertEqual(b'', stream._buffer)
 
     def test_read_exception(self):
         stream = asyncio.StreamReader(loop=self.loop)
@@ -161,7 +161,8 @@ class StreamReaderTests(unittest.TestCase):
             ValueError, self.loop.run_until_complete, stream.read(2))
 
     def test_readline(self):
-        # Read one line.
+        # Read one line. 'readline' will need to wait for the data
+        # to come from 'cb'
         stream = asyncio.StreamReader(loop=self.loop)
         stream.feed_data(b'chunk1 ')
         read_task = asyncio.Task(stream.readline(), loop=self.loop)
@@ -174,30 +175,40 @@ class StreamReaderTests(unittest.TestCase):
 
         line = self.loop.run_until_complete(read_task)
         self.assertEqual(b'chunk1 chunk2 chunk3 \n', line)
-        self.assertEqual(len(b'\n chunk4')-1, stream._byte_count)
+        self.assertEqual(b' chunk4', stream._buffer)
 
     def test_readline_limit_with_existing_data(self):
-        stream = asyncio.StreamReader(3, loop=self.loop)
+        # Read one line. The data is in StreamReader's buffer
+        # before the event loop is run.
+
+        stream = asyncio.StreamReader(limit=3, loop=self.loop)
         stream.feed_data(b'li')
         stream.feed_data(b'ne1\nline2\n')
 
         self.assertRaises(
             ValueError, self.loop.run_until_complete, stream.readline())
-        self.assertEqual([b'line2\n'], list(stream._buffer))
+        # The buffer should contain the remaining data after exception
+        self.assertEqual(b'line2\n', stream._buffer)
 
-        stream = asyncio.StreamReader(3, loop=self.loop)
+        stream = asyncio.StreamReader(limit=3, loop=self.loop)
         stream.feed_data(b'li')
         stream.feed_data(b'ne1')
         stream.feed_data(b'li')
 
         self.assertRaises(
             ValueError, self.loop.run_until_complete, stream.readline())
-        self.assertEqual([b'li'], list(stream._buffer))
-        self.assertEqual(2, stream._byte_count)
+        # No b'\n' at the end. The 'limit' is set to 3. So before
+        # waiting for the new data in buffer, 'readline' will consume
+        # the entire buffer, and since the length of the consumed data
+        # is more than 3, it will raise a ValudError. The buffer is
+        # expected to be empty now.
+        self.assertEqual(b'', stream._buffer)
 
     def test_readline_limit(self):
-        stream = asyncio.StreamReader(7, loop=self.loop)
+        # Read one line. StreamReaders are fed with data after
+        # their 'readline' methods are called.
 
+        stream = asyncio.StreamReader(limit=7, loop=self.loop)
         def cb():
             stream.feed_data(b'chunk1')
             stream.feed_data(b'chunk2')
@@ -207,10 +218,25 @@ class StreamReaderTests(unittest.TestCase):
 
         self.assertRaises(
             ValueError, self.loop.run_until_complete, stream.readline())
-        self.assertEqual([b'chunk3\n'], list(stream._buffer))
-        self.assertEqual(7, stream._byte_count)
+        # The buffer had just one line of data, and after raising
+        # a ValueError it should be empty.
+        self.assertEqual(b'', stream._buffer)
+
+        stream = asyncio.StreamReader(limit=7, loop=self.loop)
+        def cb():
+            stream.feed_data(b'chunk1')
+            stream.feed_data(b'chunk2\n')
+            stream.feed_data(b'chunk3\n')
+            stream.feed_eof()
+        self.loop.call_soon(cb)
+
+        self.assertRaises(
+            ValueError, self.loop.run_until_complete, stream.readline())
+        self.assertEqual(b'chunk3\n', stream._buffer)
 
-    def test_readline_line_byte_count(self):
+    def test_readline_nolimit_nowait(self):
+        # All needed data for the first 'readline' call will be
+        # in the buffer.
         stream = asyncio.StreamReader(loop=self.loop)
         stream.feed_data(self.DATA[:6])
         stream.feed_data(self.DATA[6:])
@@ -218,7 +244,7 @@ class StreamReaderTests(unittest.TestCase):
         line = self.loop.run_until_complete(stream.readline())
 
         self.assertEqual(b'line1\n', line)
-        self.assertEqual(len(self.DATA) - len(b'line1\n'), stream._byte_count)
+        self.assertEqual(b'line2\nline3\n', stream._buffer)
 
     def test_readline_eof(self):
         stream = asyncio.StreamReader(loop=self.loop)
@@ -244,9 +270,7 @@ class StreamReaderTests(unittest.TestCase):
         data = self.loop.run_until_complete(stream.read(7))
 
         self.assertEqual(b'line2\nl', data)
-        self.assertEqual(
-            len(self.DATA) - len(b'line1\n') - len(b'line2\nl'),
-            stream._byte_count)
+        self.assertEqual(b'ine3\n', stream._buffer)
 
     def test_readline_exception(self):
         stream = asyncio.StreamReader(loop=self.loop)
@@ -258,6 +282,7 @@ class StreamReaderTests(unittest.TestCase):
         stream.set_exception(ValueError())
         self.assertRaises(
             ValueError, self.loop.run_until_complete, stream.readline())
+        self.assertEqual(b'', stream._buffer)
 
     def test_readexactly_zero_or_less(self):
         # Read exact number of bytes (zero or less).
@@ -266,11 +291,11 @@ class StreamReaderTests(unittest.TestCase):
 
         data = self.loop.run_until_complete(stream.readexactly(0))
         self.assertEqual(b'', data)
-        self.assertEqual(len(self.DATA), stream._byte_count)
+        self.assertEqual(self.DATA, stream._buffer)
 
         data = self.loop.run_until_complete(stream.readexactly(-1))
         self.assertEqual(b'', data)
-        self.assertEqual(len(self.DATA), stream._byte_count)
+        self.assertEqual(self.DATA, stream._buffer)
 
     def test_readexactly(self):
         # Read exact number of bytes.
@@ -287,7 +312,7 @@ class StreamReaderTests(unittest.TestCase):
 
         data = self.loop.run_until_complete(read_task)
         self.assertEqual(self.DATA + self.DATA, data)
-        self.assertEqual(len(self.DATA), stream._byte_count)
+        self.assertEqual(self.DATA, stream._buffer)
 
     def test_readexactly_eof(self):
         # Read exact number of bytes (eof).
@@ -306,7 +331,7 @@ class StreamReaderTests(unittest.TestCase):
         self.assertEqual(cm.exception.expected, n)
         self.assertEqual(str(cm.exception),
                          '18 bytes read on a total of 36 expected bytes')
-        self.assertFalse(stream._byte_count)
+        self.assertEqual(b'', stream._buffer)
 
     def test_readexactly_exception(self):
         stream = asyncio.StreamReader(loop=self.loop)