]> granicus.if.org Git - python/commitdiff
Issue #9003: http.client.HTTPSConnection, urllib.request.HTTPSHandler and
authorAntoine Pitrou <solipsis@pitrou.net>
Wed, 13 Oct 2010 10:36:15 +0000 (10:36 +0000)
committerAntoine Pitrou <solipsis@pitrou.net>
Wed, 13 Oct 2010 10:36:15 +0000 (10:36 +0000)
urllib.request.urlopen now take optional arguments to allow for
server certificate checking, as recommended in public uses of HTTPS.

Doc/library/http.client.rst
Doc/library/urllib.request.rst
Lib/http/client.py
Lib/test/keycert2.pem [new file with mode: 0644]
Lib/test/make_ssl_certs.py
Lib/test/ssl_servers.py [new file with mode: 0644]
Lib/test/test_httplib.py
Lib/test/test_ssl.py
Lib/test/test_urllib2_localnet.py
Lib/urllib/request.py
Misc/NEWS

index 6c96731844526cd2c235c83d6ea132d7237b063b..90e16a7325ec6ec6f6690ed5fa8c0c3d40a0c95b 100644 (file)
@@ -50,19 +50,31 @@ The module provides the following classes:
       *source_address* was added.
 
 
-.. class:: HTTPSConnection(host, port=None, key_file=None, cert_file=None, strict=None[, timeout[, source_address]])
+.. class:: HTTPSConnection(host, port=None, key_file=None, cert_file=None, strict=None[, timeout[, source_address]], *, context=None, check_hostname=None)
 
    A subclass of :class:`HTTPConnection` that uses SSL for communication with
-   secure servers.  Default port is ``443``.  *key_file* is the name of a PEM
-   formatted file that contains your private key, and *cert_file* is a PEM
-   formatted certificate chain file; both can be used for authenticating
-   yourself against the server.
-
-   .. warning::
-      This does not do any verification of the server's certificate.
+   secure servers.  Default port is ``443``.  If *context* is specified, it
+   must be a :class:`ssl.SSLContext` instance describing the various SSL
+   options.  If *context* is specified and has a :attr:`~ssl.SSLContext.verify_mode`
+   of either :data:`~ssl.CERT_OPTIONAL` or :data:`~ssl.CERT_REQUIRED`, then
+   by default *host* is matched against the host name(s) allowed by the
+   server's certificate.  If you want to change that behaviour, you can
+   explicitly set *check_hostname* to False.
+
+   *key_file* and *cert_file* are deprecated, please use
+   :meth:`ssl.SSLContext.load_cert_chain` instead.
+
+   If you access arbitrary hosts on the Internet, it is recommended to
+   require certificate checking and feed the *context* with a set of
+   trusted CA certificates::
+
+      context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+      context.verify_mode = ssl.CERT_REQUIRED
+      context.load_verify_locations('/etc/pki/tls/certs/ca-bundle.crt')
+      h = client.HTTPSConnection('svn.python.org', 443, context=context)
 
    .. versionchanged:: 3.2
-      *source_address* was added.
+      *source_address*, *context* and *check_hostname* were added.
 
 
 .. class:: HTTPResponse(sock, debuglevel=0, strict=0, method=None, url=None)
index a697bdd33b722215f059cd45e9c5d3047b3efb80..21c1c2fd31cc3621cd516518d742528dba1008d3 100644 (file)
@@ -15,14 +15,11 @@ authentication, redirections, cookies and more.
 The :mod:`urllib.request` module defines the following functions:
 
 
-.. function:: urlopen(url, data=None[, timeout])
+.. function:: urlopen(url, data=None[, timeout], *, cafile=None, capath=None)
 
    Open the URL *url*, which can be either a string or a
    :class:`Request` object.
 
-   .. warning::
-      HTTPS requests do not do any verification of the server's certificate.
-
    *data* may be a string specifying additional data to send to the
    server, or ``None`` if no such data is needed.  Currently HTTP
    requests are the only ones that use *data*; the HTTP request will
@@ -38,6 +35,16 @@ The :mod:`urllib.request` module defines the following functions:
    the global default timeout setting will be used).  This actually
    only works for HTTP, HTTPS and FTP connections.
 
+   The optional *cafile* and *capath* parameters specify a set of trusted
+   CA certificates for HTTPS requests.  *cafile* should point to a single
+   file containing a bundle of CA certificates, whereas *capath* should
+   point to a directory of hashed certificate files.  More information can
+   be found in :meth:`ssl.SSLContext.load_verify_locations`.
+
+   .. warning::
+      If neither *cafile* nor *capath* is specified, an HTTPS request
+      will not do any verification of the server's certificate.
+
    This function returns a file-like object with two additional methods from
    the :mod:`urllib.response` module
 
@@ -62,6 +69,9 @@ The :mod:`urllib.request` module defines the following functions:
    Proxy handling, which was done by passing a dictionary parameter to
    ``urllib.urlopen``, can be obtained by using :class:`ProxyHandler` objects.
 
+   .. versionchanged:: 3.2
+      *cafile* and *capath* were added.
+
 .. function:: install_opener(opener)
 
    Install an :class:`OpenerDirector` instance as the default global opener.
@@ -421,9 +431,13 @@ The following classes are provided:
    A class to handle opening of HTTP URLs.
 
 
-.. class:: HTTPSHandler()
+.. class:: HTTPSHandler(debuglevel=0, context=None, check_hostname=None)
+
+   A class to handle opening of HTTPS URLs.  *context* and *check_hostname*
+   have the same meaning as in :class:`http.client.HTTPSConnection`.
 
-   A class to handle opening of HTTPS URLs.
+   .. versionchanged:: 3.2
+      *context* and *check_hostname* were added.
 
 
 .. class:: FileHandler()
index 355b36d1cdef287f834dfa534b6d8cf7054ca335..13dc3594ebd39e6a1adabe70313765195f4a7b71 100644 (file)
@@ -1047,13 +1047,29 @@ else:
 
         default_port = HTTPS_PORT
 
+        # XXX Should key_file and cert_file be deprecated in favour of context?
+
         def __init__(self, host, port=None, key_file=None, cert_file=None,
                      strict=None, timeout=socket._GLOBAL_DEFAULT_TIMEOUT,
-                     source_address=None):
+                     source_address=None, *, context=None, check_hostname=None):
             super(HTTPSConnection, self).__init__(host, port, strict, timeout,
                                                   source_address)
             self.key_file = key_file
             self.cert_file = cert_file
+            if context is None:
+                # Some reasonable defaults
+                context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
+                context.options |= ssl.OP_NO_SSLv2
+            will_verify = context.verify_mode != ssl.CERT_NONE
+            if check_hostname is None:
+                check_hostname = will_verify
+            elif check_hostname and not will_verify:
+                raise ValueError("check_hostname needs a SSL context with "
+                                 "either CERT_OPTIONAL or CERT_REQUIRED")
+            if key_file or cert_file:
+                context.load_cert_chain(certfile, keyfile)
+            self._context = context
+            self._check_hostname = check_hostname
 
         def connect(self):
             "Connect to a host on a given (SSL) port."
@@ -1065,7 +1081,14 @@ else:
                 self.sock = sock
                 self._tunnel()
 
-            self.sock = ssl.wrap_socket(sock, self.key_file, self.cert_file)
+            self.sock = self._context.wrap_socket(sock)
+            try:
+                if self._check_hostname:
+                    ssl.match_hostname(self.sock.getpeercert(), self.host)
+            except Exception:
+                self.sock.shutdown(socket.SHUT_RDWR)
+                self.sock.close()
+                raise
 
     __all__.append("HTTPSConnection")
 
diff --git a/Lib/test/keycert2.pem b/Lib/test/keycert2.pem
new file mode 100644 (file)
index 0000000..e8a9e08
--- /dev/null
@@ -0,0 +1,31 @@
+-----BEGIN PRIVATE KEY-----
+MIICdwIBADANBgkqhkiG9w0BAQEFAASCAmEwggJdAgEAAoGBAJnsJZVrppL+W5I9
+zGQrrawWwE5QJpBK9nWw17mXrZ03R1cD9BamLGivVISbPlRlAVnZBEyh1ATpsB7d
+CUQ+WHEvALquvx4+Yw5l+fXeiYRjrLRBYZuVy8yNtXzU3iWcGObcYRkUdiXdOyP7
+sLF2YZHRvQZpzgDBKkrraeQ81w21AgMBAAECgYBEm7n07FMHWlE+0kT0sXNsLYfy
+YE+QKZnJw9WkaDN+zFEEPELkhZVt5BjsMraJr6v2fIEqF0gGGJPkbenffVq2B5dC
+lWUOxvJHufMK4sM3Cp6s/gOp3LP+QkzVnvJSfAyZU6l+4PGX5pLdUsXYjPxgzjzL
+S36tF7/2Uv1WePyLUQJBAMsPhYzUXOPRgmbhcJiqi9A9c3GO8kvSDYTCKt3VMnqz
+HBn6MQ4VQasCD1F+7jWTI0FU/3vdw8non/Fj8hhYqZcCQQDCDRdvmZqDiZnpMqDq
+L6ZSrLTVtMvZXZbgwForaAD9uHj51TME7+eYT7EG2YCgJTXJ4YvRJEnPNyskwdKt
+vTSTAkEAtaaN/vyemEJ82BIGStwONNw0ILsSr5cZ9tBHzqiA/tipY+e36HRFiXhP
+QcU9zXlxyWkDH8iz9DSAmE2jbfoqwwJANlMJ65E543cjIlitGcKLMnvtCCLcKpb7
+xSG0XJB6Lo11OKPJ66jp0gcFTSCY1Lx2CXVd+gfJrfwI1Pp562+bhwJBAJ9IfDPU
+R8OpO9v1SGd8x33Owm7uXOpB9d63/T70AD1QOXjKUC4eXYbt0WWfWuny/RNPRuyh
+w7DXSfUF+kPKolU=
+-----END PRIVATE KEY-----
+-----BEGIN CERTIFICATE-----
+MIICXTCCAcagAwIBAgIJAIO3upAG445fMA0GCSqGSIb3DQEBBQUAMGIxCzAJBgNV
+BAYTAlhZMRcwFQYDVQQHEw5DYXN0bGUgQW50aHJheDEjMCEGA1UEChMaUHl0aG9u
+IFNvZnR3YXJlIEZvdW5kYXRpb24xFTATBgNVBAMTDGZha2Vob3N0bmFtZTAeFw0x
+MDEwMDkxNTAxMDBaFw0yMDEwMDYxNTAxMDBaMGIxCzAJBgNVBAYTAlhZMRcwFQYD
+VQQHEw5DYXN0bGUgQW50aHJheDEjMCEGA1UEChMaUHl0aG9uIFNvZnR3YXJlIEZv
+dW5kYXRpb24xFTATBgNVBAMTDGZha2Vob3N0bmFtZTCBnzANBgkqhkiG9w0BAQEF
+AAOBjQAwgYkCgYEAmewllWumkv5bkj3MZCutrBbATlAmkEr2dbDXuZetnTdHVwP0
+FqYsaK9UhJs+VGUBWdkETKHUBOmwHt0JRD5YcS8Auq6/Hj5jDmX59d6JhGOstEFh
+m5XLzI21fNTeJZwY5txhGRR2Jd07I/uwsXZhkdG9BmnOAMEqSutp5DzXDbUCAwEA
+AaMbMBkwFwYDVR0RBBAwDoIMZmFrZWhvc3RuYW1lMA0GCSqGSIb3DQEBBQUAA4GB
+AH+iMClLLGSaKWgwXsmdVo4FhTZZHo8Uprrtg3N9FxEeE50btpDVQysgRt5ias3K
+m+bME9zbKwvbVWD5zZdjus4pDgzwF/iHyccL8JyYhxOvS/9zmvAtFXj/APIIbZFp
+IT75d9f88ScIGEtknZQejnrdhB64tYki/EqluiuKBqKD
+-----END CERTIFICATE-----
index 3e25fc218825886f7373de99aa6e152a512d330d..48d2e57f4be7e9d833c7bb9ca88a5ca869fc6d29 100644 (file)
@@ -57,3 +57,8 @@ if __name__ == '__main__':
     with open('keycert.pem', 'w') as f:
         f.write(key)
         f.write(cert)
+    # For certificate matching tests
+    cert, key = make_cert_key('fakehostname')
+    with open('keycert2.pem', 'w') as f:
+        f.write(key)
+        f.write(cert)
diff --git a/Lib/test/ssl_servers.py b/Lib/test/ssl_servers.py
new file mode 100644 (file)
index 0000000..d0736b1
--- /dev/null
@@ -0,0 +1,119 @@
+import os
+import sys
+import ssl
+import threading
+import urllib.parse
+# Rename HTTPServer to _HTTPServer so as to avoid confusion with HTTPSServer.
+from http.server import HTTPServer as _HTTPServer, SimpleHTTPRequestHandler
+
+from test import support
+
+here = os.path.dirname(__file__)
+
+HOST = support.HOST
+CERTFILE = os.path.join(here, 'keycert.pem')
+
+# This one's based on HTTPServer, which is based on SocketServer
+
+class HTTPSServer(_HTTPServer):
+
+    def __init__(self, server_address, handler_class, context):
+        _HTTPServer.__init__(self, server_address, handler_class)
+        self.context = context
+
+    def __str__(self):
+        return ('<%s %s:%s>' %
+                (self.__class__.__name__,
+                 self.server_name,
+                 self.server_port))
+
+    def get_request(self):
+        # override this to wrap socket with SSL
+        sock, addr = self.socket.accept()
+        sslconn = self.context.wrap_socket(sock, server_side=True)
+        return sslconn, addr
+
+class RootedHTTPRequestHandler(SimpleHTTPRequestHandler):
+    # need to override translate_path to get a known root,
+    # instead of using os.curdir, since the test could be
+    # run from anywhere
+
+    server_version = "TestHTTPS/1.0"
+    root = here
+    # Avoid hanging when a request gets interrupted by the client
+    timeout = 5
+
+    def translate_path(self, path):
+        """Translate a /-separated PATH to the local filename syntax.
+
+        Components that mean special things to the local file system
+        (e.g. drive or directory names) are ignored.  (XXX They should
+        probably be diagnosed.)
+
+        """
+        # abandon query parameters
+        path = urllib.parse.urlparse(path)[2]
+        path = os.path.normpath(urllib.parse.unquote(path))
+        words = path.split('/')
+        words = filter(None, words)
+        path = self.root
+        for word in words:
+            drive, word = os.path.splitdrive(word)
+            head, word = os.path.split(word)
+            path = os.path.join(path, word)
+        return path
+
+    def log_message(self, format, *args):
+        # we override this to suppress logging unless "verbose"
+        if support.verbose:
+            sys.stdout.write(" server (%s:%d %s):\n   [%s] %s\n" %
+                             (self.server.server_address,
+                              self.server.server_port,
+                              self.request.cipher(),
+                              self.log_date_time_string(),
+                              format%args))
+
+class HTTPSServerThread(threading.Thread):
+
+    def __init__(self, context, host=HOST, handler_class=None):
+        self.flag = None
+        self.server = HTTPSServer((host, 0),
+                                  handler_class or RootedHTTPRequestHandler,
+                                  context)
+        self.port = self.server.server_port
+        threading.Thread.__init__(self)
+        self.daemon = True
+
+    def __str__(self):
+        return "<%s %s>" % (self.__class__.__name__, self.server)
+
+    def start(self, flag=None):
+        self.flag = flag
+        threading.Thread.start(self)
+
+    def run(self):
+        if self.flag:
+            self.flag.set()
+        self.server.serve_forever(0.05)
+
+    def stop(self):
+        self.server.shutdown()
+
+
+def make_https_server(case, certfile=CERTFILE, host=HOST, handler_class=None):
+    # we assume the certfile contains both private key and certificate
+    context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+    context.load_cert_chain(certfile)
+    server = HTTPSServerThread(context, host, handler_class)
+    flag = threading.Event()
+    server.start(flag)
+    flag.wait()
+    def cleanup():
+        if support.verbose:
+            sys.stdout.write('stopping HTTPS server\n')
+        server.stop()
+        if support.verbose:
+            sys.stdout.write('joining HTTPS thread\n')
+        server.join()
+    case.addCleanup(cleanup)
+    return server
index ab2a3e63735d8c94af046969a96beb21e9fe9672..62b444056ac1fcbc9cca7f28008e81104931f8c6 100644 (file)
@@ -1,6 +1,7 @@
 import errno
 from http import client
 import io
+import os
 import array
 import socket
 
@@ -9,6 +10,14 @@ TestCase = unittest.TestCase
 
 from test import support
 
+here = os.path.dirname(__file__)
+# Self-signed cert file for 'localhost'
+CERT_localhost = os.path.join(here, 'keycert.pem')
+# Self-signed cert file for 'fakehostname'
+CERT_fakehostname = os.path.join(here, 'keycert2.pem')
+# Root cert file (CA) for svn.python.org's cert
+CACERT_svn_python_org = os.path.join(here, 'https_svn_python_org_root.pem')
+
 HOST = support.HOST
 
 class FakeSocket:
@@ -370,14 +379,97 @@ class TimeoutTest(TestCase):
         self.assertEqual(httpConn.sock.gettimeout(), 30)
         httpConn.close()
 
-class HTTPSTimeoutTest(TestCase):
-# XXX Here should be tests for HTTPS, there isn't any right now!
+
+class HTTPSTest(TestCase):
+
+    def setUp(self):
+        if not hasattr(client, 'HTTPSConnection'):
+            self.skipTest('ssl support required')
+
+    def make_server(self, certfile):
+        from test.ssl_servers import make_https_server
+        return make_https_server(self, certfile)
 
     def test_attributes(self):
-        # simple test to check it's storing it
-        if hasattr(client, 'HTTPSConnection'):
-            h = client.HTTPSConnection(HOST, TimeoutTest.PORT, timeout=30)
-            self.assertEqual(h.timeout, 30)
+        # simple test to check it's storing the timeout
+        h = client.HTTPSConnection(HOST, TimeoutTest.PORT, timeout=30)
+        self.assertEqual(h.timeout, 30)
+
+    def _check_svn_python_org(self, resp):
+        # Just a simple check that everything went fine
+        server_string = resp.getheader('server')
+        self.assertIn('Apache', server_string)
+
+    def test_networked(self):
+        # Default settings: no cert verification is done
+        support.requires('network')
+        with support.transient_internet('svn.python.org'):
+            h = client.HTTPSConnection('svn.python.org', 443)
+            h.request('GET', '/')
+            resp = h.getresponse()
+            self._check_svn_python_org(resp)
+
+    def test_networked_good_cert(self):
+        # We feed a CA cert that validates the server's cert
+        import ssl
+        support.requires('network')
+        with support.transient_internet('svn.python.org'):
+            context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+            context.verify_mode = ssl.CERT_REQUIRED
+            context.load_verify_locations(CACERT_svn_python_org)
+            h = client.HTTPSConnection('svn.python.org', 443, context=context)
+            h.request('GET', '/')
+            resp = h.getresponse()
+            self._check_svn_python_org(resp)
+
+    def test_networked_bad_cert(self):
+        # We feed a "CA" cert that is unrelated to the server's cert
+        import ssl
+        support.requires('network')
+        with support.transient_internet('svn.python.org'):
+            context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+            context.verify_mode = ssl.CERT_REQUIRED
+            context.load_verify_locations(CERT_localhost)
+            h = client.HTTPSConnection('svn.python.org', 443, context=context)
+            with self.assertRaises(ssl.SSLError):
+                h.request('GET', '/')
+
+    def test_local_good_hostname(self):
+        # The (valid) cert validates the HTTP hostname
+        import ssl
+        from test.ssl_servers import make_https_server
+        server = make_https_server(self, CERT_localhost)
+        context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+        context.verify_mode = ssl.CERT_REQUIRED
+        context.load_verify_locations(CERT_localhost)
+        h = client.HTTPSConnection('localhost', server.port, context=context)
+        h.request('GET', '/nonexistent')
+        resp = h.getresponse()
+        self.assertEqual(resp.status, 404)
+
+    def test_local_bad_hostname(self):
+        # The (valid) cert doesn't validate the HTTP hostname
+        import ssl
+        from test.ssl_servers import make_https_server
+        server = make_https_server(self, CERT_fakehostname)
+        context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+        context.verify_mode = ssl.CERT_REQUIRED
+        context.load_verify_locations(CERT_fakehostname)
+        h = client.HTTPSConnection('localhost', server.port, context=context)
+        with self.assertRaises(ssl.CertificateError):
+            h.request('GET', '/')
+        # Same with explicit check_hostname=True
+        h = client.HTTPSConnection('localhost', server.port, context=context,
+                                   check_hostname=True)
+        with self.assertRaises(ssl.CertificateError):
+            h.request('GET', '/')
+        # With check_hostname=False, the mismatching is ignored
+        h = client.HTTPSConnection('localhost', server.port, context=context,
+                                   check_hostname=False)
+        h.request('GET', '/nonexistent')
+        resp = h.getresponse()
+        self.assertEqual(resp.status, 404)
+
 
 class RequestBodyTest(TestCase):
     """Test cases where a request includes a message body."""
@@ -488,7 +580,7 @@ class HTTPResponseTest(TestCase):
 
 def test_main(verbose=None):
     support.run_unittest(HeaderTests, OfflineTest, BasicTest, TimeoutTest,
-                         HTTPSTimeoutTest, RequestBodyTest, SourceAddressTest,
+                         HTTPSTest, RequestBodyTest, SourceAddressTest,
                          HTTPResponseTest)
 
 if __name__ == '__main__':
index bf08f6fe1eb5dec412ef4d10e0fbf6e168284e1b..ca74e7132e7a5c8466ea40232f0dafd03a431c5e 100644 (file)
@@ -11,15 +11,13 @@ import os
 import errno
 import pprint
 import tempfile
-import urllib.parse, urllib.request
+import urllib.request
 import traceback
 import asyncore
 import weakref
 import platform
 import functools
 
-from http.server import HTTPServer, SimpleHTTPRequestHandler
-
 # Optionally test SSL support, if we have it in the tested platform
 skip_expected = False
 try:
@@ -605,6 +603,8 @@ except ImportError:
 else:
     _have_threads = True
 
+    from test.ssl_servers import make_https_server
+
     class ThreadedEchoServer(threading.Thread):
 
         class ConnectionHandler(threading.Thread):
@@ -774,98 +774,6 @@ else:
         def stop(self):
             self.active = False
 
-    class OurHTTPSServer(threading.Thread):
-
-        # This one's based on HTTPServer, which is based on SocketServer
-
-        class HTTPSServer(HTTPServer):
-
-            def __init__(self, server_address, RequestHandlerClass, certfile):
-                HTTPServer.__init__(self, server_address, RequestHandlerClass)
-                # we assume the certfile contains both private key and certificate
-                self.certfile = certfile
-                self.allow_reuse_address = True
-
-            def __str__(self):
-                return ('<%s %s:%s>' %
-                        (self.__class__.__name__,
-                         self.server_name,
-                         self.server_port))
-
-            def get_request(self):
-                # override this to wrap socket with SSL
-                sock, addr = self.socket.accept()
-                sslconn = ssl.wrap_socket(sock, server_side=True,
-                                          certfile=self.certfile)
-                return sslconn, addr
-
-        class RootedHTTPRequestHandler(SimpleHTTPRequestHandler):
-            # need to override translate_path to get a known root,
-            # instead of using os.curdir, since the test could be
-            # run from anywhere
-
-            server_version = "TestHTTPS/1.0"
-
-            root = None
-
-            def translate_path(self, path):
-                """Translate a /-separated PATH to the local filename syntax.
-
-                Components that mean special things to the local file system
-                (e.g. drive or directory names) are ignored.  (XXX They should
-                probably be diagnosed.)
-
-                """
-                # abandon query parameters
-                path = urllib.parse.urlparse(path)[2]
-                path = os.path.normpath(urllib.parse.unquote(path))
-                words = path.split('/')
-                words = filter(None, words)
-                path = self.root
-                for word in words:
-                    drive, word = os.path.splitdrive(word)
-                    head, word = os.path.split(word)
-                    if word in self.root: continue
-                    path = os.path.join(path, word)
-                return path
-
-            def log_message(self, format, *args):
-                # we override this to suppress logging unless "verbose"
-
-                if support.verbose:
-                    sys.stdout.write(" server (%s:%d %s):\n   [%s] %s\n" %
-                                     (self.server.server_address,
-                                      self.server.server_port,
-                                      self.request.cipher(),
-                                      self.log_date_time_string(),
-                                      format%args))
-
-
-        def __init__(self, certfile):
-            self.flag = None
-            self.RootedHTTPRequestHandler.root = os.path.split(CERTFILE)[0]
-            self.server = self.HTTPSServer(
-                (HOST, 0), self.RootedHTTPRequestHandler, certfile)
-            self.port = self.server.server_port
-            threading.Thread.__init__(self)
-            self.daemon = True
-
-        def __str__(self):
-            return "<%s %s>" % (self.__class__.__name__, self.server)
-
-        def start(self, flag=None):
-            self.flag = flag
-            threading.Thread.start(self)
-
-        def run(self):
-            if self.flag:
-                self.flag.set()
-            self.server.serve_forever(0.05)
-
-        def stop(self):
-            self.server.shutdown()
-
-
     class AsyncoreEchoServer(threading.Thread):
 
         # this one's based on asyncore.dispatcher
@@ -1349,22 +1257,18 @@ else:
 
         def test_socketserver(self):
             """Using a SocketServer to create and manage SSL connections."""
-            server = OurHTTPSServer(CERTFILE)
-            flag = threading.Event()
-            server.start(flag)
-            # wait for it to start
-            flag.wait()
+            server = make_https_server(self, CERTFILE)
             # try to connect
+            if support.verbose:
+                sys.stdout.write('\n')
+            with open(CERTFILE, 'rb') as f:
+                d1 = f.read()
+            d2 = ''
+            # now fetch the same data from the HTTPS server
+            url = 'https://%s:%d/%s' % (
+                HOST, server.port, os.path.split(CERTFILE)[1])
+            f = urllib.request.urlopen(url)
             try:
-                if support.verbose:
-                    sys.stdout.write('\n')
-                with open(CERTFILE, 'rb') as f:
-                    d1 = f.read()
-                d2 = ''
-                # now fetch the same data from the HTTPS server
-                url = 'https://%s:%d/%s' % (
-                    HOST, server.port, os.path.split(CERTFILE)[1])
-                f = urllib.request.urlopen(url)
                 dlen = f.info().get("content-length")
                 if dlen and (int(dlen) > 0):
                     d2 = f.read(int(dlen))
@@ -1372,15 +1276,9 @@ else:
                         sys.stdout.write(
                             " client: read %d bytes from remote server '%s'\n"
                             % (len(d2), server))
-                f.close()
-                self.assertEqual(d1, d2)
             finally:
-                if support.verbose:
-                    sys.stdout.write('stopping server\n')
-                server.stop()
-                if support.verbose:
-                    sys.stdout.write('joining thread\n')
-                server.join()
+                f.close()
+            self.assertEqual(d1, d2)
 
         def test_asyncore_server(self):
             """Check the example asyncore integration."""
index de8a521033fd8e3f3ab9588b4f0af092265def9b..872b2be7bef3ec2a7e1ea856aec8959d77ec3809 100644 (file)
@@ -1,5 +1,6 @@
 #!/usr/bin/env python3
 
+import os
 import email
 import urllib.parse
 import urllib.request
@@ -9,6 +10,13 @@ import hashlib
 from test import support
 threading = support.import_module('threading')
 
+
+here = os.path.dirname(__file__)
+# Self-signed cert file for 'localhost'
+CERT_localhost = os.path.join(here, 'keycert.pem')
+# Self-signed cert file for 'fakehostname'
+CERT_fakehostname = os.path.join(here, 'keycert2.pem')
+
 # Loopback http server infrastructure
 
 class LoopbackHttpServer(http.server.HTTPServer):
@@ -23,7 +31,7 @@ class LoopbackHttpServer(http.server.HTTPServer):
 
         # Set the timeout of our listening socket really low so
         # that we can stop the server easily.
-        self.socket.settimeout(1.0)
+        self.socket.settimeout(0.1)
 
     def get_request(self):
         """HTTPServer method, overridden."""
@@ -221,15 +229,7 @@ class FakeProxyHandler(http.server.BaseHTTPRequestHandler):
 
 # Test cases
 
-class BaseTestCase(unittest.TestCase):
-    def setUp(self):
-        self._threads = support.threading_setup()
-
-    def tearDown(self):
-        support.threading_cleanup(*self._threads)
-
-
-class ProxyAuthTests(BaseTestCase):
+class ProxyAuthTests(unittest.TestCase):
     URL = "http://localhost"
 
     USER = "tester"
@@ -340,7 +340,7 @@ def GetRequestHandler(responses):
     return FakeHTTPRequestHandler
 
 
-class TestUrlopen(BaseTestCase):
+class TestUrlopen(unittest.TestCase):
     """Tests urllib.request.urlopen using the network.
 
     These tests are not exhaustive.  Assuming that testing using files does a
@@ -358,9 +358,9 @@ class TestUrlopen(BaseTestCase):
             self.server.stop()
         super(TestUrlopen, self).tearDown()
 
-    def urlopen(self, url, data=None):
+    def urlopen(self, url, data=None, **kwargs):
         l = []
-        f = urllib.request.urlopen(url, data)
+        f = urllib.request.urlopen(url, data, **kwargs)
         try:
             # Exercise various methods
             l.extend(f.readlines(200))
@@ -383,6 +383,17 @@ class TestUrlopen(BaseTestCase):
         handler.port = port
         return handler
 
+    def start_https_server(self, responses=None, certfile=CERT_localhost):
+        if not hasattr(urllib.request, 'HTTPSHandler'):
+            self.skipTest('ssl support required')
+        from test.ssl_servers import make_https_server
+        if responses is None:
+            responses = [(200, [], b"we care a bit")]
+        handler = GetRequestHandler(responses)
+        server = make_https_server(self, certfile=certfile, handler_class=handler)
+        handler.port = server.port
+        return handler
+
     def test_redirection(self):
         expected_response = b"We got here..."
         responses = [
@@ -440,6 +451,28 @@ class TestUrlopen(BaseTestCase):
         self.assertEqual(data, expected_response)
         self.assertEqual(handler.requests, ["/bizarre", b"get=with_feeling"])
 
+    def test_https(self):
+        handler = self.start_https_server()
+        data = self.urlopen("https://localhost:%s/bizarre" % handler.port)
+        self.assertEqual(data, b"we care a bit")
+
+    def test_https_with_cafile(self):
+        handler = self.start_https_server(certfile=CERT_localhost)
+        import ssl
+        # Good cert
+        data = self.urlopen("https://localhost:%s/bizarre" % handler.port,
+                            cafile=CERT_localhost)
+        self.assertEqual(data, b"we care a bit")
+        # Bad cert
+        with self.assertRaises(urllib.error.URLError) as cm:
+            self.urlopen("https://localhost:%s/bizarre" % handler.port,
+                         cafile=CERT_fakehostname)
+        # Good cert, but mismatching hostname
+        handler = self.start_https_server(certfile=CERT_fakehostname)
+        with self.assertRaises(ssl.CertificateError) as cm:
+            self.urlopen("https://localhost:%s/bizarre" % handler.port,
+                         cafile=CERT_fakehostname)
+
     def test_sending_headers(self):
         handler = self.start_server()
         req = urllib.request.Request("http://localhost:%s/" % handler.port,
@@ -521,6 +554,8 @@ class TestUrlopen(BaseTestCase):
                              (index, len(lines[index]), len(line)))
         self.assertEqual(index + 1, len(lines))
 
+
+@support.reap_threads
 def test_main():
     support.run_unittest(ProxyAuthTests, TestUrlopen)
 
index 6c6450f8c0e8625ce68cd2dd429e7fcf5e81a675..e2845c9b5e6cd42535ec4c83ff1ddddf543265e9 100644 (file)
@@ -114,11 +114,27 @@ else:
 __version__ = sys.version[:3]
 
 _opener = None
-def urlopen(url, data=None, timeout=socket._GLOBAL_DEFAULT_TIMEOUT):
+def urlopen(url, data=None, timeout=socket._GLOBAL_DEFAULT_TIMEOUT,
+            *, cafile=None, capath=None):
     global _opener
-    if _opener is None:
-        _opener = build_opener()
-    return _opener.open(url, data, timeout)
+    if cafile or capath:
+        if not _have_ssl:
+            raise ValueError('SSL support not available')
+        context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
+        context.options |= ssl.OP_NO_SSLv2
+        if cafile or capath:
+            context.verify_mode = ssl.CERT_REQUIRED
+            context.load_verify_locations(cafile, capath)
+            check_hostname = True
+        else:
+            check_hostname = False
+        https_handler = HTTPSHandler(context=context, check_hostname=check_hostname)
+        opener = build_opener(https_handler)
+    elif _opener is None:
+        _opener = opener = build_opener()
+    else:
+        opener = _opener
+    return opener.open(url, data, timeout)
 
 def install_opener(opener):
     global _opener
@@ -1053,7 +1069,7 @@ class AbstractHTTPHandler(BaseHandler):
 
         return request
 
-    def do_open(self, http_class, req):
+    def do_open(self, http_class, req, **http_conn_args):
         """Return an HTTPResponse object for the request, using http_class.
 
         http_class must implement the HTTPConnection API from http.client.
@@ -1062,7 +1078,8 @@ class AbstractHTTPHandler(BaseHandler):
         if not host:
             raise URLError('no host given')
 
-        h = http_class(host, timeout=req.timeout) # will parse host:port
+        # will parse host:port
+        h = http_class(host, timeout=req.timeout, **http_conn_args)
 
         headers = dict(req.unredirected_hdrs)
         headers.update(dict((k, v) for k, v in req.headers.items()
@@ -1114,10 +1131,18 @@ class HTTPHandler(AbstractHTTPHandler):
     http_request = AbstractHTTPHandler.do_request_
 
 if hasattr(http.client, 'HTTPSConnection'):
+    import ssl
+
     class HTTPSHandler(AbstractHTTPHandler):
 
+        def __init__(self, debuglevel=0, context=None, check_hostname=None):
+            AbstractHTTPHandler.__init__(self, debuglevel)
+            self._context = context
+            self._check_hostname = check_hostname
+
         def https_open(self, req):
-            return self.do_open(http.client.HTTPSConnection, req)
+            return self.do_open(http.client.HTTPSConnection, req,
+                context=self._context, check_hostname=self._check_hostname)
 
         https_request = AbstractHTTPHandler.do_request_
 
index 0b7d77cdf8fe0c3d4c9622edbba9616275ea7463..e7a02daa8c9dd87dbbdee1a6da2409db6dc0d329 100644 (file)
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -13,6 +13,10 @@ Core and Builtins
 Library
 -------
 
+- Issue #9003: http.client.HTTPSConnection, urllib.request.HTTPSHandler and
+  urllib.request.urlopen now take optional arguments to allow for
+  server certificate checking, as recommended in public uses of HTTPS.
+
 - Issue #6612: Fix site and sysconfig to catch os.getcwd() error, eg. if the
   current directory was deleted. Patch written by W. Trevor King.