]> granicus.if.org Git - python/commitdiff
asyncio: Better-looking errors when ssl module cannot be imported. In part by Arnaud...
authorGuido van Rossum <guido@dropbox.com>
Fri, 1 Nov 2013 21:22:30 +0000 (14:22 -0700)
committerGuido van Rossum <guido@dropbox.com>
Fri, 1 Nov 2013 21:22:30 +0000 (14:22 -0700)
Lib/asyncio/base_events.py
Lib/asyncio/selector_events.py
Lib/test/test_asyncio/test_selector_events.py

index a73b3d3977121cdec2263daf3494c3bfc4cab6b0..f2d117bdbd2ac9ea284f2557102cf72cfd98843d 100644 (file)
@@ -466,6 +466,8 @@ class BaseEventLoop(events.AbstractEventLoop):
                       ssl=None,
                       reuse_address=None):
         """XXX"""
+        if isinstance(ssl, bool):
+            raise TypeError('ssl argument must be an SSLContext or None')
         if host is not None or port is not None:
             if sock is not None:
                 raise ValueError(
index c5fc5eb7c55666b27414562f690d1f24aaa7df04..3bad19808da93ac231a10aebc4c4f803385484e5 100644 (file)
@@ -90,12 +90,13 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
         except (BlockingIOError, InterruptedError):
             pass
 
-    def _start_serving(self, protocol_factory, sock, ssl=None, server=None):
+    def _start_serving(self, protocol_factory, sock,
+                       sslcontext=None, server=None):
         self.add_reader(sock.fileno(), self._accept_connection,
-                        protocol_factory, sock, ssl, server)
+                        protocol_factory, sock, sslcontext, server)
 
-    def _accept_connection(self, protocol_factory, sock, ssl=None,
-                           server=None):
+    def _accept_connection(self, protocol_factory, sock,
+                           sslcontext=None, server=None):
         try:
             conn, addr = sock.accept()
             conn.setblocking(False)
@@ -113,13 +114,13 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
                 self.remove_reader(sock.fileno())
                 self.call_later(constants.ACCEPT_RETRY_DELAY,
                                 self._start_serving,
-                                protocol_factory, sock, ssl, server)
+                                protocol_factory, sock, sslcontext, server)
             else:
                 raise  # The event loop will catch, log and ignore it.
         else:
-            if ssl:
+            if sslcontext:
                 self._make_ssl_transport(
-                    conn, protocol_factory(), ssl, None,
+                    conn, protocol_factory(), sslcontext, None,
                     server_side=True, extra={'peername': addr}, server=server)
             else:
                 self._make_socket_transport(
@@ -558,17 +559,23 @@ class _SelectorSslTransport(_SelectorTransport):
     def __init__(self, loop, rawsock, protocol, sslcontext, waiter=None,
                  server_side=False, server_hostname=None,
                  extra=None, server=None):
+        if ssl is None:
+            raise RuntimeError('stdlib ssl module not available')
+
         if server_side:
-            assert isinstance(
-                sslcontext, ssl.SSLContext), 'Must pass an SSLContext'
+            if not sslcontext:
+                raise ValueError('Server side ssl needs a valid SSLContext')
         else:
-            # Client-side may pass ssl=True to use a default context.
-            # The default is the same as used by urllib.
-            if sslcontext is None:
+            if not sslcontext:
+                # Client side may pass ssl=True to use a default
+                # context; in that case the sslcontext passed is None.
+                # The default is the same as used by urllib with
+                # cadefault=True.
                 sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
                 sslcontext.options |= ssl.OP_NO_SSLv2
                 sslcontext.set_default_verify_paths()
                 sslcontext.verify_mode = ssl.CERT_REQUIRED
+
         wrap_kwargs = {
             'server_side': server_side,
             'do_handshake_on_connect': False,
index 3b8238d557d6993b345bb1b22631d4dfb5bcaa01..04a7d0c54d0fa78545b06c95ef89351586b9e04c 100644 (file)
@@ -43,6 +43,7 @@ class BaseSelectorEventLoopTests(unittest.TestCase):
         self.assertIsInstance(
             self.loop._make_socket_transport(m, m), _SelectorSocketTransport)
 
+    @unittest.skipIf(ssl is None, 'No ssl module')
     def test_make_ssl_transport(self):
         m = unittest.mock.Mock()
         self.loop.add_reader = unittest.mock.Mock()
@@ -52,6 +53,16 @@ class BaseSelectorEventLoopTests(unittest.TestCase):
         self.assertIsInstance(
             self.loop._make_ssl_transport(m, m, m, m), _SelectorSslTransport)
 
+    @unittest.mock.patch('asyncio.selector_events.ssl', None)
+    def test_make_ssl_transport_without_ssl_error(self):
+        m = unittest.mock.Mock()
+        self.loop.add_reader = unittest.mock.Mock()
+        self.loop.add_writer = unittest.mock.Mock()
+        self.loop.remove_reader = unittest.mock.Mock()
+        self.loop.remove_writer = unittest.mock.Mock()
+        with self.assertRaises(RuntimeError):
+            self.loop._make_ssl_transport(m, m, m, m)
+
     def test_close(self):
         ssock = self.loop._ssock
         ssock.fileno.return_value = 7
@@ -1277,6 +1288,15 @@ class SelectorSslTransportTests(unittest.TestCase):
             server_hostname='localhost')
 
 
+class SelectorSslWithoutSslTransportTests(unittest.TestCase):
+
+    @unittest.mock.patch('asyncio.selector_events.ssl', None)
+    def test_ssl_transport_requires_ssl_module(self):
+        Mock = unittest.mock.Mock
+        with self.assertRaises(RuntimeError):
+            transport = _SelectorSslTransport(Mock(), Mock(), Mock(), Mock())
+
+
 class SelectorDatagramTransportTests(unittest.TestCase):
 
     def setUp(self):