]> granicus.if.org Git - python/commitdiff
Issue #28652: Make loop methods reject socket kinds they do not support.
authorYury Selivanov <yury@magic.io>
Wed, 9 Nov 2016 20:47:00 +0000 (15:47 -0500)
committerYury Selivanov <yury@magic.io>
Wed, 9 Nov 2016 20:47:00 +0000 (15:47 -0500)
Lib/asyncio/base_events.py
Lib/asyncio/unix_events.py
Lib/test/test_asyncio/test_base_events.py
Lib/test/test_asyncio/test_events.py
Lib/test/test_asyncio/test_unix_events.py
Misc/NEWS

index 6488f23d3c89595af99efdd2167b49011b260d82..aa7836713a6fd77c272fdd6402907762e71a485a 100644 (file)
@@ -84,12 +84,26 @@ def _set_reuseport(sock):
                              'SO_REUSEPORT defined but not implemented.')
 
 
-# Linux's sock.type is a bitmask that can include extra info about socket.
-_SOCKET_TYPE_MASK = 0
-if hasattr(socket, 'SOCK_NONBLOCK'):
-    _SOCKET_TYPE_MASK |= socket.SOCK_NONBLOCK
-if hasattr(socket, 'SOCK_CLOEXEC'):
-    _SOCKET_TYPE_MASK |= socket.SOCK_CLOEXEC
+def _is_stream_socket(sock):
+    # Linux's socket.type is a bitmask that can include extra info
+    # about socket, therefore we can't do simple
+    # `sock_type == socket.SOCK_STREAM`.
+    return (sock.type & socket.SOCK_STREAM) == socket.SOCK_STREAM
+
+
+def _is_dgram_socket(sock):
+    # Linux's socket.type is a bitmask that can include extra info
+    # about socket, therefore we can't do simple
+    # `sock_type == socket.SOCK_DGRAM`.
+    return (sock.type & socket.SOCK_DGRAM) == socket.SOCK_DGRAM
+
+
+def _is_ip_socket(sock):
+    if sock.family == socket.AF_INET:
+        return True
+    if hasattr(socket, 'AF_INET6') and sock.family == socket.AF_INET6:
+        return True
+    return False
 
 
 def _ipaddr_info(host, port, family, type, proto):
@@ -102,8 +116,12 @@ def _ipaddr_info(host, port, family, type, proto):
             host is None:
         return None
 
-    type &= ~_SOCKET_TYPE_MASK
     if type == socket.SOCK_STREAM:
+        # Linux only:
+        #    getaddrinfo() can raise when socket.type is a bit mask.
+        #    So if socket.type is a bit mask of SOCK_STREAM, and say
+        #    SOCK_NONBLOCK, we simply return None, which will trigger
+        #    a call to getaddrinfo() letting it process this request.
         proto = socket.IPPROTO_TCP
     elif type == socket.SOCK_DGRAM:
         proto = socket.IPPROTO_UDP
@@ -124,7 +142,9 @@ def _ipaddr_info(host, port, family, type, proto):
             return None
 
     if family == socket.AF_UNSPEC:
-        afs = [socket.AF_INET, socket.AF_INET6]
+        afs = [socket.AF_INET]
+        if hasattr(socket, 'AF_INET6'):
+            afs.append(socket.AF_INET6)
     else:
         afs = [family]
 
@@ -771,9 +791,13 @@ class BaseEventLoop(events.AbstractEventLoop):
                     raise OSError('Multiple exceptions: {}'.format(
                         ', '.join(str(exc) for exc in exceptions)))
 
-        elif sock is None:
-            raise ValueError(
-                'host and port was not specified and no sock specified')
+        else:
+            if sock is None:
+                raise ValueError(
+                    'host and port was not specified and no sock specified')
+            if not _is_stream_socket(sock) or not _is_ip_socket(sock):
+                raise ValueError(
+                    'A TCP Stream Socket was expected, got {!r}'.format(sock))
 
         transport, protocol = yield from self._create_connection_transport(
             sock, protocol_factory, ssl, server_hostname)
@@ -817,6 +841,9 @@ class BaseEventLoop(events.AbstractEventLoop):
                                  allow_broadcast=None, sock=None):
         """Create datagram connection."""
         if sock is not None:
+            if not _is_dgram_socket(sock):
+                raise ValueError(
+                    'A UDP Socket was expected, got {!r}'.format(sock))
             if (local_addr or remote_addr or
                     family or proto or flags or
                     reuse_address or reuse_port or allow_broadcast):
@@ -1027,6 +1054,9 @@ class BaseEventLoop(events.AbstractEventLoop):
         else:
             if sock is None:
                 raise ValueError('Neither host/port nor sock were specified')
+            if not _is_stream_socket(sock) or not _is_ip_socket(sock):
+                raise ValueError(
+                    'A TCP Stream Socket was expected, got {!r}'.format(sock))
             sockets = [sock]
 
         server = Server(self, sockets)
@@ -1048,6 +1078,10 @@ class BaseEventLoop(events.AbstractEventLoop):
         This method is a coroutine.  When completed, the coroutine
         returns a (transport, protocol) pair.
         """
+        if not _is_stream_socket(sock):
+            raise ValueError(
+                'A Stream Socket was expected, got {!r}'.format(sock))
+
         transport, protocol = yield from self._create_connection_transport(
             sock, protocol_factory, ssl, '', server_side=True)
         if self._debug:
index 65b61db66ac74e53619bbf8b45890ce09017cad8..788a5a09abf52ec8da74b6873b7456c2edeaef80 100644 (file)
@@ -235,7 +235,7 @@ class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop):
             if sock is None:
                 raise ValueError('no path and sock were specified')
             if (sock.family != socket.AF_UNIX or
-                    sock.type != socket.SOCK_STREAM):
+                    not base_events._is_stream_socket(sock)):
                 raise ValueError(
                     'A UNIX Domain Stream Socket was expected, got {!r}'
                     .format(sock))
@@ -289,7 +289,7 @@ class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop):
                     'path was not specified, and no sock specified')
 
             if (sock.family != socket.AF_UNIX or
-                    sock.type != socket.SOCK_STREAM):
+                    not base_events._is_stream_socket(sock)):
                 raise ValueError(
                     'A UNIX Domain Stream Socket was expected, got {!r}'
                     .format(sock))
index cdbd58798d67a910b9f145b6c20c1799b66b4929..2a93923f8cd51338dc78382c38bb0044dbbbf09f 100644 (file)
@@ -116,6 +116,13 @@ class BaseEventTests(test_utils.TestCase):
         self.assertIsNone(
             base_events._ipaddr_info('::3%lo0', 1, INET6, STREAM, TCP))
 
+        if hasattr(socket, 'SOCK_NONBLOCK'):
+            self.assertEqual(
+                None,
+                base_events._ipaddr_info(
+                    '1.2.3.4', 1, INET, STREAM | socket.SOCK_NONBLOCK, TCP))
+
+
     def test_port_parameter_types(self):
         # Test obscure kinds of arguments for "port".
         INET = socket.AF_INET
@@ -1040,6 +1047,43 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase):
             MyProto, 'example.com', 80, sock=object())
         self.assertRaises(ValueError, self.loop.run_until_complete, coro)
 
+    @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'no Unix sockets')
+    def test_create_connection_wrong_sock(self):
+        sock = socket.socket(socket.AF_UNIX)
+        with sock:
+            coro = self.loop.create_connection(MyProto, sock=sock)
+            with self.assertRaisesRegex(ValueError,
+                                        'A TCP Stream Socket was expected'):
+                self.loop.run_until_complete(coro)
+
+    @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'no Unix sockets')
+    def test_create_server_wrong_sock(self):
+        sock = socket.socket(socket.AF_UNIX)
+        with sock:
+            coro = self.loop.create_server(MyProto, sock=sock)
+            with self.assertRaisesRegex(ValueError,
+                                        'A TCP Stream Socket was expected'):
+                self.loop.run_until_complete(coro)
+
+    @unittest.skipUnless(hasattr(socket, 'SOCK_NONBLOCK'),
+                         'no socket.SOCK_NONBLOCK (linux only)')
+    def test_create_server_stream_bittype(self):
+        sock = socket.socket(
+            socket.AF_INET, socket.SOCK_STREAM | socket.SOCK_NONBLOCK)
+        with sock:
+            coro = self.loop.create_server(lambda: None, sock=sock)
+            srv = self.loop.run_until_complete(coro)
+            srv.close()
+            self.loop.run_until_complete(srv.wait_closed())
+
+    def test_create_datagram_endpoint_wrong_sock(self):
+        sock = socket.socket(socket.AF_INET)
+        with sock:
+            coro = self.loop.create_datagram_endpoint(MyProto, sock=sock)
+            with self.assertRaisesRegex(ValueError,
+                                        'A UDP Socket was expected'):
+                self.loop.run_until_complete(coro)
+
     def test_create_connection_no_host_port_sock(self):
         coro = self.loop.create_connection(MyProto)
         self.assertRaises(ValueError, self.loop.run_until_complete, coro)
@@ -1487,36 +1531,39 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase):
         self.assertEqual('CLOSED', protocol.state)
 
     def test_create_datagram_endpoint_sock_sockopts(self):
+        class FakeSock:
+            type = socket.SOCK_DGRAM
+
         fut = self.loop.create_datagram_endpoint(
-            MyDatagramProto, local_addr=('127.0.0.1', 0), sock=object())
+            MyDatagramProto, local_addr=('127.0.0.1', 0), sock=FakeSock())
         self.assertRaises(ValueError, self.loop.run_until_complete, fut)
 
         fut = self.loop.create_datagram_endpoint(
-            MyDatagramProto, remote_addr=('127.0.0.1', 0), sock=object())
+            MyDatagramProto, remote_addr=('127.0.0.1', 0), sock=FakeSock())
         self.assertRaises(ValueError, self.loop.run_until_complete, fut)
 
         fut = self.loop.create_datagram_endpoint(
-            MyDatagramProto, family=1, sock=object())
+            MyDatagramProto, family=1, sock=FakeSock())
         self.assertRaises(ValueError, self.loop.run_until_complete, fut)
 
         fut = self.loop.create_datagram_endpoint(
-            MyDatagramProto, proto=1, sock=object())
+            MyDatagramProto, proto=1, sock=FakeSock())
         self.assertRaises(ValueError, self.loop.run_until_complete, fut)
 
         fut = self.loop.create_datagram_endpoint(
-            MyDatagramProto, flags=1, sock=object())
+            MyDatagramProto, flags=1, sock=FakeSock())
         self.assertRaises(ValueError, self.loop.run_until_complete, fut)
 
         fut = self.loop.create_datagram_endpoint(
-            MyDatagramProto, reuse_address=True, sock=object())
+            MyDatagramProto, reuse_address=True, sock=FakeSock())
         self.assertRaises(ValueError, self.loop.run_until_complete, fut)
 
         fut = self.loop.create_datagram_endpoint(
-            MyDatagramProto, reuse_port=True, sock=object())
+            MyDatagramProto, reuse_port=True, sock=FakeSock())
         self.assertRaises(ValueError, self.loop.run_until_complete, fut)
 
         fut = self.loop.create_datagram_endpoint(
-            MyDatagramProto, allow_broadcast=True, sock=object())
+            MyDatagramProto, allow_broadcast=True, sock=FakeSock())
         self.assertRaises(ValueError, self.loop.run_until_complete, fut)
 
     def test_create_datagram_endpoint_sockopts(self):
index 5b32332ff4b8ee67d4646acfae93900aa6fd7a8d..28d92a9f4e3eac21e2aed7993f0e60dbd7ff1e89 100644 (file)
@@ -791,9 +791,9 @@ class EventLoopTestsMixin:
         conn, _ = lsock.accept()
         proto = MyProto(loop=loop)
         proto.loop = loop
-        f = loop.create_task(
+        loop.run_until_complete(
             loop.connect_accepted_socket(
-                (lambda : proto), conn, ssl=server_ssl))
+                (lambda: proto), conn, ssl=server_ssl))
         loop.run_forever()
         proto.transport.close()
         lsock.close()
@@ -1377,6 +1377,11 @@ class EventLoopTestsMixin:
         server.transport.close()
 
     def test_create_datagram_endpoint_sock(self):
+        if (sys.platform == 'win32' and
+                isinstance(self.loop, proactor_events.BaseProactorEventLoop)):
+            raise unittest.SkipTest(
+                'UDP is not supported with proactor event loops')
+
         sock = None
         local_address = ('127.0.0.1', 0)
         infos = self.loop.run_until_complete(
@@ -1394,7 +1399,7 @@ class EventLoopTestsMixin:
         else:
             assert False, 'Can not create socket.'
 
-        f = self.loop.create_connection(
+        f = self.loop.create_datagram_endpoint(
             lambda: MyDatagramProto(loop=self.loop), sock=sock)
         tr, pr = self.loop.run_until_complete(f)
         self.assertIsInstance(tr, asyncio.Transport)
index 83a035edee48e04a4122b7bc57de8c79b963781f..89c6eed602bcc58fcb5f7837aec3d090702c91c6 100644 (file)
@@ -280,6 +280,33 @@ class SelectorEventLoopUnixSocketTests(test_utils.TestCase):
                                         'A UNIX Domain Stream.*was expected'):
                 self.loop.run_until_complete(coro)
 
+    def test_create_unix_server_path_dgram(self):
+        sock = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM)
+        with sock:
+            coro = self.loop.create_unix_server(lambda: None, path=None,
+                                                sock=sock)
+            with self.assertRaisesRegex(ValueError,
+                                        'A UNIX Domain Stream.*was expected'):
+                self.loop.run_until_complete(coro)
+
+    @unittest.skipUnless(hasattr(socket, 'SOCK_NONBLOCK'),
+                         'no socket.SOCK_NONBLOCK (linux only)')
+    def test_create_unix_server_path_stream_bittype(self):
+        sock = socket.socket(
+            socket.AF_UNIX, socket.SOCK_STREAM | socket.SOCK_NONBLOCK)
+        with tempfile.NamedTemporaryFile() as file:
+            fn = file.name
+        try:
+            with sock:
+                sock.bind(fn)
+                coro = self.loop.create_unix_server(lambda: None, path=None,
+                                                    sock=sock)
+                srv = self.loop.run_until_complete(coro)
+                srv.close()
+                self.loop.run_until_complete(srv.wait_closed())
+        finally:
+            os.unlink(fn)
+
     def test_create_unix_connection_path_inetsock(self):
         sock = socket.socket()
         with sock:
index 10ec9e0b3898f121740fb175503dda7b0941abcb..f9e2f7225b711bd4535db310821b0b4271213b68 100644 (file)
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -455,6 +455,8 @@ Library
 - Issue #28639: Fix inspect.isawaitable to always return bool
   Patch by Justin Mayfield.
 
+- Issue #28652: Make loop methods reject socket kinds they do not support.
+
 IDLE
 ----