]> granicus.if.org Git - python/commitdiff
Issue #26050: Add asyncio.StreamReader.readuntil() method.
authorYury Selivanov <yselivanov@sprymix.com>
Mon, 11 Jan 2016 17:28:19 +0000 (12:28 -0500)
committerYury Selivanov <yselivanov@sprymix.com>
Mon, 11 Jan 2016 17:28:19 +0000 (12:28 -0500)
Patch by Марк Коренберг.

Lib/asyncio/streams.py
Lib/test/test_asyncio/test_streams.py
Misc/NEWS

index 9097e38271de8d998d5ee56e51c26a01e93104eb..0008d514508a90472f2fbda2ce5f6f49de5ad2df 100644 (file)
@@ -3,6 +3,7 @@
 __all__ = ['StreamReader', 'StreamWriter', 'StreamReaderProtocol',
            'open_connection', 'start_server',
            'IncompleteReadError',
+           'LimitOverrunError',
            ]
 
 import socket
@@ -27,15 +28,28 @@ class IncompleteReadError(EOFError):
     Incomplete read error. Attributes:
 
     - partial: read bytes string before the end of stream was reached
-    - expected: total number of expected bytes
+    - expected: total number of expected bytes (or None if unknown)
     """
     def __init__(self, partial, expected):
-        EOFError.__init__(self, "%s bytes read on a total of %s expected bytes"
-                                % (len(partial), expected))
+        super().__init__("%d bytes read on a total of %r expected bytes"
+                         % (len(partial), expected))
         self.partial = partial
         self.expected = expected
 
 
+class LimitOverrunError(Exception):
+    """Reached buffer limit while looking for the separator.
+
+    Attributes:
+    - message: error message
+    - consumed: total number of bytes that should be consumed
+    """
+    def __init__(self, message, consumed):
+        super().__init__(message)
+        self.message = message
+        self.consumed = consumed
+
+
 @coroutine
 def open_connection(host=None, port=None, *,
                     loop=None, limit=_DEFAULT_LIMIT, **kwds):
@@ -318,6 +332,10 @@ class StreamReader:
     def __init__(self, limit=_DEFAULT_LIMIT, loop=None):
         # The line length limit is  a security feature;
         # it also doubles as half the buffer limit.
+
+        if limit <= 0:
+            raise ValueError('Limit cannot be <= 0')
+
         self._limit = limit
         if loop is None:
             self._loop = events.get_event_loop()
@@ -361,7 +379,7 @@ class StreamReader:
                 waiter.set_exception(exc)
 
     def _wakeup_waiter(self):
-        """Wakeup read() or readline() function waiting for data or EOF."""
+        """Wakeup read*() functions waiting for data or EOF."""
         waiter = self._waiter
         if waiter is not None:
             self._waiter = None
@@ -409,7 +427,10 @@ class StreamReader:
 
     @coroutine
     def _wait_for_data(self, func_name):
-        """Wait until feed_data() or feed_eof() is called."""
+        """Wait until feed_data() or feed_eof() is called.
+
+        If stream was paused, automatically resume it.
+        """
         # StreamReader uses a future to link the protocol feed_data() method
         # to a read coroutine. Running two read coroutines at the same time
         # would have an unexpected behaviour. It would not possible to know
@@ -418,6 +439,13 @@ class StreamReader:
             raise RuntimeError('%s() called while another coroutine is '
                                'already waiting for incoming data' % func_name)
 
+        assert not self._eof, '_wait_for_data after EOF'
+
+        # Waiting for data while paused will make deadlock, so prevent it.
+        if self._paused:
+            self._paused = False
+            self._transport.resume_reading()
+
         self._waiter = futures.Future(loop=self._loop)
         try:
             yield from self._waiter
@@ -426,43 +454,150 @@ class StreamReader:
 
     @coroutine
     def readline(self):
+        """Read chunk of data from the stream until newline (b'\n') is found.
+
+        On success, return chunk that ends with newline. If only partial
+        line can be read due to EOF, return incomplete line without
+        terminating newline. When EOF was reached while no bytes read, empty
+        bytes object is returned.
+
+        If limit is reached, ValueError will be raised. In that case, if
+        newline was found, complete line including newline will be removed
+        from internal buffer. Else, internal buffer will be cleared. Limit is
+        compared against part of the line without newline.
+
+        If stream was paused, this function will automatically resume it if
+        needed.
+        """
+        sep = b'\n'
+        seplen = len(sep)
+        try:
+            line = yield from self.readuntil(sep)
+        except IncompleteReadError as e:
+            return e.partial
+        except LimitOverrunError as e:
+            if self._buffer.startswith(sep, e.consumed):
+                del self._buffer[:e.consumed + seplen]
+            else:
+                self._buffer.clear()
+            self._maybe_resume_transport()
+            raise ValueError(e.args[0])
+        return line
+
+    @coroutine
+    def readuntil(self, separator=b'\n'):
+        """Read chunk of data from the stream until `separator` is found.
+
+        On success, chunk and its separator will be removed from internal buffer
+        (i.e. consumed). Returned chunk will include separator at the end.
+
+        Configured stream limit is used to check result. Limit means maximal
+        length of chunk that can be returned, not counting the separator.
+
+        If EOF occurs and complete separator still not found,
+        IncompleteReadError(<partial data>, None) will be raised and internal
+        buffer becomes empty. This partial data may contain a partial separator.
+
+        If chunk cannot be read due to overlimit, LimitOverrunError will be raised
+        and data will be left in internal buffer, so it can be read again, in
+        some different way.
+
+        If stream was paused, this function will automatically resume it if
+        needed.
+        """
+        seplen = len(separator)
+        if seplen == 0:
+            raise ValueError('Separator should be at least one-byte string')
+
         if self._exception is not None:
             raise self._exception
 
-        line = bytearray()
-        not_enough = True
-
-        while not_enough:
-            while self._buffer and not_enough:
-                ichar = self._buffer.find(b'\n')
-                if ichar < 0:
-                    line.extend(self._buffer)
-                    self._buffer.clear()
-                else:
-                    ichar += 1
-                    line.extend(self._buffer[:ichar])
-                    del self._buffer[:ichar]
-                    not_enough = False
-
-                if len(line) > self._limit:
-                    self._maybe_resume_transport()
-                    raise ValueError('Line is too long')
+        # Consume whole buffer except last bytes, which length is
+        # one less than seplen. Let's check corner cases with
+        # separator='SEPARATOR':
+        # * we have received almost complete separator (without last
+        #   byte). i.e buffer='some textSEPARATO'. In this case we
+        #   can safely consume len(separator) - 1 bytes.
+        # * last byte of buffer is first byte of separator, i.e.
+        #   buffer='abcdefghijklmnopqrS'. We may safely consume
+        #   everything except that last byte, but this require to
+        #   analyze bytes of buffer that match partial separator.
+        #   This is slow and/or require FSM. For this case our
+        #   implementation is not optimal, since require rescanning
+        #   of data that is known to not belong to separator. In
+        #   real world, separator will not be so long to notice
+        #   performance problems. Even when reading MIME-encoded
+        #   messages :)
+
+        # `offset` is the number of bytes from the beginning of the buffer where
+        # is no occurrence of `separator`.
+        offset = 0
+
+        # Loop until we find `separator` in the buffer, exceed the buffer size,
+        # or an EOF has happened.
+        while True:
+            buflen = len(self._buffer)
+
+            # Check if we now have enough data in the buffer for `separator` to
+            # fit.
+            if buflen - offset >= seplen:
+                isep = self._buffer.find(separator, offset)
+
+                if isep != -1:
+                    # `separator` is in the buffer. `isep` will be used later to
+                    # retrieve the data.
+                    break
+
+                # see upper comment for explanation.
+                offset = buflen + 1 - seplen
+                if offset > self._limit:
+                    raise LimitOverrunError('Separator is not found, and chunk exceed the limit', offset)
 
+            # Complete message (with full separator) may be present in buffer
+            # even when EOF flag is set. This may happen when the last chunk
+            # adds data which makes separator be found. That's why we check for
+            # EOF *ater* inspecting the buffer.
             if self._eof:
-                break
+                chunk = bytes(self._buffer)
+                self._buffer.clear()
+                raise IncompleteReadError(chunk, None)
+
+            # _wait_for_data() will resume reading if stream was paused.
+            yield from self._wait_for_data('readuntil')
 
-            if not_enough:
-                yield from self._wait_for_data('readline')
+        if isep > self._limit:
+            raise LimitOverrunError('Separator is found, but chunk is longer than limit', isep)
 
+        chunk = self._buffer[:isep + seplen]
+        del self._buffer[:isep + seplen]
         self._maybe_resume_transport()
-        return bytes(line)
+        return bytes(chunk)
 
     @coroutine
     def read(self, n=-1):
+        """Read up to `n` bytes from the stream.
+
+        If n is not provided, or set to -1, read until EOF and return all read
+        bytes. If the EOF was received and the internal buffer is empty, return
+        an empty bytes object.
+
+        If n is zero, return empty bytes object immediatelly.
+
+        If n is positive, this function try to read `n` bytes, and may return
+        less or equal bytes than requested, but at least one byte. If EOF was
+        received before any byte is read, this function returns empty byte
+        object.
+
+        Returned value is not limited with limit, configured at stream creation.
+
+        If stream was paused, this function will automatically resume it if
+        needed.
+        """
+
         if self._exception is not None:
             raise self._exception
 
-        if not n:
+        if n == 0:
             return b''
 
         if n < 0:
@@ -477,29 +612,41 @@ class StreamReader:
                     break
                 blocks.append(block)
             return b''.join(blocks)
-        else:
-            if not self._buffer and not self._eof:
-                yield from self._wait_for_data('read')
 
-        if n < 0 or len(self._buffer) <= n:
-            data = bytes(self._buffer)
-            self._buffer.clear()
-        else:
-            # n > 0 and len(self._buffer) > n
-            data = bytes(self._buffer[:n])
-            del self._buffer[:n]
+        if not self._buffer and not self._eof:
+            yield from self._wait_for_data('read')
+
+        # This will work right even if buffer is less than n bytes
+        data = bytes(self._buffer[:n])
+        del self._buffer[:n]
 
         self._maybe_resume_transport()
         return data
 
     @coroutine
     def readexactly(self, n):
+        """Read exactly `n` bytes.
+
+        Raise an `IncompleteReadError` if EOF is reached before `n` bytes can be
+        read. The `IncompleteReadError.partial` attribute of the exception will
+        contain the partial read bytes.
+
+        if n is zero, return empty bytes object.
+
+        Returned value is not limited with limit, configured at stream creation.
+
+        If stream was paused, this function will automatically resume it if
+        needed.
+        """
         if n < 0:
             raise ValueError('readexactly size can not be less than zero')
 
         if self._exception is not None:
             raise self._exception
 
+        if n == 0:
+            return b''
+
         # There used to be "optimized" code here.  It created its own
         # Future and waited until self._buffer had at least the n
         # bytes, then called read(n).  Unfortunately, this could pause
@@ -516,6 +663,8 @@ class StreamReader:
             blocks.append(block)
             n -= len(block)
 
+        assert n == 0
+
         return b''.join(blocks)
 
     if compat.PY35:
index 3b115b14bec86edab94423af6bd9d21bedb61c20..1783d5f6306ade668840d6aa44f36ff8aaaa1f55 100644 (file)
@@ -203,6 +203,20 @@ class StreamReaderTests(test_utils.TestCase):
         self.assertRaises(
             ValueError, self.loop.run_until_complete, stream.read(2))
 
+    def test_invalid_limit(self):
+        with self.assertRaisesRegex(ValueError, 'imit'):
+            asyncio.StreamReader(limit=0, loop=self.loop)
+
+        with self.assertRaisesRegex(ValueError, 'imit'):
+            asyncio.StreamReader(limit=-1, loop=self.loop)
+
+    def test_read_limit(self):
+        stream = asyncio.StreamReader(limit=3, loop=self.loop)
+        stream.feed_data(b'chunk')
+        data = self.loop.run_until_complete(stream.read(5))
+        self.assertEqual(b'chunk', data)
+        self.assertEqual(b'', stream._buffer)
+
     def test_readline(self):
         # Read one line. 'readline' will need to wait for the data
         # to come from 'cb'
@@ -292,6 +306,23 @@ class StreamReaderTests(test_utils.TestCase):
             ValueError, self.loop.run_until_complete, stream.readline())
         self.assertEqual(b'chunk3\n', stream._buffer)
 
+        # check strictness of the limit
+        stream = asyncio.StreamReader(limit=7, loop=self.loop)
+        stream.feed_data(b'1234567\n')
+        line = self.loop.run_until_complete(stream.readline())
+        self.assertEqual(b'1234567\n', line)
+        self.assertEqual(b'', stream._buffer)
+
+        stream.feed_data(b'12345678\n')
+        with self.assertRaises(ValueError) as cm:
+            self.loop.run_until_complete(stream.readline())
+        self.assertEqual(b'', stream._buffer)
+
+        stream.feed_data(b'12345678')
+        with self.assertRaises(ValueError) as cm:
+            self.loop.run_until_complete(stream.readline())
+        self.assertEqual(b'', stream._buffer)
+
     def test_readline_nolimit_nowait(self):
         # All needed data for the first 'readline' call will be
         # in the buffer.
@@ -342,6 +373,92 @@ class StreamReaderTests(test_utils.TestCase):
             ValueError, self.loop.run_until_complete, stream.readline())
         self.assertEqual(b'', stream._buffer)
 
+    def test_readuntil_separator(self):
+        stream = asyncio.StreamReader(loop=self.loop)
+        with self.assertRaisesRegex(ValueError, 'Separator should be'):
+            self.loop.run_until_complete(stream.readuntil(separator=b''))
+
+    def test_readuntil_multi_chunks(self):
+        stream = asyncio.StreamReader(loop=self.loop)
+
+        stream.feed_data(b'lineAAA')
+        data = self.loop.run_until_complete(stream.readuntil(separator=b'AAA'))
+        self.assertEqual(b'lineAAA', data)
+        self.assertEqual(b'', stream._buffer)
+
+        stream.feed_data(b'lineAAA')
+        data = self.loop.run_until_complete(stream.readuntil(b'AAA'))
+        self.assertEqual(b'lineAAA', data)
+        self.assertEqual(b'', stream._buffer)
+
+        stream.feed_data(b'lineAAAxxx')
+        data = self.loop.run_until_complete(stream.readuntil(b'AAA'))
+        self.assertEqual(b'lineAAA', data)
+        self.assertEqual(b'xxx', stream._buffer)
+
+    def test_readuntil_multi_chunks_1(self):
+        stream = asyncio.StreamReader(loop=self.loop)
+
+        stream.feed_data(b'QWEaa')
+        stream.feed_data(b'XYaa')
+        stream.feed_data(b'a')
+        data = self.loop.run_until_complete(stream.readuntil(b'aaa'))
+        self.assertEqual(b'QWEaaXYaaa', data)
+        self.assertEqual(b'', stream._buffer)
+
+        stream.feed_data(b'QWEaa')
+        stream.feed_data(b'XYa')
+        stream.feed_data(b'aa')
+        data = self.loop.run_until_complete(stream.readuntil(b'aaa'))
+        self.assertEqual(b'QWEaaXYaaa', data)
+        self.assertEqual(b'', stream._buffer)
+
+        stream.feed_data(b'aaa')
+        data = self.loop.run_until_complete(stream.readuntil(b'aaa'))
+        self.assertEqual(b'aaa', data)
+        self.assertEqual(b'', stream._buffer)
+
+        stream.feed_data(b'Xaaa')
+        data = self.loop.run_until_complete(stream.readuntil(b'aaa'))
+        self.assertEqual(b'Xaaa', data)
+        self.assertEqual(b'', stream._buffer)
+
+        stream.feed_data(b'XXX')
+        stream.feed_data(b'a')
+        stream.feed_data(b'a')
+        stream.feed_data(b'a')
+        data = self.loop.run_until_complete(stream.readuntil(b'aaa'))
+        self.assertEqual(b'XXXaaa', data)
+        self.assertEqual(b'', stream._buffer)
+
+    def test_readuntil_eof(self):
+        stream = asyncio.StreamReader(loop=self.loop)
+        stream.feed_data(b'some dataAA')
+        stream.feed_eof()
+
+        with self.assertRaises(asyncio.IncompleteReadError) as cm:
+            self.loop.run_until_complete(stream.readuntil(b'AAA'))
+        self.assertEqual(cm.exception.partial, b'some dataAA')
+        self.assertIsNone(cm.exception.expected)
+        self.assertEqual(b'', stream._buffer)
+
+    def test_readuntil_limit_found_sep(self):
+        stream = asyncio.StreamReader(loop=self.loop, limit=3)
+        stream.feed_data(b'some dataAA')
+
+        with self.assertRaisesRegex(asyncio.LimitOverrunError,
+                                    'not found') as cm:
+            self.loop.run_until_complete(stream.readuntil(b'AAA'))
+
+        self.assertEqual(b'some dataAA', stream._buffer)
+
+        stream.feed_data(b'A')
+        with self.assertRaisesRegex(asyncio.LimitOverrunError,
+                                    'is found') as cm:
+            self.loop.run_until_complete(stream.readuntil(b'AAA'))
+
+        self.assertEqual(b'some dataAAA', stream._buffer)
+
     def test_readexactly_zero_or_less(self):
         # Read exact number of bytes (zero or less).
         stream = asyncio.StreamReader(loop=self.loop)
@@ -372,6 +489,13 @@ class StreamReaderTests(test_utils.TestCase):
         self.assertEqual(self.DATA + self.DATA, data)
         self.assertEqual(self.DATA, stream._buffer)
 
+    def test_readexactly_limit(self):
+        stream = asyncio.StreamReader(limit=3, loop=self.loop)
+        stream.feed_data(b'chunk')
+        data = self.loop.run_until_complete(stream.readexactly(5))
+        self.assertEqual(b'chunk', data)
+        self.assertEqual(b'', stream._buffer)
+
     def test_readexactly_eof(self):
         # Read exact number of bytes (eof).
         stream = asyncio.StreamReader(loop=self.loop)
@@ -657,7 +781,9 @@ os.close(fd)
 
         @asyncio.coroutine
         def client(host, port):
-            reader, writer = yield from asyncio.open_connection(host, port, loop=self.loop)
+            reader, writer = yield from asyncio.open_connection(
+                host, port, loop=self.loop)
+
             while True:
                 writer.write(b"foo\n")
                 yield from writer.drain()
index 9b677e522c3ea31489b8e39df04aee0a0f08f17d..a54067198f5eebda4b543b62ad60a520afdc31b0 100644 (file)
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -30,6 +30,9 @@ Library
 
 - Add asyncio.timeout() context manager.
 
+- Issue #26050: Add asyncio.StreamReader.readuntil() method.
+  Patch by Марк Коренберг.
+
 
 What's New in Python 3.4.4?
 ===========================