Issue #13464: Add a readinto() method to http.client.HTTPResponse.
authorAntoine Pitrou <solipsis@pitrou.net>
Tue, 6 Dec 2011 21:33:57 +0000 (22:33 +0100)
committerAntoine Pitrou <solipsis@pitrou.net>
Tue, 6 Dec 2011 21:33:57 +0000 (22:33 +0100)
Patch by Jon Kuhn.

Doc/library/http.client.rst
Lib/http/client.py
Lib/test/test_httplib.py
Misc/ACKS
Misc/NEWS

index c1ce15bddc3622eac403e3d8b5774c0110018058..7fce91fb87c259bd3b0c373dde6af9772077ba81 100644 (file)
@@ -502,6 +502,12 @@ statement.
 
    Reads and returns the response body, or up to the next *amt* bytes.
 
+.. method:: HTTPResponse.readinto(b)
+
+   Reads up to the next len(b) bytes of the response body into the buffer *b*.
+   Returns the number of bytes read.
+
+   .. versionadded:: 3.3
 
 .. method:: HTTPResponse.getheader(name, default=None)
 
index 88da55054ea40216acf9ec4a69361e6b3242efca..70625699b9e0df5f33f8216d9c80cbb08a6d4976 100644 (file)
@@ -485,11 +485,17 @@ class HTTPResponse(io.RawIOBase):
             self.close()
             return b""
 
-        if self.chunked:
-            return self._read_chunked(amt)
+        if amt is not None:
+            # Amount is given, so call base class version
+            # (which is implemented in terms of self.readinto)
+            return super(HTTPResponse, self).read(amt)
+        else:
+            # Amount is not given (unbounded read) so we must check self.length
+            # and self.chunked
+
+            if self.chunked:
+                return self._readall_chunked()
 
-        if amt is None:
-            # unbounded read
             if self.length is None:
                 s = self.fp.read()
             else:
@@ -498,78 +504,127 @@ class HTTPResponse(io.RawIOBase):
             self.close()        # we read everything
             return s
 
+    def readinto(self, b):
+        if self.fp is None:
+            return 0
+
+        if self._method == "HEAD":
+            self.close()
+            return 0
+
+        if self.chunked:
+            return self._readinto_chunked(b)
+
         if self.length is not None:
-            if amt > self.length:
+            if len(b) > self.length:
                 # clip the read to the "end of response"
-                amt = self.length
+                b = memoryview(b)[0:self.length]
 
         # we do not use _safe_read() here because this may be a .will_close
         # connection, and the user is reading more bytes than will be provided
         # (for example, reading in 1k chunks)
-        s = self.fp.read(amt)
+        n = self.fp.readinto(b)
         if self.length is not None:
-            self.length -= len(s)
+            self.length -= n
             if not self.length:
                 self.close()
-        return s
+        return n
+
+    def _read_next_chunk_size(self):
+        # Read the next chunk size from the file
+        line = self.fp.readline(_MAXLINE + 1)
+        if len(line) > _MAXLINE:
+            raise LineTooLong("chunk size")
+        i = line.find(b";")
+        if i >= 0:
+            line = line[:i] # strip chunk-extensions
+        try:
+            return int(line, 16)
+        except ValueError:
+            # close the connection as protocol synchronisation is
+            # probably lost
+            self.close()
+            raise
 
-    def _read_chunked(self, amt):
+    def _read_and_discard_trailer(self):
+        # read and discard trailer up to the CRLF terminator
+        ### note: we shouldn't have any trailers!
+        while True:
+            line = self.fp.readline(_MAXLINE + 1)
+            if len(line) > _MAXLINE:
+                raise LineTooLong("trailer line")
+            if not line:
+                # a vanishingly small number of sites EOF without
+                # sending the trailer
+                break
+            if line == b"\r\n":
+                break
+
+    def _readall_chunked(self):
         assert self.chunked != _UNKNOWN
         chunk_left = self.chunk_left
         value = []
         while True:
             if chunk_left is None:
-                line = self.fp.readline(_MAXLINE + 1)
-                if len(line) > _MAXLINE:
-                    raise LineTooLong("chunk size")
-                i = line.find(b";")
-                if i >= 0:
-                    line = line[:i] # strip chunk-extensions
                 try:
-                    chunk_left = int(line, 16)
+                    chunk_left = self._read_next_chunk_size()
+                    if chunk_left == 0:
+                        break
                 except ValueError:
-                    # close the connection as protocol synchronisation is
-                    # probably lost
-                    self.close()
                     raise IncompleteRead(b''.join(value))
-                if chunk_left == 0:
-                    break
-            if amt is None:
-                value.append(self._safe_read(chunk_left))
-            elif amt < chunk_left:
-                value.append(self._safe_read(amt))
-                self.chunk_left = chunk_left - amt
-                return b''.join(value)
-            elif amt == chunk_left:
-                value.append(self._safe_read(amt))
+            value.append(self._safe_read(chunk_left))
+
+            # we read the whole chunk, get another
+            self._safe_read(2)      # toss the CRLF at the end of the chunk
+            chunk_left = None
+
+        self._read_and_discard_trailer()
+
+        # we read everything; close the "file"
+        self.close()
+
+        return b''.join(value)
+
+    def _readinto_chunked(self, b):
+        assert self.chunked != _UNKNOWN
+        chunk_left = self.chunk_left
+
+        total_bytes = 0
+        mvb = memoryview(b)
+        while True:
+            if chunk_left is None:
+                try:
+                    chunk_left = self._read_next_chunk_size()
+                    if chunk_left == 0:
+                        break
+                except ValueError:
+                    raise IncompleteRead(bytes(b[0:total_bytes]))
+                    
+            if len(mvb) < chunk_left:
+                n = self._safe_readinto(mvb)
+                self.chunk_left = chunk_left - n
+                return n
+            elif len(mvb) == chunk_left:
+                n = self._safe_readinto(mvb)
                 self._safe_read(2)  # toss the CRLF at the end of the chunk
                 self.chunk_left = None
-                return b''.join(value)
+                return n
             else:
-                value.append(self._safe_read(chunk_left))
-                amt -= chunk_left
+                temp_mvb = mvb[0:chunk_left]
+                n = self._safe_readinto(temp_mvb)
+                mvb = mvb[n:]
+                total_bytes += n
 
             # we read the whole chunk, get another
             self._safe_read(2)      # toss the CRLF at the end of the chunk
             chunk_left = None
 
-        # read and discard trailer up to the CRLF terminator
-        ### note: we shouldn't have any trailers!
-        while True:
-            line = self.fp.readline(_MAXLINE + 1)
-            if len(line) > _MAXLINE:
-                raise LineTooLong("trailer line")
-            if not line:
-                # a vanishingly small number of sites EOF without
-                # sending the trailer
-                break
-            if line == b"\r\n":
-                break
+        self._read_and_discard_trailer()
 
         # we read everything; close the "file"
         self.close()
 
-        return b''.join(value)
+        return total_bytes
 
     def _safe_read(self, amt):
         """Read the number of bytes requested, compensating for partial reads.
@@ -594,6 +649,22 @@ class HTTPResponse(io.RawIOBase):
             amt -= len(chunk)
         return b"".join(s)
 
+    def _safe_readinto(self, b):
+        """Same as _safe_read, but for reading into a buffer."""
+        total_bytes = 0
+        mvb = memoryview(b)
+        while total_bytes < len(b):
+            if MAXAMOUNT < len(mvb):
+                temp_mvb = mvb[0:MAXAMOUNT]
+                n = self.fp.readinto(temp_mvb)
+            else:
+                n = self.fp.readinto(mvb)
+            if not n:
+                raise IncompleteRead(bytes(mvb[0:total_bytes]), len(b))
+            mvb = mvb[n:]
+            total_bytes += n
+        return total_bytes
+
     def fileno(self):
         return self.fp.fileno()
 
index a10c09dddcf829d1c31d3e15d8a0d47c0601e3d3..425c7160510c010ec10015c3989c2bda75eadff9 100644 (file)
@@ -158,6 +158,23 @@ class BasicTest(TestCase):
         self.assertEqual(resp.read(2), b'xt')
         self.assertTrue(resp.isclosed())
 
+    def test_partial_readintos(self):
+        # if we have a lenght, the system knows when to close itself
+        # same behaviour than when we read the whole thing with read()
+        body = "HTTP/1.1 200 Ok\r\nContent-Length: 4\r\n\r\nText"
+        sock = FakeSocket(body)
+        resp = client.HTTPResponse(sock)
+        resp.begin()
+        b = bytearray(2)
+        n = resp.readinto(b)
+        self.assertEqual(n, 2)
+        self.assertEqual(bytes(b), b'Te')
+        self.assertFalse(resp.isclosed())
+        n = resp.readinto(b)
+        self.assertEqual(n, 2)
+        self.assertEqual(bytes(b), b'xt')
+        self.assertTrue(resp.isclosed())
+
     def test_host_port(self):
         # Check invalid host_port
 
@@ -206,6 +223,21 @@ class BasicTest(TestCase):
         if resp.read():
             self.fail("Did not expect response from HEAD request")
 
+    def test_readinto_head(self):
+        # Test that the library doesn't attempt to read any data
+        # from a HEAD request.  (Tickles SF bug #622042.)
+        sock = FakeSocket(
+            'HTTP/1.1 200 OK\r\n'
+            'Content-Length: 14432\r\n'
+            '\r\n',
+            NoEOFStringIO)
+        resp = client.HTTPResponse(sock, method="HEAD")
+        resp.begin()
+        b = bytearray(5)
+        if resp.readinto(b) != 0:
+            self.fail("Did not expect response from HEAD request")
+        self.assertEqual(bytes(b), b'\x00'*5)
+
     def test_send_file(self):
         expected = (b'GET /foo HTTP/1.1\r\nHost: example.com\r\n'
                     b'Accept-Encoding: identity\r\nContent-Length:')
@@ -285,6 +317,40 @@ class BasicTest(TestCase):
             finally:
                 resp.close()
 
+    def test_readinto_chunked(self):
+        chunked_start = (
+            'HTTP/1.1 200 OK\r\n'
+            'Transfer-Encoding: chunked\r\n\r\n'
+            'a\r\n'
+            'hello worl\r\n'
+            '1\r\n'
+            'd\r\n'
+        )
+        sock = FakeSocket(chunked_start + '0\r\n')
+        resp = client.HTTPResponse(sock, method="GET")
+        resp.begin()
+        b = bytearray(16)
+        n = resp.readinto(b)
+        self.assertEqual(b[:11], b'hello world')
+        self.assertEqual(n, 11)
+        resp.close()
+
+        for x in ('', 'foo\r\n'):
+            sock = FakeSocket(chunked_start + x)
+            resp = client.HTTPResponse(sock, method="GET")
+            resp.begin()
+            try:
+                b = bytearray(16)
+                n = resp.readinto(b)
+            except client.IncompleteRead as i:
+                self.assertEqual(i.partial, b'hello world')
+                self.assertEqual(repr(i),'IncompleteRead(11 bytes read)')
+                self.assertEqual(str(i),'IncompleteRead(11 bytes read)')
+            else:
+                self.fail('IncompleteRead expected')
+            finally:
+                resp.close()
+
     def test_chunked_head(self):
         chunked_start = (
             'HTTP/1.1 200 OK\r\n'
@@ -302,6 +368,26 @@ class BasicTest(TestCase):
         self.assertEqual(resp.reason, 'OK')
         self.assertTrue(resp.isclosed())
 
+    def test_readinto_chunked_head(self):
+        chunked_start = (
+            'HTTP/1.1 200 OK\r\n'
+            'Transfer-Encoding: chunked\r\n\r\n'
+            'a\r\n'
+            'hello world\r\n'
+            '1\r\n'
+            'd\r\n'
+        )
+        sock = FakeSocket(chunked_start + '0\r\n')
+        resp = client.HTTPResponse(sock, method="HEAD")
+        resp.begin()
+        b = bytearray(5)
+        n = resp.readinto(b)
+        self.assertEqual(n, 0)
+        self.assertEqual(bytes(b), b'\x00'*5)
+        self.assertEqual(resp.status, 200)
+        self.assertEqual(resp.reason, 'OK')
+        self.assertTrue(resp.isclosed())
+
     def test_negative_content_length(self):
         sock = FakeSocket(
             'HTTP/1.1 200 OK\r\nContent-Length: -1\r\n\r\nHello\r\n')
index 048e5495fc7635fc49eee4008df1ba8d09f76241..0beb5c9002d3c0a227dee941412c04017077ec4a 100644 (file)
--- a/Misc/ACKS
+++ b/Misc/ACKS
@@ -547,6 +547,7 @@ Hannu Krosing
 Andrej Krpic
 Ivan Krstić
 Andrew Kuchling
+Jon Kuhn
 Vladimir Kushnir
 Ross Lagerwall
 Cameron Laird
index c98c9561321ac33cec85d49fb9b2c1d3c5963a00..2444ced68697beda24c1b1645fcb5836d142753a 100644 (file)
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -406,6 +406,9 @@ Core and Builtins
 Library
 -------
 
+- Issue #13464: Add a readinto() method to http.client.HTTPResponse.  Patch
+  by Jon Kuhn.
+
 - tarfile.py: Correctly detect bzip2 compressed streams with blocksizes
   other than 900k.