bpo-29970: Add timeout for SSL handshake in asyncio
authorNeil Aspinall <mail@neilaspinall.co.uk>
Tue, 19 Dec 2017 19:45:42 +0000 (19:45 +0000)
committerAndrew Svetlov <andrew.svetlov@gmail.com>
Tue, 19 Dec 2017 19:45:42 +0000 (21:45 +0200)
10 seconds by default.

12 files changed:
Doc/library/asyncio-eventloop.rst
Lib/asyncio/base_events.py
Lib/asyncio/constants.py
Lib/asyncio/events.py
Lib/asyncio/proactor_events.py
Lib/asyncio/selector_events.py
Lib/asyncio/sslproto.py
Lib/asyncio/unix_events.py
Lib/test/test_asyncio/test_base_events.py
Lib/test/test_asyncio/test_sslproto.py
Misc/ACKS
Misc/NEWS.d/next/Library/2017-12-13-19-02-38.bpo-29970.uxVOpk.rst [new file with mode: 0644]

index 26798783fe7256e46b77ab8286230465f50b5620..d20e995d35355b072c9783cfc142639af99ad615 100644 (file)
@@ -261,7 +261,7 @@ Tasks
 Creating connections
 --------------------
 
-.. coroutinemethod:: AbstractEventLoop.create_connection(protocol_factory, host=None, port=None, \*, ssl=None, family=0, proto=0, flags=0, sock=None, local_addr=None, server_hostname=None)
+.. coroutinemethod:: AbstractEventLoop.create_connection(protocol_factory, host=None, port=None, \*, ssl=None, family=0, proto=0, flags=0, sock=None, local_addr=None, server_hostname=None, ssl_handshake_timeout=10.0)
 
    Create a streaming transport connection to a given Internet *host* and
    *port*: socket family :py:data:`~socket.AF_INET` or
@@ -325,6 +325,13 @@ Creating connections
      to bind the socket to locally.  The *local_host* and *local_port*
      are looked up using getaddrinfo(), similarly to *host* and *port*.
 
+   * *ssl_handshake_timeout* is (for an SSL connection) the time in seconds
+     to wait for the SSL handshake to complete before aborting the connection.
+
+   .. versionadded:: 3.7
+
+      The *ssl_handshake_timeout* parameter.
+
    .. versionchanged:: 3.5
 
       On Windows with :class:`ProactorEventLoop`, SSL/TLS is now supported.
@@ -386,7 +393,7 @@ Creating connections
    :ref:`UDP echo server protocol <asyncio-udp-echo-server-protocol>` examples.
 
 
-.. coroutinemethod:: AbstractEventLoop.create_unix_connection(protocol_factory, path=None, \*, ssl=None, sock=None, server_hostname=None)
+.. coroutinemethod:: AbstractEventLoop.create_unix_connection(protocol_factory, path=None, \*, ssl=None, sock=None, server_hostname=None, ssl_handshake_timeout=10.0)
 
    Create UNIX connection: socket family :py:data:`~socket.AF_UNIX`, socket
    type :py:data:`~socket.SOCK_STREAM`. The :py:data:`~socket.AF_UNIX` socket
@@ -404,6 +411,10 @@ Creating connections
 
    Availability: UNIX.
 
+   .. versionadded:: 3.7
+
+      The *ssl_handshake_timeout* parameter.
+
    .. versionchanged:: 3.7
 
       The *path* parameter can now be a :class:`~pathlib.Path` object.
@@ -412,7 +423,7 @@ Creating connections
 Creating listening connections
 ------------------------------
 
-.. coroutinemethod:: AbstractEventLoop.create_server(protocol_factory, host=None, port=None, \*, family=socket.AF_UNSPEC, flags=socket.AI_PASSIVE, sock=None, backlog=100, ssl=None, reuse_address=None, reuse_port=None)
+.. coroutinemethod:: AbstractEventLoop.create_server(protocol_factory, host=None, port=None, \*, family=socket.AF_UNSPEC, flags=socket.AI_PASSIVE, sock=None, backlog=100, ssl=None, reuse_address=None, reuse_port=None, ssl_handshake_timeout=10.0)
 
    Create a TCP server (socket type :data:`~socket.SOCK_STREAM`) bound to
    *host* and *port*.
@@ -456,6 +467,13 @@ Creating listening connections
      set this flag when being created. This option is not supported on
      Windows.
 
+   * *ssl_handshake_timeout* is (for an SSL server) the time in seconds to wait
+     for the SSL handshake to complete before aborting the connection.
+
+   .. versionadded:: 3.7
+
+      The *ssl_handshake_timeout* parameter.
+
    .. versionchanged:: 3.5
 
       On Windows with :class:`ProactorEventLoop`, SSL/TLS is now supported.
@@ -470,7 +488,7 @@ Creating listening connections
       The *host* parameter can now be a sequence of strings.
 
 
-.. coroutinemethod:: AbstractEventLoop.create_unix_server(protocol_factory, path=None, \*, sock=None, backlog=100, ssl=None)
+.. coroutinemethod:: AbstractEventLoop.create_unix_server(protocol_factory, path=None, \*, sock=None, backlog=100, ssl=None, ssl_handshake_timeout=10.0)
 
    Similar to :meth:`AbstractEventLoop.create_server`, but specific to the
    socket family :py:data:`~socket.AF_UNIX`.
@@ -481,11 +499,15 @@ Creating listening connections
 
    Availability: UNIX.
 
+   .. versionadded:: 3.7
+
+      The *ssl_handshake_timeout* parameter.
+
    .. versionchanged:: 3.7
 
       The *path* parameter can now be a :class:`~pathlib.Path` object.
 
-.. coroutinemethod:: BaseEventLoop.connect_accepted_socket(protocol_factory, sock, \*, ssl=None)
+.. coroutinemethod:: BaseEventLoop.connect_accepted_socket(protocol_factory, sock, \*, ssl=None, ssl_handshake_timeout=10.0)
 
    Handle an accepted connection.
 
@@ -500,8 +522,15 @@ Creating listening connections
    * *ssl* can be set to an :class:`~ssl.SSLContext` to enable SSL over the
      accepted connections.
 
+   * *ssl_handshake_timeout* is (for an SSL connection) the time in seconds to
+     wait for the SSL handshake to complete before aborting the connection.
+
    When completed it returns a ``(transport, protocol)`` pair.
 
+   .. versionadded:: 3.7
+
+      The *ssl_handshake_timeout* parameter.
+
    .. versionadded:: 3.5.3
 
 
index a7f8edd8cfd478342602e34c472aba708d328dcd..6246f4e221cb96801313bda00f22090df0137920 100644 (file)
@@ -29,6 +29,7 @@ import sys
 import warnings
 import weakref
 
+from . import constants
 from . import coroutines
 from . import events
 from . import futures
@@ -275,9 +276,11 @@ class BaseEventLoop(events.AbstractEventLoop):
         """Create socket transport."""
         raise NotImplementedError
 
-    def _make_ssl_transport(self, rawsock, protocol, sslcontext, waiter=None,
-                            *, server_side=False, server_hostname=None,
-                            extra=None, server=None):
+    def _make_ssl_transport(
+            self, rawsock, protocol, sslcontext, waiter=None,
+            *, server_side=False, server_hostname=None,
+            extra=None, server=None,
+            ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
         """Create SSL transport."""
         raise NotImplementedError
 
@@ -635,10 +638,12 @@ class BaseEventLoop(events.AbstractEventLoop):
         return await self.run_in_executor(
             None, socket.getnameinfo, sockaddr, flags)
 
-    async def create_connection(self, protocol_factory, host=None, port=None,
-                                *, ssl=None, family=0,
-                                proto=0, flags=0, sock=None,
-                                local_addr=None, server_hostname=None):
+    async def create_connection(
+            self, protocol_factory, host=None, port=None,
+            *, ssl=None, family=0,
+            proto=0, flags=0, sock=None,
+            local_addr=None, server_hostname=None,
+            ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
         """Connect to a TCP server.
 
         Create a streaming transport connection to a given Internet host and
@@ -751,7 +756,8 @@ class BaseEventLoop(events.AbstractEventLoop):
                     f'A Stream Socket was expected, got {sock!r}')
 
         transport, protocol = await self._create_connection_transport(
-            sock, protocol_factory, ssl, server_hostname)
+            sock, protocol_factory, ssl, server_hostname,
+            ssl_handshake_timeout=ssl_handshake_timeout)
         if self._debug:
             # Get the socket from the transport because SSL transport closes
             # the old socket and creates a new SSL socket
@@ -760,8 +766,10 @@ class BaseEventLoop(events.AbstractEventLoop):
                          sock, host, port, transport, protocol)
         return transport, protocol
 
-    async def _create_connection_transport(self, sock, protocol_factory, ssl,
-                                           server_hostname, server_side=False):
+    async def _create_connection_transport(
+            self, sock, protocol_factory, ssl,
+            server_hostname, server_side=False,
+            ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
 
         sock.setblocking(False)
 
@@ -771,7 +779,8 @@ class BaseEventLoop(events.AbstractEventLoop):
             sslcontext = None if isinstance(ssl, bool) else ssl
             transport = self._make_ssl_transport(
                 sock, protocol, sslcontext, waiter,
-                server_side=server_side, server_hostname=server_hostname)
+                server_side=server_side, server_hostname=server_hostname,
+                ssl_handshake_timeout=ssl_handshake_timeout)
         else:
             transport = self._make_socket_transport(sock, protocol, waiter)
 
@@ -929,15 +938,17 @@ class BaseEventLoop(events.AbstractEventLoop):
             raise OSError(f'getaddrinfo({host!r}) returned empty list')
         return infos
 
-    async def create_server(self, protocol_factory, host=None, port=None,
-                            *,
-                            family=socket.AF_UNSPEC,
-                            flags=socket.AI_PASSIVE,
-                            sock=None,
-                            backlog=100,
-                            ssl=None,
-                            reuse_address=None,
-                            reuse_port=None):
+    async def create_server(
+            self, protocol_factory, host=None, port=None,
+            *,
+            family=socket.AF_UNSPEC,
+            flags=socket.AI_PASSIVE,
+            sock=None,
+            backlog=100,
+            ssl=None,
+            reuse_address=None,
+            reuse_port=None,
+            ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
         """Create a TCP server.
 
         The host parameter can be a string, in that case the TCP server is
@@ -1026,13 +1037,16 @@ class BaseEventLoop(events.AbstractEventLoop):
         for sock in sockets:
             sock.listen(backlog)
             sock.setblocking(False)
-            self._start_serving(protocol_factory, sock, ssl, server, backlog)
+            self._start_serving(protocol_factory, sock, ssl, server, backlog,
+                                ssl_handshake_timeout)
         if self._debug:
             logger.info("%r is serving", server)
         return server
 
-    async def connect_accepted_socket(self, protocol_factory, sock,
-                                      *, ssl=None):
+    async def connect_accepted_socket(
+            self, protocol_factory, sock,
+            *, ssl=None,
+            ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
         """Handle an accepted connection.
 
         This is used by servers that accept connections outside of
@@ -1045,7 +1059,8 @@ class BaseEventLoop(events.AbstractEventLoop):
             raise ValueError(f'A Stream Socket was expected, got {sock!r}')
 
         transport, protocol = await self._create_connection_transport(
-            sock, protocol_factory, ssl, '', server_side=True)
+            sock, protocol_factory, ssl, '', server_side=True,
+            ssl_handshake_timeout=ssl_handshake_timeout)
         if self._debug:
             # Get the socket from the transport because SSL transport closes
             # the old socket and creates a new SSL socket
index 52169c3f8e5b128369e5f19d50b134dc94e479a5..0ad974ff2fb9d5654e1b514a20ce6bb34498e6de 100644 (file)
@@ -8,3 +8,6 @@ ACCEPT_RETRY_DELAY = 1
 # The larger the number, the slower the operation in debug mode
 # (see extract_stack() in format_helpers.py).
 DEBUG_STACK_DEPTH = 10
+
+# Number of seconds to wait for SSL handshake to complete
+SSL_HANDSHAKE_TIMEOUT = 10.0
index 974a4a22218fd5689cb4720981ecb5e6c080298c..c9033c020f368e5c3447b3ae6a88950a0272f35d 100644 (file)
@@ -250,16 +250,20 @@ class AbstractEventLoop:
     async def getnameinfo(self, sockaddr, flags=0):
         raise NotImplementedError
 
-    async def create_connection(self, protocol_factory, host=None, port=None,
-                                *, ssl=None, family=0, proto=0,
-                                flags=0, sock=None, local_addr=None,
-                                server_hostname=None):
-        raise NotImplementedError
-
-    async def create_server(self, protocol_factory, host=None, port=None,
-                            *, family=socket.AF_UNSPEC,
-                            flags=socket.AI_PASSIVE, sock=None, backlog=100,
-                            ssl=None, reuse_address=None, reuse_port=None):
+    async def create_connection(
+            self, protocol_factory, host=None, port=None,
+            *, ssl=None, family=0, proto=0,
+            flags=0, sock=None, local_addr=None,
+            server_hostname=None,
+            ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
+        raise NotImplementedError
+
+    async def create_server(
+            self, protocol_factory, host=None, port=None,
+            *, family=socket.AF_UNSPEC,
+            flags=socket.AI_PASSIVE, sock=None, backlog=100,
+            ssl=None, reuse_address=None, reuse_port=None,
+            ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
         """A coroutine which creates a TCP server bound to host and port.
 
         The return value is a Server object which can be used to stop
@@ -294,16 +298,25 @@ class AbstractEventLoop:
         the same port as other existing endpoints are bound to, so long as
         they all set this flag when being created. This option is not
         supported on Windows.
+
+        ssl_handshake_timeout is the time in seconds that an SSL server
+        will wait for completion of the SSL handshake before aborting the
+        connection. Default is 10s, longer timeouts may increase vulnerability
+        to DoS attacks (see https://support.f5.com/csp/article/K13834)
         """
         raise NotImplementedError
 
-    async def create_unix_connection(self, protocol_factory, path=None, *,
-                                     ssl=None, sock=None,
-                                     server_hostname=None):
+    async def create_unix_connection(
+            self, protocol_factory, path=None, *,
+            ssl=None, sock=None,
+            server_hostname=None,
+            ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
         raise NotImplementedError
 
-    async def create_unix_server(self, protocol_factory, path=None, *,
-                                 sock=None, backlog=100, ssl=None):
+    async def create_unix_server(
+            self, protocol_factory, path=None, *,
+            sock=None, backlog=100, ssl=None,
+            ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
         """A coroutine which creates a UNIX Domain Socket server.
 
         The return value is a Server object, which can be used to stop
@@ -320,6 +333,9 @@ class AbstractEventLoop:
 
         ssl can be set to an SSLContext to enable SSL over the
         accepted connections.
+
+        ssl_handshake_timeout is the time in seconds that an SSL server
+        will wait for the SSL handshake to complete (defaults to 10s).
         """
         raise NotImplementedError
 
index 7044437bbb1b4122ddedc005afb6f5467fe2d84a..bc319b06ed672b918f8a342a3534c339165ec1fb 100644 (file)
@@ -389,11 +389,15 @@ class BaseProactorEventLoop(base_events.BaseEventLoop):
         return _ProactorSocketTransport(self, sock, protocol, waiter,
                                         extra, server)
 
-    def _make_ssl_transport(self, rawsock, protocol, sslcontext, waiter=None,
-                            *, server_side=False, server_hostname=None,
-                            extra=None, server=None):
-        ssl_protocol = sslproto.SSLProtocol(self, protocol, sslcontext, waiter,
-                                            server_side, server_hostname)
+    def _make_ssl_transport(
+            self, rawsock, protocol, sslcontext, waiter=None,
+            *, server_side=False, server_hostname=None,
+            extra=None, server=None,
+            ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
+        ssl_protocol = sslproto.SSLProtocol(
+                self, protocol, sslcontext, waiter,
+                server_side, server_hostname,
+                ssl_handshake_timeout=ssl_handshake_timeout)
         _ProactorSocketTransport(self, rawsock, ssl_protocol,
                                  extra=extra, server=server)
         return ssl_protocol._app_transport
@@ -486,7 +490,8 @@ class BaseProactorEventLoop(base_events.BaseEventLoop):
         self._csock.send(b'\0')
 
     def _start_serving(self, protocol_factory, sock,
-                       sslcontext=None, server=None, backlog=100):
+                       sslcontext=None, server=None, backlog=100,
+                       ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
 
         def loop(f=None):
             try:
@@ -499,7 +504,8 @@ class BaseProactorEventLoop(base_events.BaseEventLoop):
                     if sslcontext is not None:
                         self._make_ssl_transport(
                             conn, protocol, sslcontext, server_side=True,
-                            extra={'peername': addr}, server=server)
+                            extra={'peername': addr}, server=server,
+                            ssl_handshake_timeout=ssl_handshake_timeout)
                     else:
                         self._make_socket_transport(
                             conn, protocol,
index 3b49b0cb928f8280ecfb0c41f205afd654567856..1e4bd83a1b1d36a6b9f12c7b239ec544b15b0d18 100644 (file)
@@ -70,11 +70,15 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
         return _SelectorSocketTransport(self, sock, protocol, waiter,
                                         extra, server)
 
-    def _make_ssl_transport(self, rawsock, protocol, sslcontext, waiter=None,
-                            *, server_side=False, server_hostname=None,
-                            extra=None, server=None):
-        ssl_protocol = sslproto.SSLProtocol(self, protocol, sslcontext, waiter,
-                                            server_side, server_hostname)
+    def _make_ssl_transport(
+            self, rawsock, protocol, sslcontext, waiter=None,
+            *, server_side=False, server_hostname=None,
+            extra=None, server=None,
+            ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
+        ssl_protocol = sslproto.SSLProtocol(
+                self, protocol, sslcontext, waiter,
+                server_side, server_hostname,
+                ssl_handshake_timeout=ssl_handshake_timeout)
         _SelectorSocketTransport(self, rawsock, ssl_protocol,
                                  extra=extra, server=server)
         return ssl_protocol._app_transport
@@ -143,12 +147,16 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
                                  exc_info=True)
 
     def _start_serving(self, protocol_factory, sock,
-                       sslcontext=None, server=None, backlog=100):
+                       sslcontext=None, server=None, backlog=100,
+                       ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
         self._add_reader(sock.fileno(), self._accept_connection,
-                         protocol_factory, sock, sslcontext, server, backlog)
+                         protocol_factory, sock, sslcontext, server, backlog,
+                         ssl_handshake_timeout)
 
-    def _accept_connection(self, protocol_factory, sock,
-                           sslcontext=None, server=None, backlog=100):
+    def _accept_connection(
+            self, protocol_factory, sock,
+            sslcontext=None, server=None, backlog=100,
+            ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
         # This method is only called once for each event loop tick where the
         # listening socket has triggered an EVENT_READ. There may be multiple
         # connections waiting for an .accept() so it is called in a loop.
@@ -179,17 +187,20 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
                     self.call_later(constants.ACCEPT_RETRY_DELAY,
                                     self._start_serving,
                                     protocol_factory, sock, sslcontext, server,
-                                    backlog)
+                                    backlog, ssl_handshake_timeout)
                 else:
                     raise  # The event loop will catch, log and ignore it.
             else:
                 extra = {'peername': addr}
                 accept = self._accept_connection2(
-                    protocol_factory, conn, extra, sslcontext, server)
+                    protocol_factory, conn, extra, sslcontext, server,
+                    ssl_handshake_timeout)
                 self.create_task(accept)
 
-    async def _accept_connection2(self, protocol_factory, conn, extra,
-                                  sslcontext=None, server=None):
+    async def _accept_connection2(
+            self, protocol_factory, conn, extra,
+            sslcontext=None, server=None,
+            ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
         protocol = None
         transport = None
         try:
@@ -198,7 +209,8 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
             if sslcontext:
                 transport = self._make_ssl_transport(
                     conn, protocol, sslcontext, waiter=waiter,
-                    server_side=True, extra=extra, server=server)
+                    server_side=True, extra=extra, server=server,
+                    ssl_handshake_timeout=ssl_handshake_timeout)
             else:
                 transport = self._make_socket_transport(
                     conn, protocol, waiter=waiter, extra=extra,
index 8da8570d66d4fb67913f2cff18fff839507beb26..8bcc6cc04334fa1ecd3e6d24feb3c812d231e241 100644 (file)
@@ -6,6 +6,7 @@ except ImportError:  # pragma: no cover
     ssl = None
 
 from . import base_events
+from . import constants
 from . import protocols
 from . import transports
 from .log import logger
@@ -400,7 +401,8 @@ class SSLProtocol(protocols.Protocol):
 
     def __init__(self, loop, app_protocol, sslcontext, waiter,
                  server_side=False, server_hostname=None,
-                 call_connection_made=True):
+                 call_connection_made=True,
+                 ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
         if ssl is None:
             raise RuntimeError('stdlib ssl module not available')
 
@@ -434,6 +436,7 @@ class SSLProtocol(protocols.Protocol):
         # transport, ex: SelectorSocketTransport
         self._transport = None
         self._call_connection_made = call_connection_made
+        self._ssl_handshake_timeout = ssl_handshake_timeout
 
     def _wakeup_waiter(self, exc=None):
         if self._waiter is None:
@@ -561,9 +564,18 @@ class SSLProtocol(protocols.Protocol):
         # the SSL handshake
         self._write_backlog.append((b'', 1))
         self._loop.call_soon(self._process_write_backlog)
+        self._handshake_timeout_handle = \
+            self._loop.call_later(self._ssl_handshake_timeout,
+                                  self._check_handshake_timeout)
+
+    def _check_handshake_timeout(self):
+        if self._in_handshake is True:
+            logger.warning("%r stalled during handshake", self)
+            self._abort()
 
     def _on_handshake_complete(self, handshake_exc):
         self._in_handshake = False
+        self._handshake_timeout_handle.cancel()
 
         sslobj = self._sslpipe.ssl_object
         try:
index 2ab6b154b15517af4e7e78af3ba2dddc1f355f89..e2344582268db31d6223d8ba144bdc95dbe17cf7 100644 (file)
@@ -192,9 +192,11 @@ class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop):
     def _child_watcher_callback(self, pid, returncode, transp):
         self.call_soon_threadsafe(transp._process_exited, returncode)
 
-    async def create_unix_connection(self, protocol_factory, path=None, *,
-                                     ssl=None, sock=None,
-                                     server_hostname=None):
+    async def create_unix_connection(
+            self, protocol_factory, path=None, *,
+            ssl=None, sock=None,
+            server_hostname=None,
+            ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
         assert server_hostname is None or isinstance(server_hostname, str)
         if ssl:
             if server_hostname is None:
@@ -228,11 +230,14 @@ class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop):
             sock.setblocking(False)
 
         transport, protocol = await self._create_connection_transport(
-            sock, protocol_factory, ssl, server_hostname)
+            sock, protocol_factory, ssl, server_hostname,
+            ssl_handshake_timeout=ssl_handshake_timeout)
         return transport, protocol
 
-    async def create_unix_server(self, protocol_factory, path=None, *,
-                                 sock=None, backlog=100, ssl=None):
+    async def create_unix_server(
+            self, protocol_factory, path=None, *,
+            sock=None, backlog=100, ssl=None,
+            ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT):
         if isinstance(ssl, bool):
             raise TypeError('ssl argument must be an SSLContext or None')
 
@@ -283,7 +288,8 @@ class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop):
         server = base_events.Server(self, [sock])
         sock.listen(backlog)
         sock.setblocking(False)
-        self._start_serving(protocol_factory, sock, ssl, server)
+        self._start_serving(protocol_factory, sock, ssl, server,
+                            ssl_handshake_timeout=ssl_handshake_timeout)
         return server
 
 
index 1d45cf86425ac184d9167744be8f89e0e227b1f8..488257b341f438451867e3d838ebbdb46fd1499f 100644 (file)
@@ -1301,34 +1301,45 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase):
 
         self.loop._make_ssl_transport.side_effect = mock_make_ssl_transport
         ANY = mock.ANY
+        handshake_timeout = object()
         # First try the default server_hostname.
         self.loop._make_ssl_transport.reset_mock()
-        coro = self.loop.create_connection(MyProto, 'python.org', 80, ssl=True)
+        coro = self.loop.create_connection(
+                MyProto, 'python.org', 80, ssl=True,
+                ssl_handshake_timeout=handshake_timeout)
         transport, _ = self.loop.run_until_complete(coro)
         transport.close()
         self.loop._make_ssl_transport.assert_called_with(
             ANY, ANY, ANY, ANY,
             server_side=False,
-            server_hostname='python.org')
+            server_hostname='python.org',
+            ssl_handshake_timeout=handshake_timeout)
         # Next try an explicit server_hostname.
         self.loop._make_ssl_transport.reset_mock()
-        coro = self.loop.create_connection(MyProto, 'python.org', 80, ssl=True,
-                                           server_hostname='perl.com')
+        coro = self.loop.create_connection(
+                MyProto, 'python.org', 80, ssl=True,
+                server_hostname='perl.com',
+                ssl_handshake_timeout=handshake_timeout)
         transport, _ = self.loop.run_until_complete(coro)
         transport.close()
         self.loop._make_ssl_transport.assert_called_with(
             ANY, ANY, ANY, ANY,
             server_side=False,
-            server_hostname='perl.com')
+            server_hostname='perl.com',
+            ssl_handshake_timeout=handshake_timeout)
         # Finally try an explicit empty server_hostname.
         self.loop._make_ssl_transport.reset_mock()
-        coro = self.loop.create_connection(MyProto, 'python.org', 80, ssl=True,
-                                           server_hostname='')
+        coro = self.loop.create_connection(
+                MyProto, 'python.org', 80, ssl=True,
+                server_hostname='',
+                ssl_handshake_timeout=handshake_timeout)
         transport, _ = self.loop.run_until_complete(coro)
         transport.close()
-        self.loop._make_ssl_transport.assert_called_with(ANY, ANY, ANY, ANY,
-                                                         server_side=False,
-                                                         server_hostname='')
+        self.loop._make_ssl_transport.assert_called_with(
+                ANY, ANY, ANY, ANY,
+                server_side=False,
+                server_hostname='',
+                ssl_handshake_timeout=handshake_timeout)
 
     def test_create_connection_no_ssl_server_hostname_errors(self):
         # When not using ssl, server_hostname must be None.
@@ -1687,7 +1698,7 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase):
             constants.ACCEPT_RETRY_DELAY,
             # self.loop._start_serving
             mock.ANY,
-            MyProto, sock, None, None, mock.ANY)
+            MyProto, sock, None, None, mock.ANY, mock.ANY)
 
     def test_call_coroutine(self):
         @asyncio.coroutine
index 7650fe6bd46ca9aea4ee125216e04e38f42939d1..1c42a35128fac556e07cae6034e1ec7d5b7befee 100644 (file)
@@ -11,6 +11,7 @@ except ImportError:
 import asyncio
 from asyncio import log
 from asyncio import sslproto
+from asyncio import tasks
 from test.test_asyncio import utils as test_utils
 
 
@@ -25,7 +26,8 @@ class SslProtoHandshakeTests(test_utils.TestCase):
     def ssl_protocol(self, waiter=None):
         sslcontext = test_utils.dummy_ssl_context()
         app_proto = asyncio.Protocol()
-        proto = sslproto.SSLProtocol(self.loop, app_proto, sslcontext, waiter)
+        proto = sslproto.SSLProtocol(self.loop, app_proto, sslcontext, waiter,
+                                     ssl_handshake_timeout=0.1)
         self.assertIs(proto._app_transport.get_protocol(), app_proto)
         self.addCleanup(proto._app_transport.close)
         return proto
@@ -63,6 +65,16 @@ class SslProtoHandshakeTests(test_utils.TestCase):
         with test_utils.disable_logger():
             self.loop.run_until_complete(handshake_fut)
 
+    def test_handshake_timeout(self):
+        # bpo-29970: Check that a connection is aborted if handshake is not
+        # completed in timeout period, instead of remaining open indefinitely
+        ssl_proto = self.ssl_protocol()
+        transport = self.connection_made(ssl_proto)
+
+        with test_utils.disable_logger():
+            self.loop.run_until_complete(tasks.sleep(0.2, loop=self.loop))
+        self.assertTrue(transport.abort.called)
+
     def test_eof_received_waiter(self):
         waiter = asyncio.Future(loop=self.loop)
         ssl_proto = self.ssl_protocol(waiter)
index e5343899640fe63907b1f64604935c069c554f57..009b072d680aac624b9641298457ddf6fe9ad24d 100644 (file)
--- a/Misc/ACKS
+++ b/Misc/ACKS
@@ -63,6 +63,7 @@ Jeffrey Armstrong
 Jason Asbahr
 David Ascher
 Ammar Askar
+Neil Aspinall
 Chris AtLee
 Aymeric Augustin
 Cathy Avery
diff --git a/Misc/NEWS.d/next/Library/2017-12-13-19-02-38.bpo-29970.uxVOpk.rst b/Misc/NEWS.d/next/Library/2017-12-13-19-02-38.bpo-29970.uxVOpk.rst
new file mode 100644 (file)
index 0000000..d3d9ae9
--- /dev/null
@@ -0,0 +1 @@
+Abort asyncio SSLProtocol connection if handshake not complete within 10s