]> granicus.if.org Git - python/commitdiff
allow a SSLContext to be given to ftplib.FTP_TLS
authorBenjamin Peterson <benjamin@python.org>
Sun, 4 Jan 2015 21:36:31 +0000 (15:36 -0600)
committerBenjamin Peterson <benjamin@python.org>
Sun, 4 Jan 2015 21:36:31 +0000 (15:36 -0600)
Doc/library/ftplib.rst
Lib/ftplib.py
Lib/test/test_ftplib.py
Misc/NEWS

index ec07c8785444c8c52efef2c1ff027647a24b5901..cce2ef1c6d2a813ac5f73d7e57783e9ef0f130a8 100644 (file)
@@ -55,18 +55,26 @@ The module defines the following items:
       *timeout* was added.
 
 
-.. class:: FTP_TLS([host[, user[, passwd[, acct[, keyfile[, certfile[, timeout]]]]]]])
+.. class:: FTP_TLS([host[, user[, passwd[, acct[, keyfile[, certfile[, context[, timeout]]]]]]]])
 
-   A :class:`FTP` subclass which adds TLS support to FTP as described in
+    A :class:`FTP` subclass which adds TLS support to FTP as described in
    :rfc:`4217`.
    Connect as usual to port 21 implicitly securing the FTP control connection
    before authenticating. Securing the data connection requires the user to
-   explicitly ask for it by calling the :meth:`prot_p` method.
-   *keyfile* and *certfile* are optional -- they can contain a PEM formatted
-   private key and certificate chain file name for the SSL connection.
+   explicitly ask for it by calling the :meth:`prot_p` method.  *context*
+   is a :class:`ssl.SSLContext` object which allows bundling SSL configuration
+   options, certificates and private keys into a single (potentially
+   long-lived) structure.  Please read :ref:`ssl-security` for best practices.
+
+   *keyfile* and *certfile* are a legacy alternative to *context* -- they
+   can point to PEM-formatted private key and certificate chain files
+   (respectively) for the SSL connection.
 
    .. versionadded:: 2.7
 
+   .. versionchanged:: 2.7.10
+      The *context* parameter was added.
+
    Here's a sample session using the :class:`FTP_TLS` class:
 
    >>> from ftplib import FTP_TLS
index 0a69b7a6619fd543b38cf1bed716aedf5be1b10e..e1c3d99d1e3973bf29f70c859456a9ec11cc2292 100644 (file)
@@ -641,9 +641,21 @@ else:
         ssl_version = ssl.PROTOCOL_SSLv23
 
         def __init__(self, host='', user='', passwd='', acct='', keyfile=None,
-                     certfile=None, timeout=_GLOBAL_DEFAULT_TIMEOUT):
+                     certfile=None, context=None,
+                     timeout=_GLOBAL_DEFAULT_TIMEOUT, source_address=None):
+            if context is not None and keyfile is not None:
+                raise ValueError("context and keyfile arguments are mutually "
+                                 "exclusive")
+            if context is not None and certfile is not None:
+                raise ValueError("context and certfile arguments are mutually "
+                                 "exclusive")
             self.keyfile = keyfile
             self.certfile = certfile
+            if context is None:
+                context = ssl._create_stdlib_context(self.ssl_version,
+                                                     certfile=certfile,
+                                                     keyfile=keyfile)
+            self.context = context
             self._prot_p = False
             FTP.__init__(self, host, user, passwd, acct, timeout)
 
@@ -660,8 +672,8 @@ else:
                 resp = self.voidcmd('AUTH TLS')
             else:
                 resp = self.voidcmd('AUTH SSL')
-            self.sock = ssl.wrap_socket(self.sock, self.keyfile, self.certfile,
-                                        ssl_version=self.ssl_version)
+            self.sock = self.context.wrap_socket(self.sock,
+                                                 server_hostname=self.host)
             self.file = self.sock.makefile(mode='rb')
             return resp
 
@@ -692,8 +704,8 @@ else:
         def ntransfercmd(self, cmd, rest=None):
             conn, size = FTP.ntransfercmd(self, cmd, rest)
             if self._prot_p:
-                conn = ssl.wrap_socket(conn, self.keyfile, self.certfile,
-                                       ssl_version=self.ssl_version)
+                conn = self.context.wrap_socket(conn,
+                                                server_hostname=self.host)
             return conn, size
 
         def retrbinary(self, cmd, callback, blocksize=8192, rest=None):
index 4c229c0c4bcfa2a434f25276469d61e7c532a304..cc1c19bc46877a8504534fda2d50b97081bd1897 100644 (file)
@@ -20,7 +20,7 @@ from test import test_support
 from test.test_support import HOST, HOSTv6
 threading = test_support.import_module('threading')
 
-
+TIMEOUT = 3
 # the dummy data returned by server over the data channel when
 # RETR, LIST and NLST commands are issued
 RETR_DATA = 'abcde12345\r\n' * 1000
@@ -223,6 +223,7 @@ class DummyFTPServer(asyncore.dispatcher, threading.Thread):
         self.active = False
         self.active_lock = threading.Lock()
         self.host, self.port = self.socket.getsockname()[:2]
+        self.handler_instance = None
 
     def start(self):
         assert not self.active
@@ -246,8 +247,7 @@ class DummyFTPServer(asyncore.dispatcher, threading.Thread):
 
     def handle_accept(self):
         conn, addr = self.accept()
-        self.handler = self.handler(conn)
-        self.close()
+        self.handler_instance = self.handler(conn)
 
     def handle_connect(self):
         self.close()
@@ -262,7 +262,8 @@ class DummyFTPServer(asyncore.dispatcher, threading.Thread):
 
 if ssl is not None:
 
-    CERTFILE = os.path.join(os.path.dirname(__file__), "keycert.pem")
+    CERTFILE = os.path.join(os.path.dirname(__file__), "keycert3.pem")
+    CAFILE = os.path.join(os.path.dirname(__file__), "pycacert.pem")
 
     class SSLConnection(object, asyncore.dispatcher):
         """An asyncore.dispatcher subclass supporting TLS/SSL."""
@@ -271,23 +272,25 @@ if ssl is not None:
         _ssl_closing = False
 
         def secure_connection(self):
-            self.socket = ssl.wrap_socket(self.socket, suppress_ragged_eofs=False,
-                                          certfile=CERTFILE, server_side=True,
-                                          do_handshake_on_connect=False,
-                                          ssl_version=ssl.PROTOCOL_SSLv23)
+            socket = ssl.wrap_socket(self.socket, suppress_ragged_eofs=False,
+                                     certfile=CERTFILE, server_side=True,
+                                     do_handshake_on_connect=False,
+                                     ssl_version=ssl.PROTOCOL_SSLv23)
+            self.del_channel()
+            self.set_socket(socket)
             self._ssl_accepting = True
 
         def _do_ssl_handshake(self):
             try:
                 self.socket.do_handshake()
-            except ssl.SSLError, err:
+            except ssl.SSLError as err:
                 if err.args[0] in (ssl.SSL_ERROR_WANT_READ,
                                    ssl.SSL_ERROR_WANT_WRITE):
                     return
                 elif err.args[0] == ssl.SSL_ERROR_EOF:
                     return self.handle_close()
                 raise
-            except socket.error, err:
+            except socket.error as err:
                 if err.args[0] == errno.ECONNABORTED:
                     return self.handle_close()
             else:
@@ -297,18 +300,21 @@ if ssl is not None:
             self._ssl_closing = True
             try:
                 self.socket = self.socket.unwrap()
-            except ssl.SSLError, err:
+            except ssl.SSLError as err:
                 if err.args[0] in (ssl.SSL_ERROR_WANT_READ,
                                    ssl.SSL_ERROR_WANT_WRITE):
                     return
-            except socket.error, err:
+            except socket.error as err:
                 # Any "socket error" corresponds to a SSL_ERROR_SYSCALL return
                 # from OpenSSL's SSL_shutdown(), corresponding to a
                 # closed socket condition. See also:
                 # http://www.mail-archive.com/openssl-users@openssl.org/msg60710.html
                 pass
             self._ssl_closing = False
-            super(SSLConnection, self).close()
+            if getattr(self, '_ccc', False) is False:
+                super(SSLConnection, self).close()
+            else:
+                pass
 
         def handle_read_event(self):
             if self._ssl_accepting:
@@ -329,7 +335,7 @@ if ssl is not None:
         def send(self, data):
             try:
                 return super(SSLConnection, self).send(data)
-            except ssl.SSLError, err:
+            except ssl.SSLError as err:
                 if err.args[0] in (ssl.SSL_ERROR_EOF, ssl.SSL_ERROR_ZERO_RETURN,
                                    ssl.SSL_ERROR_WANT_READ,
                                    ssl.SSL_ERROR_WANT_WRITE):
@@ -339,13 +345,13 @@ if ssl is not None:
         def recv(self, buffer_size):
             try:
                 return super(SSLConnection, self).recv(buffer_size)
-            except ssl.SSLError, err:
+            except ssl.SSLError as err:
                 if err.args[0] in (ssl.SSL_ERROR_WANT_READ,
                                    ssl.SSL_ERROR_WANT_WRITE):
-                    return ''
+                    return b''
                 if err.args[0] in (ssl.SSL_ERROR_EOF, ssl.SSL_ERROR_ZERO_RETURN):
                     self.handle_close()
-                    return ''
+                    return b''
                 raise
 
         def handle_error(self):
@@ -355,6 +361,8 @@ if ssl is not None:
             if (isinstance(self.socket, ssl.SSLSocket) and
                 self.socket._sslobj is not None):
                 self._do_ssl_shutdown()
+            else:
+                super(SSLConnection, self).close()
 
 
     class DummyTLS_DTPHandler(SSLConnection, DummyDTPHandler):
@@ -462,12 +470,12 @@ class TestFTPClass(TestCase):
 
     def test_rename(self):
         self.client.rename('a', 'b')
-        self.server.handler.next_response = '200'
+        self.server.handler_instance.next_response = '200'
         self.assertRaises(ftplib.error_reply, self.client.rename, 'a', 'b')
 
     def test_delete(self):
         self.client.delete('foo')
-        self.server.handler.next_response = '199'
+        self.server.handler_instance.next_response = '199'
         self.assertRaises(ftplib.error_reply, self.client.delete, 'foo')
 
     def test_size(self):
@@ -515,7 +523,7 @@ class TestFTPClass(TestCase):
     def test_storbinary(self):
         f = StringIO.StringIO(RETR_DATA)
         self.client.storbinary('stor', f)
-        self.assertEqual(self.server.handler.last_received_data, RETR_DATA)
+        self.assertEqual(self.server.handler_instance.last_received_data, RETR_DATA)
         # test new callback arg
         flag = []
         f.seek(0)
@@ -527,12 +535,12 @@ class TestFTPClass(TestCase):
         for r in (30, '30'):
             f.seek(0)
             self.client.storbinary('stor', f, rest=r)
-            self.assertEqual(self.server.handler.rest, str(r))
+            self.assertEqual(self.server.handler_instance.rest, str(r))
 
     def test_storlines(self):
         f = StringIO.StringIO(RETR_DATA.replace('\r\n', '\n'))
         self.client.storlines('stor', f)
-        self.assertEqual(self.server.handler.last_received_data, RETR_DATA)
+        self.assertEqual(self.server.handler_instance.last_received_data, RETR_DATA)
         # test new callback arg
         flag = []
         f.seek(0)
@@ -551,14 +559,14 @@ class TestFTPClass(TestCase):
     def test_makeport(self):
         self.client.makeport()
         # IPv4 is in use, just make sure send_eprt has not been used
-        self.assertEqual(self.server.handler.last_received_cmd, 'port')
+        self.assertEqual(self.server.handler_instance.last_received_cmd, 'port')
 
     def test_makepasv(self):
         host, port = self.client.makepasv()
         conn = socket.create_connection((host, port), 10)
         conn.close()
         # IPv4 is in use, just make sure send_epsv has not been used
-        self.assertEqual(self.server.handler.last_received_cmd, 'pasv')
+        self.assertEqual(self.server.handler_instance.last_received_cmd, 'pasv')
 
     def test_line_too_long(self):
         self.assertRaises(ftplib.Error, self.client.sendcmd,
@@ -600,13 +608,13 @@ class TestIPv6Environment(TestCase):
 
     def test_makeport(self):
         self.client.makeport()
-        self.assertEqual(self.server.handler.last_received_cmd, 'eprt')
+        self.assertEqual(self.server.handler_instance.last_received_cmd, 'eprt')
 
     def test_makepasv(self):
         host, port = self.client.makepasv()
         conn = socket.create_connection((host, port), 10)
         conn.close()
-        self.assertEqual(self.server.handler.last_received_cmd, 'epsv')
+        self.assertEqual(self.server.handler_instance.last_received_cmd, 'epsv')
 
     def test_transfer(self):
         def retr():
@@ -642,7 +650,7 @@ class TestTLS_FTPClass(TestCase):
     def setUp(self):
         self.server = DummyTLS_FTPServer((HOST, 0))
         self.server.start()
-        self.client = ftplib.FTP_TLS(timeout=10)
+        self.client = ftplib.FTP_TLS(timeout=TIMEOUT)
         self.client.connect(self.server.host, self.server.port)
 
     def tearDown(self):
@@ -695,6 +703,59 @@ class TestTLS_FTPClass(TestCase):
         finally:
             self.client.ssl_version = ssl.PROTOCOL_TLSv1
 
+    def test_context(self):
+        self.client.quit()
+        ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+        self.assertRaises(ValueError, ftplib.FTP_TLS, keyfile=CERTFILE,
+                          context=ctx)
+        self.assertRaises(ValueError, ftplib.FTP_TLS, certfile=CERTFILE,
+                          context=ctx)
+        self.assertRaises(ValueError, ftplib.FTP_TLS, certfile=CERTFILE,
+                          keyfile=CERTFILE, context=ctx)
+
+        self.client = ftplib.FTP_TLS(context=ctx, timeout=TIMEOUT)
+        self.client.connect(self.server.host, self.server.port)
+        self.assertNotIsInstance(self.client.sock, ssl.SSLSocket)
+        self.client.auth()
+        self.assertIs(self.client.sock.context, ctx)
+        self.assertIsInstance(self.client.sock, ssl.SSLSocket)
+
+        self.client.prot_p()
+        sock = self.client.transfercmd('list')
+        try:
+            self.assertIs(sock.context, ctx)
+            self.assertIsInstance(sock, ssl.SSLSocket)
+        finally:
+            sock.close()
+
+    def test_check_hostname(self):
+        self.client.quit()
+        ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+        ctx.verify_mode = ssl.CERT_REQUIRED
+        ctx.check_hostname = True
+        ctx.load_verify_locations(CAFILE)
+        self.client = ftplib.FTP_TLS(context=ctx, timeout=TIMEOUT)
+
+        # 127.0.0.1 doesn't match SAN
+        self.client.connect(self.server.host, self.server.port)
+        with self.assertRaises(ssl.CertificateError):
+            self.client.auth()
+        # exception quits connection
+
+        self.client.connect(self.server.host, self.server.port)
+        self.client.prot_p()
+        with self.assertRaises(ssl.CertificateError):
+            self.client.transfercmd("list").close()
+        self.client.quit()
+
+        self.client.connect("localhost", self.server.port)
+        self.client.auth()
+        self.client.quit()
+
+        self.client.connect("localhost", self.server.port)
+        self.client.prot_p()
+        self.client.transfercmd("list").close()
+
 
 class TestTimeouts(TestCase):
 
index 47a73fe599aa2e416b60e7242e67127814fbc055..da3d60297b508a9f9d80982123ddaa0483f8700e 100644 (file)
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -15,6 +15,8 @@ Core and Builtins
 Library
 -------
 
+- Backport the context argument to ftplib.FTP_TLS.
+
 - Issue #23111: Maximize compatibility in protocol versions of ftplib.FTP_TLS.
 
 - Issue #23112: Fix SimpleHTTPServer to correctly carry the query string and