]> granicus.if.org Git - python/commitdiff
Issue #7776: Fix ``Host:'' header and reconnection when using http.client.HTTPConnect...
authorSenthil Kumaran <senthil@uthcode.com>
Mon, 14 Apr 2014 17:07:56 +0000 (13:07 -0400)
committerSenthil Kumaran <senthil@uthcode.com>
Mon, 14 Apr 2014 17:07:56 +0000 (13:07 -0400)
Patch by Nikolaus Rath.

Lib/http/client.py
Lib/test/test_httplib.py
Misc/NEWS

index 12c1a5feddc2441c3ff6d010f30421ff679aebd8..d2013f2de79993bb39cc9d966c18bbc80d9a11b3 100644 (file)
@@ -747,14 +747,30 @@ class HTTPConnection:
         self._tunnel_port = None
         self._tunnel_headers = {}
 
-        self._set_hostport(host, port)
+        (self.host, self.port) = self._get_hostport(host, port)
+
+        # This is stored as an instance variable to allow unit
+        # tests to replace it with a suitable mockup
+        self._create_connection = socket.create_connection
 
     def set_tunnel(self, host, port=None, headers=None):
-        """ Sets up the host and the port for the HTTP CONNECT Tunnelling.
+        """Set up host and port for HTTP CONNECT tunnelling.
+
+        In a connection that uses HTTP CONNECT tunneling, the host passed to the
+        constructor is used as a proxy server that relays all communication to
+        the endpoint passed to `set_tunnel`. This done by sending an HTTP
+        CONNECT request to the proxy server when the connection is established.
 
-        The headers argument should be a mapping of extra HTTP headers
-        to send with the CONNECT request.
+        This method must be called before the HTML connection has been
+        established.
+
+        The headers argument should be a mapping of extra HTTP headers to send
+        with the CONNECT request.
         """
+
+        if self.sock:
+            raise RuntimeError("Can't set up tunnel for established connection")
+
         self._tunnel_host = host
         self._tunnel_port = port
         if headers:
@@ -762,7 +778,7 @@ class HTTPConnection:
         else:
             self._tunnel_headers.clear()
 
-    def _set_hostport(self, host, port):
+    def _get_hostport(self, host, port):
         if port is None:
             i = host.rfind(':')
             j = host.rfind(']')         # ipv6 addresses have [...]
@@ -779,15 +795,16 @@ class HTTPConnection:
                 port = self.default_port
             if host and host[0] == '[' and host[-1] == ']':
                 host = host[1:-1]
-        self.host = host
-        self.port = port
+
+        return (host, port)
 
     def set_debuglevel(self, level):
         self.debuglevel = level
 
     def _tunnel(self):
-        self._set_hostport(self._tunnel_host, self._tunnel_port)
-        connect_str = "CONNECT %s:%d HTTP/1.0\r\n" % (self.host, self.port)
+        (host, port) = self._get_hostport(self._tunnel_host,
+                                          self._tunnel_port)
+        connect_str = "CONNECT %s:%d HTTP/1.0\r\n" % (host, port)
         connect_bytes = connect_str.encode("ascii")
         self.send(connect_bytes)
         for header, value in self._tunnel_headers.items():
@@ -815,8 +832,9 @@ class HTTPConnection:
 
     def connect(self):
         """Connect to the host and port specified in __init__."""
-        self.sock = socket.create_connection((self.host,self.port),
-                                             self.timeout, self.source_address)
+        self.sock = self._create_connection((self.host,self.port),
+                                            self.timeout, self.source_address)
+
         if self._tunnel_host:
             self._tunnel()
 
@@ -985,22 +1003,29 @@ class HTTPConnection:
                         netloc_enc = netloc.encode("idna")
                     self.putheader('Host', netloc_enc)
                 else:
+                    if self._tunnel_host:
+                        host = self._tunnel_host
+                        port = self._tunnel_port
+                    else:
+                        host = self.host
+                        port = self.port
+
                     try:
-                        host_enc = self.host.encode("ascii")
+                        host_enc = host.encode("ascii")
                     except UnicodeEncodeError:
-                        host_enc = self.host.encode("idna")
+                        host_enc = host.encode("idna")
 
                     # As per RFC 273, IPv6 address should be wrapped with []
                     # when used as Host header
 
-                    if self.host.find(':') >= 0:
+                    if host.find(':') >= 0:
                         host_enc = b'[' + host_enc + b']'
 
-                    if self.port == self.default_port:
+                    if port == self.default_port:
                         self.putheader('Host', host_enc)
                     else:
                         host_enc = host_enc.decode("ascii")
-                        self.putheader('Host', "%s:%s" % (host_enc, self.port))
+                        self.putheader('Host', "%s:%s" % (host_enc, port))
 
             # note: we are assuming that clients will not attempt to set these
             #       headers since *this* library must deal with the
@@ -1193,19 +1218,19 @@ else:
         def connect(self):
             "Connect to a host on a given (SSL) port."
 
-            sock = socket.create_connection((self.host, self.port),
-                                            self.timeout, self.source_address)
+            super().connect()
 
             if self._tunnel_host:
-                self.sock = sock
-                self._tunnel()
+                server_hostname = self._tunnel_host
+            else:
+                server_hostname = self.host
+            sni_hostname = server_hostname if ssl.HAS_SNI else None
 
-            server_hostname = self.host if ssl.HAS_SNI else None
-            self.sock = self._context.wrap_socket(sock,
-                                                  server_hostname=server_hostname)
+            self.sock = self._context.wrap_socket(self.sock,
+                                                  server_hostname=sni_hostname)
             if not self._context.check_hostname and self._check_hostname:
                 try:
-                    ssl.match_hostname(self.sock.getpeercert(), self.host)
+                    ssl.match_hostname(self.sock.getpeercert(), server_hostname)
                 except Exception:
                     self.sock.shutdown(socket.SHUT_RDWR)
                     self.sock.close()
index 30b6c0cfcbbe28b1ff432a8f1a04bd9cdf714eb9..22f7329886e30867c683f416cfaddf4ae3e6cce9 100644 (file)
@@ -21,13 +21,15 @@ CACERT_svn_python_org = os.path.join(here, 'https_svn_python_org_root.pem')
 HOST = support.HOST
 
 class FakeSocket:
-    def __init__(self, text, fileclass=io.BytesIO):
+    def __init__(self, text, fileclass=io.BytesIO, host=None, port=None):
         if isinstance(text, str):
             text = text.encode("ascii")
         self.text = text
         self.fileclass = fileclass
         self.data = b''
         self.sendall_calls = 0
+        self.host = host
+        self.port = port
 
     def sendall(self, data):
         self.sendall_calls += 1
@@ -38,6 +40,9 @@ class FakeSocket:
             raise client.UnimplementedFileMode()
         return self.fileclass(self.text)
 
+    def close(self):
+        pass
+
 class EPipeSocket(FakeSocket):
 
     def __init__(self, text, pipe_trigger):
@@ -970,10 +975,51 @@ class HTTPResponseTest(TestCase):
         header = self.resp.getheader('No-Such-Header',default=42)
         self.assertEqual(header, 42)
 
+class TunnelTests(TestCase):
+
+    def test_connect(self):
+        response_text = (
+            'HTTP/1.0 200 OK\r\n\r\n' # Reply to CONNECT
+            'HTTP/1.1 200 OK\r\n' # Reply to HEAD
+            'Content-Length: 42\r\n\r\n'
+        )
+
+        def create_connection(address, timeout=None, source_address=None):
+            return FakeSocket(response_text, host=address[0],
+                              port=address[1])
+
+        conn = client.HTTPConnection('proxy.com')
+        conn._create_connection = create_connection
+
+        # Once connected, we shouldn't be able to tunnel anymore
+        conn.connect()
+        self.assertRaises(RuntimeError, conn.set_tunnel,
+                          'destination.com')
+
+        # But if we close the connection, we're good
+        conn.close()
+        conn.set_tunnel('destination.com')
+        conn.request('HEAD', '/', '')
+
+        self.assertEqual(conn.sock.host, 'proxy.com')
+        self.assertEqual(conn.sock.port, 80)
+        self.assertTrue(b'CONNECT destination.com' in conn.sock.data)
+        self.assertTrue(b'Host: destination.com' in conn.sock.data)
+
+        # This test should be removed when CONNECT gets the HTTP/1.1 blessing
+        self.assertTrue(b'Host: proxy.com' not in conn.sock.data)
+
+        conn.close()
+        conn.request('PUT', '/', '')
+        self.assertEqual(conn.sock.host, 'proxy.com')
+        self.assertEqual(conn.sock.port, 80)
+        self.assertTrue(b'CONNECT destination.com' in conn.sock.data)
+        self.assertTrue(b'Host: destination.com' in conn.sock.data)
+
 def test_main(verbose=None):
     support.run_unittest(HeaderTests, OfflineTest, BasicTest, TimeoutTest,
                          HTTPSTest, RequestBodyTest, SourceAddressTest,
-                         HTTPResponseTest)
+                         HTTPResponseTest, TunnelTests)
 
 if __name__ == '__main__':
     test_main()
index 240c6cc0ec796913604c7974bc8f6fee2d4804cb..9629b5e60fce7cafeeca0d821c68ca5325716eaa 100644 (file)
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -33,6 +33,9 @@ Core and Builtins
 Library
 -------
 
+- Issue #7776: Fix ``Host:'' header and reconnection when using
+  http.client.HTTPConnection.set_tunnel(). Patch by Nikolaus Rath.
+
 - Issue #20968: unittest.mock.MagicMock now supports division.
   Patch by Johannes Baiter.