]> granicus.if.org Git - python/commitdiff
bpo-33734: asyncio/ssl: a bunch of bugfixes (#7321)
authorYury Selivanov <yury@magic.io>
Mon, 4 Jun 2018 15:32:35 +0000 (11:32 -0400)
committerGitHub <noreply@github.com>
Mon, 4 Jun 2018 15:32:35 +0000 (11:32 -0400)
* Fix AttributeError (not all SSL exceptions have 'errno' attribute)

* Increase default handshake timeout from 10 to 60 seconds
* Make sure start_tls can be cancelled correctly
* Make sure any error in SSLProtocol gets propagated (instead of just being logged)

Doc/library/asyncio-eventloop.rst
Lib/asyncio/base_events.py
Lib/asyncio/constants.py
Lib/asyncio/events.py
Lib/asyncio/sslproto.py
Lib/test/test_asyncio/test_sslproto.py
Lib/test/test_asyncio/utils.py
Misc/NEWS.d/next/Library/2018-06-01-10-55-48.bpo-33734.x1W9x0.rst [new file with mode: 0644]

index 9d7f2362b3d19b1dc5f041eb50bc2d25263ddca3..a38dab0a7251ece1b5fac42a863f421916a3a1a1 100644 (file)
@@ -351,7 +351,7 @@ Creating connections
 
    * *ssl_handshake_timeout* is (for an SSL connection) the time in seconds
      to wait for the SSL handshake to complete before aborting the connection.
-     ``10.0`` seconds if ``None`` (default).
+     ``60.0`` seconds if ``None`` (default).
 
    .. versionadded:: 3.7
 
@@ -497,7 +497,7 @@ Creating listening connections
 
    * *ssl_handshake_timeout* is (for an SSL server) the time in seconds to wait
      for the SSL handshake to complete before aborting the connection.
-     ``10.0`` seconds if ``None`` (default).
+     ``60.0`` seconds if ``None`` (default).
 
    * *start_serving* set to ``True`` (the default) causes the created server
      to start accepting connections immediately.  When set to ``False``,
@@ -559,7 +559,7 @@ Creating listening connections
 
    * *ssl_handshake_timeout* is (for an SSL connection) the time in seconds to
      wait for the SSL handshake to complete before aborting the connection.
-     ``10.0`` seconds if ``None`` (default).
+     ``60.0`` seconds if ``None`` (default).
 
    When completed it returns a ``(transport, protocol)`` pair.
 
@@ -628,7 +628,7 @@ TLS Upgrade
 
    * *ssl_handshake_timeout* is (for an SSL connection) the time in seconds to
      wait for the SSL handshake to complete before aborting the connection.
-     ``10.0`` seconds if ``None`` (default).
+     ``60.0`` seconds if ``None`` (default).
 
    .. versionadded:: 3.7
 
index 61938e90c375dfb3a8f51475f685b93e808e4fc5..34cc6252e77cb63c773d125eb6dbc98992a6cea9 100644 (file)
@@ -1114,7 +1114,12 @@ class BaseEventLoop(events.AbstractEventLoop):
         self.call_soon(ssl_protocol.connection_made, transport)
         self.call_soon(transport.resume_reading)
 
-        await waiter
+        try:
+            await waiter
+        except Exception:
+            transport.close()
+            raise
+
         return ssl_protocol._app_transport
 
     async def create_datagram_endpoint(self, protocol_factory,
index d7ba496942896999524c370dfae9ee56a6d652a4..33feed60e55b008104ec01f0ce6cbb9d573a8eb3 100644 (file)
@@ -12,7 +12,8 @@ ACCEPT_RETRY_DELAY = 1
 DEBUG_STACK_DEPTH = 10
 
 # Number of seconds to wait for SSL handshake to complete
-SSL_HANDSHAKE_TIMEOUT = 10.0
+# The default timeout matches that of Nginx.
+SSL_HANDSHAKE_TIMEOUT = 60.0
 
 # Used in sendfile fallback code.  We use fallback for platforms
 # that don't support sendfile, or for TLS connections.
index 40946bbf65299ddd263c2f47d7a628f89b6feed4..e4e632206af1bc5411094d51de533201feabe25f 100644 (file)
@@ -352,8 +352,7 @@ class AbstractEventLoop:
 
         ssl_handshake_timeout is the time in seconds that an SSL server
         will wait for completion of the SSL handshake before aborting the
-        connection. Default is 10s, longer timeouts may increase vulnerability
-        to DoS attacks (see https://support.f5.com/csp/article/K13834)
+        connection. Default is 60s.
 
         start_serving set to True (default) causes the created server
         to start accepting connections immediately.  When set to False,
@@ -411,7 +410,7 @@ class AbstractEventLoop:
         accepted connections.
 
         ssl_handshake_timeout is the time in seconds that an SSL server
-        will wait for the SSL handshake to complete (defaults to 10s).
+        will wait for the SSL handshake to complete (defaults to 60s).
 
         start_serving set to True (default) causes the created server
         to start accepting connections immediately.  When set to False,
index a6d382ecd3de634ea2bdc2d6d8c3e6f6ff220b90..8515ec5eebd32e4cd0a5a3c4d7b2b4fa5e0ea619 100644 (file)
@@ -214,13 +214,14 @@ class _SSLPipe(object):
                 # Drain possible plaintext data after close_notify.
                 appdata.append(self._incoming.read())
         except (ssl.SSLError, ssl.CertificateError) as exc:
-            if getattr(exc, 'errno', None) not in (
+            exc_errno = getattr(exc, 'errno', None)
+            if exc_errno not in (
                     ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE,
                     ssl.SSL_ERROR_SYSCALL):
                 if self._state == _DO_HANDSHAKE and self._handshake_cb:
                     self._handshake_cb(exc)
                 raise
-            self._need_ssldata = (exc.errno == ssl.SSL_ERROR_WANT_READ)
+            self._need_ssldata = (exc_errno == ssl.SSL_ERROR_WANT_READ)
 
         # Check for record level data that needs to be sent back.
         # Happens for the initial handshake and renegotiations.
@@ -263,13 +264,14 @@ class _SSLPipe(object):
                 # It is not allowed to call write() after unwrap() until the
                 # close_notify is acknowledged. We return the condition to the
                 # caller as a short write.
+                exc_errno = getattr(exc, 'errno', None)
                 if exc.reason == 'PROTOCOL_IS_SHUTDOWN':
-                    exc.errno = ssl.SSL_ERROR_WANT_READ
-                if exc.errno not in (ssl.SSL_ERROR_WANT_READ,
+                    exc_errno = exc.errno = ssl.SSL_ERROR_WANT_READ
+                if exc_errno not in (ssl.SSL_ERROR_WANT_READ,
                                      ssl.SSL_ERROR_WANT_WRITE,
                                      ssl.SSL_ERROR_SYSCALL):
                     raise
-                self._need_ssldata = (exc.errno == ssl.SSL_ERROR_WANT_READ)
+                self._need_ssldata = (exc_errno == ssl.SSL_ERROR_WANT_READ)
 
             # See if there's any record level data back for us.
             if self._outgoing.pending:
@@ -488,6 +490,12 @@ class SSLProtocol(protocols.Protocol):
         if self._session_established:
             self._session_established = False
             self._loop.call_soon(self._app_protocol.connection_lost, exc)
+        else:
+            # Most likely an exception occurred while in SSL handshake.
+            # Just mark the app transport as closed so that its __del__
+            # doesn't complain.
+            if self._app_transport is not None:
+                self._app_transport._closed = True
         self._transport = None
         self._app_transport = None
         self._wakeup_waiter(exc)
@@ -515,11 +523,8 @@ class SSLProtocol(protocols.Protocol):
 
         try:
             ssldata, appdata = self._sslpipe.feed_ssldata(data)
-        except ssl.SSLError as e:
-            if self._loop.get_debug():
-                logger.warning('%r: SSL error %s (reason %s)',
-                               self, e.errno, e.reason)
-            self._abort()
+        except Exception as e:
+            self._fatal_error(e, 'SSL error in data received')
             return
 
         for chunk in ssldata:
@@ -602,8 +607,12 @@ class SSLProtocol(protocols.Protocol):
 
     def _check_handshake_timeout(self):
         if self._in_handshake is True:
-            logger.warning("%r stalled during handshake", self)
-            self._abort()
+            msg = (
+                f"SSL handshake is taking longer than "
+                f"{self._ssl_handshake_timeout} seconds: "
+                f"aborting the connection"
+            )
+            self._fatal_error(ConnectionAbortedError(msg))
 
     def _on_handshake_complete(self, handshake_exc):
         self._in_handshake = False
@@ -615,21 +624,13 @@ class SSLProtocol(protocols.Protocol):
                 raise handshake_exc
 
             peercert = sslobj.getpeercert()
-        except BaseException as exc:
-            if self._loop.get_debug():
-                if isinstance(exc, ssl.CertificateError):
-                    logger.warning("%r: SSL handshake failed "
-                                   "on verifying the certificate",
-                                   self, exc_info=True)
-                else:
-                    logger.warning("%r: SSL handshake failed",
-                                   self, exc_info=True)
-            self._transport.close()
-            if isinstance(exc, Exception):
-                self._wakeup_waiter(exc)
-                return
+        except Exception as exc:
+            if isinstance(exc, ssl.CertificateError):
+                msg = 'SSL handshake failed on verifying the certificate'
             else:
-                raise
+                msg = 'SSL handshake failed'
+            self._fatal_error(exc, msg)
+            return
 
         if self._loop.get_debug():
             dt = self._loop.time() - self._handshake_start_time
@@ -686,18 +687,14 @@ class SSLProtocol(protocols.Protocol):
                 # delete it and reduce the outstanding buffer size.
                 del self._write_backlog[0]
                 self._write_buffer_size -= len(data)
-        except BaseException as exc:
+        except Exception as exc:
             if self._in_handshake:
-                # BaseExceptions will be re-raised in _on_handshake_complete.
+                # Exceptions will be re-raised in _on_handshake_complete.
                 self._on_handshake_complete(exc)
             else:
                 self._fatal_error(exc, 'Fatal error on SSL transport')
-            if not isinstance(exc, Exception):
-                # BaseException
-                raise
 
     def _fatal_error(self, exc, message='Fatal error on transport'):
-        # Should be called from exception handler only.
         if isinstance(exc, base_events._FATAL_ERROR_IGNORE):
             if self._loop.get_debug():
                 logger.debug("%r: %s", self, message, exc_info=True)
index fb823f8b0ca43525a9b493a0944f15297e00a4a0..d02d441a8309cd2855e980307c6cf59b22875f0c 100644 (file)
@@ -53,35 +53,6 @@ class SslProtoHandshakeTests(test_utils.TestCase):
             ssl_proto.connection_made(transport)
         return transport
 
-    def test_cancel_handshake(self):
-        # Python issue #23197: cancelling a handshake must not raise an
-        # exception or log an error, even if the handshake failed
-        waiter = asyncio.Future(loop=self.loop)
-        ssl_proto = self.ssl_protocol(waiter=waiter)
-        handshake_fut = asyncio.Future(loop=self.loop)
-
-        def do_handshake(callback):
-            exc = Exception()
-            callback(exc)
-            handshake_fut.set_result(None)
-            return []
-
-        waiter.cancel()
-        self.connection_made(ssl_proto, do_handshake=do_handshake)
-
-        with test_utils.disable_logger():
-            self.loop.run_until_complete(handshake_fut)
-
-    def test_handshake_timeout(self):
-        # bpo-29970: Check that a connection is aborted if handshake is not
-        # completed in timeout period, instead of remaining open indefinitely
-        ssl_proto = self.ssl_protocol()
-        transport = self.connection_made(ssl_proto)
-
-        with test_utils.disable_logger():
-            self.loop.run_until_complete(tasks.sleep(0.2, loop=self.loop))
-        self.assertTrue(transport.abort.called)
-
     def test_handshake_timeout_zero(self):
         sslcontext = test_utils.dummy_ssl_context()
         app_proto = mock.Mock()
@@ -392,6 +363,67 @@ class BaseStartTLS(func_tests.FunctionalTestCaseMixin):
                 asyncio.wait_for(client(srv.addr),
                                  loop=self.loop, timeout=self.TIMEOUT))
 
+    def test_start_tls_slow_client_cancel(self):
+        HELLO_MSG = b'1' * self.PAYLOAD_SIZE
+
+        client_context = test_utils.simple_client_sslcontext()
+        server_waits_on_handshake = self.loop.create_future()
+
+        def serve(sock):
+            sock.settimeout(self.TIMEOUT)
+
+            data = sock.recv_all(len(HELLO_MSG))
+            self.assertEqual(len(data), len(HELLO_MSG))
+
+            try:
+                self.loop.call_soon_threadsafe(
+                    server_waits_on_handshake.set_result, None)
+                data = sock.recv_all(1024 * 1024)
+            except ConnectionAbortedError:
+                pass
+            finally:
+                sock.close()
+
+        class ClientProto(asyncio.Protocol):
+            def __init__(self, on_data, on_eof):
+                self.on_data = on_data
+                self.on_eof = on_eof
+                self.con_made_cnt = 0
+
+            def connection_made(proto, tr):
+                proto.con_made_cnt += 1
+                # Ensure connection_made gets called only once.
+                self.assertEqual(proto.con_made_cnt, 1)
+
+            def data_received(self, data):
+                self.on_data.set_result(data)
+
+            def eof_received(self):
+                self.on_eof.set_result(True)
+
+        async def client(addr):
+            await asyncio.sleep(0.5, loop=self.loop)
+
+            on_data = self.loop.create_future()
+            on_eof = self.loop.create_future()
+
+            tr, proto = await self.loop.create_connection(
+                lambda: ClientProto(on_data, on_eof), *addr)
+
+            tr.write(HELLO_MSG)
+
+            await server_waits_on_handshake
+
+            with self.assertRaises(asyncio.TimeoutError):
+                await asyncio.wait_for(
+                    self.loop.start_tls(tr, proto, client_context),
+                    0.5,
+                    loop=self.loop)
+
+        with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
+            self.loop.run_until_complete(
+                asyncio.wait_for(client(srv.addr), loop=self.loop, timeout=10))
+
     def test_start_tls_server_1(self):
         HELLO_MSG = b'1' * self.PAYLOAD_SIZE
 
@@ -481,6 +513,156 @@ class BaseStartTLS(func_tests.FunctionalTestCaseMixin):
 
         self.loop.run_until_complete(main())
 
+    def test_handshake_timeout(self):
+        # bpo-29970: Check that a connection is aborted if handshake is not
+        # completed in timeout period, instead of remaining open indefinitely
+        client_sslctx = test_utils.simple_client_sslcontext()
+
+        messages = []
+        self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
+
+        server_side_aborted = False
+
+        def server(sock):
+            nonlocal server_side_aborted
+            try:
+                sock.recv_all(1024 * 1024)
+            except ConnectionAbortedError:
+                server_side_aborted = True
+            finally:
+                sock.close()
+
+        async def client(addr):
+            await asyncio.wait_for(
+                self.loop.create_connection(
+                    asyncio.Protocol,
+                    *addr,
+                    ssl=client_sslctx,
+                    server_hostname='',
+                    ssl_handshake_timeout=10.0),
+                0.5,
+                loop=self.loop)
+
+        with self.tcp_server(server,
+                             max_clients=1,
+                             backlog=1) as srv:
+
+            with self.assertRaises(asyncio.TimeoutError):
+                self.loop.run_until_complete(client(srv.addr))
+
+        self.assertTrue(server_side_aborted)
+
+        # Python issue #23197: cancelling a handshake must not raise an
+        # exception or log an error, even if the handshake failed
+        self.assertEqual(messages, [])
+
+    def test_create_connection_ssl_slow_handshake(self):
+        client_sslctx = test_utils.simple_client_sslcontext()
+
+        messages = []
+        self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
+
+        def server(sock):
+            try:
+                sock.recv_all(1024 * 1024)
+            except ConnectionAbortedError:
+                pass
+            finally:
+                sock.close()
+
+        async def client(addr):
+            reader, writer = await asyncio.open_connection(
+                *addr,
+                ssl=client_sslctx,
+                server_hostname='',
+                loop=self.loop,
+                ssl_handshake_timeout=1.0)
+
+        with self.tcp_server(server,
+                             max_clients=1,
+                             backlog=1) as srv:
+
+            with self.assertRaisesRegex(
+                    ConnectionAbortedError,
+                    r'SSL handshake.*is taking longer'):
+
+                self.loop.run_until_complete(client(srv.addr))
+
+        self.assertEqual(messages, [])
+
+    def test_create_connection_ssl_failed_certificate(self):
+        self.loop.set_exception_handler(lambda loop, ctx: None)
+
+        sslctx = test_utils.simple_server_sslcontext()
+        client_sslctx = test_utils.simple_client_sslcontext(
+            disable_verify=False)
+
+        def server(sock):
+            try:
+                sock.start_tls(
+                    sslctx,
+                    server_side=True)
+            except ssl.SSLError:
+                pass
+            finally:
+                sock.close()
+
+        async def client(addr):
+            reader, writer = await asyncio.open_connection(
+                *addr,
+                ssl=client_sslctx,
+                server_hostname='',
+                loop=self.loop,
+                ssl_handshake_timeout=1.0)
+
+        with self.tcp_server(server,
+                             max_clients=1,
+                             backlog=1) as srv:
+
+            with self.assertRaises(ssl.SSLCertVerificationError):
+                self.loop.run_until_complete(client(srv.addr))
+
+    def test_start_tls_client_corrupted_ssl(self):
+        self.loop.set_exception_handler(lambda loop, ctx: None)
+
+        sslctx = test_utils.simple_server_sslcontext()
+        client_sslctx = test_utils.simple_client_sslcontext()
+
+        def server(sock):
+            orig_sock = sock.dup()
+            try:
+                sock.start_tls(
+                    sslctx,
+                    server_side=True)
+                sock.sendall(b'A\n')
+                sock.recv_all(1)
+                orig_sock.send(b'please corrupt the SSL connection')
+            except ssl.SSLError:
+                pass
+            finally:
+                sock.close()
+
+        async def client(addr):
+            reader, writer = await asyncio.open_connection(
+                *addr,
+                ssl=client_sslctx,
+                server_hostname='',
+                loop=self.loop)
+
+            self.assertEqual(await reader.readline(), b'A\n')
+            writer.write(b'B')
+            with self.assertRaises(ssl.SSLError):
+                await reader.readline()
+            return 'OK'
+
+        with self.tcp_server(server,
+                             max_clients=1,
+                             backlog=1) as srv:
+
+            res = self.loop.run_until_complete(client(srv.addr))
+
+        self.assertEqual(res, 'OK')
+
 
 @unittest.skipIf(ssl is None, 'No ssl module')
 class SelectorStartTLSTests(BaseStartTLS, unittest.TestCase):
index 96dfe2f85b4de111f360c8351cec796c7618f6f4..5362591b5d7380297348169deb741384aec48b4c 100644 (file)
@@ -77,10 +77,11 @@ def simple_server_sslcontext():
     return server_context
 
 
-def simple_client_sslcontext():
+def simple_client_sslcontext(*, disable_verify=True):
     client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
     client_context.check_hostname = False
-    client_context.verify_mode = ssl.CERT_NONE
+    if disable_verify:
+        client_context.verify_mode = ssl.CERT_NONE
     return client_context
 
 
diff --git a/Misc/NEWS.d/next/Library/2018-06-01-10-55-48.bpo-33734.x1W9x0.rst b/Misc/NEWS.d/next/Library/2018-06-01-10-55-48.bpo-33734.x1W9x0.rst
new file mode 100644 (file)
index 0000000..305d40e
--- /dev/null
@@ -0,0 +1 @@
+asyncio/ssl: Fix AttributeError, increase default handshake timeout