]> granicus.if.org Git - python/commitdiff
Issue #19009
authorKristján Valur Jónsson <sweskman@gmail.com>
Wed, 19 Mar 2014 10:07:26 +0000 (10:07 +0000)
committerKristján Valur Jónsson <sweskman@gmail.com>
Wed, 19 Mar 2014 10:07:26 +0000 (10:07 +0000)
Enhance HTTPResponse.readline() performance

Lib/http/client.py
Lib/test/test_httplib.py

index 763a903d42b64264b9f63923b2673146619664f6..8ee2101708d68255fdcdd076f6cd242cf9a3d5bb 100644 (file)
@@ -271,7 +271,7 @@ def parse_headers(fp, _class=HTTPMessage):
     return email.parser.Parser(_class=_class).parsestr(hstring)
 
 
-class HTTPResponse(io.RawIOBase):
+class HTTPResponse(io.BufferedIOBase):
 
     # See RFC 2616 sec 19.6 and RFC 1945 sec 6 for details.
 
@@ -496,9 +496,10 @@ class HTTPResponse(io.RawIOBase):
             return b""
 
         if amt is not None:
-            # Amount is given, so call base class version
-            # (which is implemented in terms of self.readinto)
-            return super(HTTPResponse, self).read(amt)
+            # Amount is given, implement using readinto
+            b = bytearray(amt)
+            n = self.readinto(b)
+            return memoryview(b)[:n].tobytes()
         else:
             # Amount is not given (unbounded read) so we must check self.length
             # and self.chunked
@@ -578,71 +579,67 @@ class HTTPResponse(io.RawIOBase):
             if line in (b'\r\n', b'\n', b''):
                 break
 
+    def _get_chunk_left(self):
+        # return self.chunk_left, reading a new chunk if necessary.
+        # chunk_left == 0: at the end of the current chunk, need to close it
+        # chunk_left == None: No current chunk, should read next.
+        # This function returns non-zero or None if the last chunk has
+        # been read.
+        chunk_left = self.chunk_left
+        if not chunk_left: # Can be 0 or None
+            if chunk_left is not None:
+                # We are at the end of chunk. dicard chunk end
+                self._safe_read(2)  # toss the CRLF at the end of the chunk
+            try:
+                chunk_left = self._read_next_chunk_size()
+            except ValueError:
+                raise IncompleteRead(b'')
+            if chunk_left == 0:
+                # last chunk: 1*("0") [ chunk-extension ] CRLF
+                self._read_and_discard_trailer()
+                # we read everything; close the "file"
+                self._close_conn()
+                chunk_left = None
+            self.chunk_left = chunk_left
+        return chunk_left
+
     def _readall_chunked(self):
         assert self.chunked != _UNKNOWN
-        chunk_left = self.chunk_left
         value = []
-        while True:
-            if chunk_left is None:
-                try:
-                    chunk_left = self._read_next_chunk_size()
-                    if chunk_left == 0:
-                        break
-                except ValueError:
-                    raise IncompleteRead(b''.join(value))
-            value.append(self._safe_read(chunk_left))
-
-            # we read the whole chunk, get another
-            self._safe_read(2)      # toss the CRLF at the end of the chunk
-            chunk_left = None
-
-        self._read_and_discard_trailer()
-
-        # we read everything; close the "file"
-        self._close_conn()
-
-        return b''.join(value)
+        try:
+            while True:
+                chunk_left = self._get_chunk_left()
+                if chunk_left is None:
+                    break
+                value.append(self._safe_read(chunk_left))
+                self.chunk_left = 0
+            return b''.join(value)
+        except IncompleteRead:
+            raise IncompleteRead(b''.join(value))
 
     def _readinto_chunked(self, b):
         assert self.chunked != _UNKNOWN
-        chunk_left = self.chunk_left
-
         total_bytes = 0
         mvb = memoryview(b)
-        while True:
-            if chunk_left is None:
-                try:
-                    chunk_left = self._read_next_chunk_size()
-                    if chunk_left == 0:
-                        break
-                except ValueError:
-                    raise IncompleteRead(bytes(b[0:total_bytes]))
-
-            if len(mvb) < chunk_left:
-                n = self._safe_readinto(mvb)
-                self.chunk_left = chunk_left - n
-                return total_bytes + n
-            elif len(mvb) == chunk_left:
-                n = self._safe_readinto(mvb)
-                self._safe_read(2)  # toss the CRLF at the end of the chunk
-                self.chunk_left = None
-                return total_bytes + n
-            else:
-                temp_mvb = mvb[0:chunk_left]
+        try:
+            while True:
+                chunk_left = self._get_chunk_left()
+                if chunk_left is None:
+                    return total_bytes
+
+                if len(mvb) <= chunk_left:
+                    n = self._safe_readinto(mvb)
+                    self.chunk_left = chunk_left - n
+                    return total_bytes + n
+
+                temp_mvb = mvb[:chunk_left]
                 n = self._safe_readinto(temp_mvb)
                 mvb = mvb[n:]
                 total_bytes += n
+                self.chunk_left = 0
 
-            # we read the whole chunk, get another
-            self._safe_read(2)      # toss the CRLF at the end of the chunk
-            chunk_left = None
-
-        self._read_and_discard_trailer()
-
-        # we read everything; close the "file"
-        self._close_conn()
-
-        return total_bytes
+        except IncompleteRead:
+            raise IncompleteRead(bytes(b[0:total_bytes]))
 
     def _safe_read(self, amt):
         """Read the number of bytes requested, compensating for partial reads.
@@ -683,6 +680,73 @@ class HTTPResponse(io.RawIOBase):
             total_bytes += n
         return total_bytes
 
+    def read1(self, n=-1):
+        """Read with at most one underlying system call.  If at least one
+        byte is buffered, return that instead.
+        """
+        if self.fp is None or self._method == "HEAD":
+            return b""
+        if self.chunked:
+            return self._read1_chunked(n)
+        try:
+            result = self.fp.read1(n)
+        except ValueError:
+            if n >= 0:
+                raise
+            # some implementations, like BufferedReader, don't support -1
+            # Read an arbitrarily selected largeish chunk.
+            result = self.fp.read1(16*1024)
+        if not result and n:
+            self._close_conn()
+        return result
+
+    def peek(self, n=-1):
+        # Having this enables IOBase.readline() to read more than one
+        # byte at a time
+        if self.fp is None or self._method == "HEAD":
+            return b""
+        if self.chunked:
+            return self._peek_chunked(n)
+        return self.fp.peek(n)
+
+    def readline(self, limit=-1):
+        if self.fp is None or self._method == "HEAD":
+            return b""
+        if self.chunked:
+            # Fallback to IOBase readline which uses peek() and read()
+            return super().readline(limit)
+        result = self.fp.readline(limit)
+        if not result and limit:
+            self._close_conn()
+        return result
+
+    def _read1_chunked(self, n):
+        # Strictly speaking, _get_chunk_left() may cause more than one read,
+        # but that is ok, since that is to satisfy the chunked protocol.
+        chunk_left = self._get_chunk_left()
+        if chunk_left is None or n == 0:
+            return b''
+        if not (0 <= n <= chunk_left):
+            n = chunk_left # if n is negative or larger than chunk_left
+        read = self.fp.read1(n)
+        self.chunk_left -= len(read)
+        if not read:
+            raise IncompleteRead(b"")
+        return read
+
+    def _peek_chunked(self, n):
+        # Strictly speaking, _get_chunk_left() may cause more than one read,
+        # but that is ok, since that is to satisfy the chunked protocol.
+        try:
+            chunk_left = self._get_chunk_left()
+        except IncompleteRead:
+            return b'' # peek doesn't worry about protocol
+        if chunk_left is None:
+            return b'' # eof
+        # peek is allowed to return more than requested.  Just request the
+        # entire chunk, and truncate what we get.
+        return self.fp.peek(chunk_left)[:chunk_left]
+
     def fileno(self):
         return self.fp.fileno()
 
index 30b6c0cfcbbe28b1ff432a8f1a04bd9cdf714eb9..69aa381918388a5b1c9bb189158bafdaa4a6dd9d 100644 (file)
@@ -18,6 +18,26 @@ CERT_fakehostname = os.path.join(here, 'keycert2.pem')
 # Root cert file (CA) for svn.python.org's cert
 CACERT_svn_python_org = os.path.join(here, 'https_svn_python_org_root.pem')
 
+# constants for testing chunked encoding
+chunked_start = (
+    'HTTP/1.1 200 OK\r\n'
+    'Transfer-Encoding: chunked\r\n\r\n'
+    'a\r\n'
+    'hello worl\r\n'
+    '3\r\n'
+    'd! \r\n'
+    '8\r\n'
+    'and now \r\n'
+    '22\r\n'
+    'for something completely different\r\n'
+)
+chunked_expected = b'hello world! and now for something completely different'
+chunk_extension = ";foo=bar"
+last_chunk = "0\r\n"
+last_chunk_extended = "0" + chunk_extension + "\r\n"
+trailers = "X-Dummy: foo\r\nX-Dumm2: bar\r\n"
+chunked_end = "\r\n"
+
 HOST = support.HOST
 
 class FakeSocket:
@@ -36,7 +56,10 @@ class FakeSocket:
     def makefile(self, mode, bufsize=None):
         if mode != 'r' and mode != 'rb':
             raise client.UnimplementedFileMode()
-        return self.fileclass(self.text)
+        # keep the file around so we can check how much was read from it
+        self.file = self.fileclass(self.text)
+        self.file.close = lambda:None #nerf close ()
+        return self.file
 
 class EPipeSocket(FakeSocket):
 
@@ -430,20 +453,8 @@ class BasicTest(TestCase):
             conn.request('POST', 'test', conn)
 
     def test_chunked(self):
-        chunked_start = (
-            'HTTP/1.1 200 OK\r\n'
-            'Transfer-Encoding: chunked\r\n\r\n'
-            'a\r\n'
-            'hello worl\r\n'
-            '3\r\n'
-            'd! \r\n'
-            '8\r\n'
-            'and now \r\n'
-            '22\r\n'
-            'for something completely different\r\n'
-        )
-        expected = b'hello world! and now for something completely different'
-        sock = FakeSocket(chunked_start + '0\r\n')
+        expected = chunked_expected
+        sock = FakeSocket(chunked_start + last_chunk + chunked_end)
         resp = client.HTTPResponse(sock, method="GET")
         resp.begin()
         self.assertEqual(resp.read(), expected)
@@ -451,7 +462,7 @@ class BasicTest(TestCase):
 
         # Various read sizes
         for n in range(1, 12):
-            sock = FakeSocket(chunked_start + '0\r\n')
+            sock = FakeSocket(chunked_start + last_chunk + chunked_end)
             resp = client.HTTPResponse(sock, method="GET")
             resp.begin()
             self.assertEqual(resp.read(n) + resp.read(n) + resp.read(), expected)
@@ -474,23 +485,12 @@ class BasicTest(TestCase):
                 resp.close()
 
     def test_readinto_chunked(self):
-        chunked_start = (
-            'HTTP/1.1 200 OK\r\n'
-            'Transfer-Encoding: chunked\r\n\r\n'
-            'a\r\n'
-            'hello worl\r\n'
-            '3\r\n'
-            'd! \r\n'
-            '8\r\n'
-            'and now \r\n'
-            '22\r\n'
-            'for something completely different\r\n'
-        )
-        expected = b'hello world! and now for something completely different'
+
+        expected = chunked_expected
         nexpected = len(expected)
         b = bytearray(128)
 
-        sock = FakeSocket(chunked_start + '0\r\n')
+        sock = FakeSocket(chunked_start + last_chunk + chunked_end)
         resp = client.HTTPResponse(sock, method="GET")
         resp.begin()
         n = resp.readinto(b)
@@ -500,7 +500,7 @@ class BasicTest(TestCase):
 
         # Various read sizes
         for n in range(1, 12):
-            sock = FakeSocket(chunked_start + '0\r\n')
+            sock = FakeSocket(chunked_start + last_chunk + chunked_end)
             resp = client.HTTPResponse(sock, method="GET")
             resp.begin()
             m = memoryview(b)
@@ -536,7 +536,7 @@ class BasicTest(TestCase):
             '1\r\n'
             'd\r\n'
         )
-        sock = FakeSocket(chunked_start + '0\r\n')
+        sock = FakeSocket(chunked_start + last_chunk + chunked_end)
         resp = client.HTTPResponse(sock, method="HEAD")
         resp.begin()
         self.assertEqual(resp.read(), b'')
@@ -556,7 +556,7 @@ class BasicTest(TestCase):
             '1\r\n'
             'd\r\n'
         )
-        sock = FakeSocket(chunked_start + '0\r\n')
+        sock = FakeSocket(chunked_start + last_chunk + chunked_end)
         resp = client.HTTPResponse(sock, method="HEAD")
         resp.begin()
         b = bytearray(5)
@@ -631,6 +631,7 @@ class BasicTest(TestCase):
             + '0' * 65536 + 'a\r\n'
             'hello world\r\n'
             '0\r\n'
+            '\r\n'
         )
         resp = client.HTTPResponse(FakeSocket(body))
         resp.begin()
@@ -670,6 +671,239 @@ class BasicTest(TestCase):
         conn.request('POST', '/', body)
         self.assertGreater(sock.sendall_calls, 1)
 
+    def test_chunked_extension(self):
+        extra = '3;foo=bar\r\n' + 'abc\r\n'
+        expected = chunked_expected + b'abc'
+
+        sock = FakeSocket(chunked_start + extra + last_chunk_extended + chunked_end)
+        resp = client.HTTPResponse(sock, method="GET")
+        resp.begin()
+        self.assertEqual(resp.read(), expected)
+        resp.close()
+
+    def test_chunked_missing_end(self):
+        """some servers may serve up a short chunked encoding stream"""
+        expected = chunked_expected
+        sock = FakeSocket(chunked_start + last_chunk)  #no terminating crlf
+        resp = client.HTTPResponse(sock, method="GET")
+        resp.begin()
+        self.assertEqual(resp.read(), expected)
+        resp.close()
+
+    def test_chunked_trailers(self):
+        """See that trailers are read and ignored"""
+        expected = chunked_expected
+        sock = FakeSocket(chunked_start + last_chunk + trailers + chunked_end)
+        resp = client.HTTPResponse(sock, method="GET")
+        resp.begin()
+        self.assertEqual(resp.read(), expected)
+        # we should have reached the end of the file
+        self.assertEqual(sock.file.read(100), b"") #we read to the end
+        resp.close()
+
+    def test_chunked_sync(self):
+        """Check that we don't read past the end of the chunked-encoding stream"""
+        expected = chunked_expected
+        extradata = "extradata"
+        sock = FakeSocket(chunked_start + last_chunk + trailers + chunked_end + extradata)
+        resp = client.HTTPResponse(sock, method="GET")
+        resp.begin()
+        self.assertEqual(resp.read(), expected)
+        # the file should now have our extradata ready to be read
+        self.assertEqual(sock.file.read(100), extradata.encode("ascii")) #we read to the end
+        resp.close()
+
+    def test_content_length_sync(self):
+        """Check that we don't read past the end of the Content-Length stream"""
+        extradata = "extradata"
+        expected = b"Hello123\r\n"
+        sock = FakeSocket('HTTP/1.1 200 OK\r\nContent-Length: 10\r\n\r\nHello123\r\n' + extradata)
+        resp = client.HTTPResponse(sock, method="GET")
+        resp.begin()
+        self.assertEqual(resp.read(), expected)
+        # the file should now have our extradata ready to be read
+        self.assertEqual(sock.file.read(100), extradata.encode("ascii")) #we read to the end
+        resp.close()
+
+class ExtendedReadTest(TestCase):
+    """
+    Test peek(), read1(), readline()
+    """
+    lines = (
+        'HTTP/1.1 200 OK\r\n'
+        '\r\n'
+        'hello world!\n'
+        'and now \n'
+        'for something completely different\n'
+        'foo'
+        )
+    lines_expected = lines[lines.find('hello'):].encode("ascii")
+    lines_chunked = (
+        'HTTP/1.1 200 OK\r\n'
+        'Transfer-Encoding: chunked\r\n\r\n'
+        'a\r\n'
+        'hello worl\r\n'
+        '3\r\n'
+        'd!\n\r\n'
+        '9\r\n'
+        'and now \n\r\n'
+        '23\r\n'
+        'for something completely different\n\r\n'
+        '3\r\n'
+        'foo\r\n'
+        '0\r\n' # terminating chunk
+        '\r\n'  # end of trailers
+    )
+
+    def setUp(self):
+        sock = FakeSocket(self.lines)
+        resp = client.HTTPResponse(sock, method="GET")
+        resp.begin()
+        resp.fp = io.BufferedReader(resp.fp)
+        self.resp = resp
+
+
+
+    def test_peek(self):
+        resp = self.resp
+        # patch up the buffered peek so that it returns not too much stuff
+        oldpeek = resp.fp.peek
+        def mypeek(n=-1):
+            p = oldpeek(n)
+            if n >= 0:
+                return p[:n]
+            return p[:10]
+        resp.fp.peek = mypeek
+
+        all = []
+        while True:
+            # try a short peek
+            p = resp.peek(3)
+            if p:
+                self.assertGreater(len(p), 0)
+                # then unbounded peek
+                p2 = resp.peek()
+                self.assertGreaterEqual(len(p2), len(p))
+                self.assertTrue(p2.startswith(p))
+                next = resp.read(len(p2))
+                self.assertEqual(next, p2)
+            else:
+                next = resp.read()
+                self.assertFalse(next)
+            all.append(next)
+            if not next:
+                break
+        self.assertEqual(b"".join(all), self.lines_expected)
+
+    def test_readline(self):
+        resp = self.resp
+        self._verify_readline(self.resp.readline, self.lines_expected)
+
+    def _verify_readline(self, readline, expected):
+        all = []
+        while True:
+            # short readlines
+            line = readline(5)
+            if line and line != b"foo":
+                if len(line) < 5:
+                    self.assertTrue(line.endswith(b"\n"))
+            all.append(line)
+            if not line:
+                break
+        self.assertEqual(b"".join(all), expected)
+
+    def test_read1(self):
+        resp = self.resp
+        def r():
+            res = resp.read1(4)
+            self.assertLessEqual(len(res), 4)
+            return res
+        readliner = Readliner(r)
+        self._verify_readline(readliner.readline, self.lines_expected)
+
+    def test_read1_unbounded(self):
+        resp = self.resp
+        all = []
+        while True:
+            data = resp.read1()
+            if not data:
+                break
+            all.append(data)
+        self.assertEqual(b"".join(all), self.lines_expected)
+
+    def test_read1_bounded(self):
+        resp = self.resp
+        all = []
+        while True:
+            data = resp.read1(10)
+            if not data:
+                break
+            self.assertLessEqual(len(data), 10)
+            all.append(data)
+        self.assertEqual(b"".join(all), self.lines_expected)
+
+    def test_read1_0(self):
+        self.assertEqual(self.resp.read1(0), b"")
+
+    def test_peek_0(self):
+        p = self.resp.peek(0)
+        self.assertLessEqual(0, len(p))
+
+class ExtendedReadTestChunked(ExtendedReadTest):
+    """
+    Test peek(), read1(), readline() in chunked mode
+    """
+    lines = (
+        'HTTP/1.1 200 OK\r\n'
+        'Transfer-Encoding: chunked\r\n\r\n'
+        'a\r\n'
+        'hello worl\r\n'
+        '3\r\n'
+        'd!\n\r\n'
+        '9\r\n'
+        'and now \n\r\n'
+        '23\r\n'
+        'for something completely different\n\r\n'
+        '3\r\n'
+        'foo\r\n'
+        '0\r\n' # terminating chunk
+        '\r\n'  # end of trailers
+    )
+
+
+class Readliner:
+    """
+    a simple readline class that uses an arbitrary read function and buffering
+    """
+    def __init__(self, readfunc):
+        self.readfunc = readfunc
+        self.remainder = b""
+
+    def readline(self, limit):
+        data = []
+        datalen = 0
+        read = self.remainder
+        try:
+            while True:
+                idx = read.find(b'\n')
+                if idx != -1:
+                    break
+                if datalen + len(read) >= limit:
+                    idx = limit - datalen - 1
+                # read more data
+                data.append(read)
+                read = self.readfunc()
+                if not read:
+                    idx = 0 #eof condition
+                    break
+            idx += 1
+            data.append(read[:idx])
+            self.remainder = read[idx:]
+            return b"".join(data)
+        except:
+            self.remainder = b"".join(data)
+            raise
+
 class OfflineTest(TestCase):
     def test_responses(self):
         self.assertEqual(client.responses[client.NOT_FOUND], "Not Found")
@@ -973,7 +1207,8 @@ class HTTPResponseTest(TestCase):
 def test_main(verbose=None):
     support.run_unittest(HeaderTests, OfflineTest, BasicTest, TimeoutTest,
                          HTTPSTest, RequestBodyTest, SourceAddressTest,
-                         HTTPResponseTest)
+                         HTTPResponseTest, ExtendedReadTest,
+                         ExtendedReadTestChunked)
 
 if __name__ == '__main__':
     test_main()