]> granicus.if.org Git - python/commitdiff
Issue #27136: Fix DNS static resolution; don't use it in getaddrinfo
authorYury Selivanov <yury@magic.io>
Wed, 8 Jun 2016 16:33:31 +0000 (12:33 -0400)
committerYury Selivanov <yury@magic.io>
Wed, 8 Jun 2016 16:33:31 +0000 (12:33 -0400)
Patch by A. Jesse Jiryu Davis

Lib/asyncio/base_events.py
Lib/asyncio/proactor_events.py
Lib/asyncio/selector_events.py
Lib/test/test_asyncio/test_base_events.py
Lib/test/test_asyncio/test_events.py
Lib/test/test_asyncio/test_selector_events.py

index 2b2c18536d7b0f9cb464c03516c62c23fe9bbf2d..172a463ef807716f993409547f0beb2ab648e5a2 100644 (file)
@@ -16,10 +16,8 @@ to modify the meaning of the API call itself.
 
 import collections
 import concurrent.futures
-import functools
 import heapq
 import inspect
-import ipaddress
 import itertools
 import logging
 import os
@@ -86,12 +84,14 @@ if hasattr(socket, 'SOCK_CLOEXEC'):
     _SOCKET_TYPE_MASK |= socket.SOCK_CLOEXEC
 
 
-@functools.lru_cache(maxsize=1024, typed=True)
 def _ipaddr_info(host, port, family, type, proto):
-    # Try to skip getaddrinfo if "host" is already an IP. Since getaddrinfo
-    # blocks on an exclusive lock on some platforms, users might handle name
-    # resolution in their own code and pass in resolved IPs.
-    if proto not in {0, socket.IPPROTO_TCP, socket.IPPROTO_UDP} or host is None:
+    # Try to skip getaddrinfo if "host" is already an IP. Users might have
+    # handled name resolution in their own code and pass in resolved IPs.
+    if not hasattr(socket, 'inet_pton'):
+        return
+
+    if proto not in {0, socket.IPPROTO_TCP, socket.IPPROTO_UDP} or \
+            host is None:
         return None
 
     type &= ~_SOCKET_TYPE_MASK
@@ -123,59 +123,42 @@ def _ipaddr_info(host, port, family, type, proto):
                 # Might be a service name like "http".
                 port = socket.getservbyname(port)
 
-    if hasattr(socket, 'inet_pton'):
-        if family == socket.AF_UNSPEC:
-            afs = [socket.AF_INET, socket.AF_INET6]
-        else:
-            afs = [family]
-
-        for af in afs:
-            # Linux's inet_pton doesn't accept an IPv6 zone index after host,
-            # like '::1%lo0', so strip it. If we happen to make an invalid
-            # address look valid, we fail later in sock.connect or sock.bind.
-            try:
-                if af == socket.AF_INET6:
-                    socket.inet_pton(af, host.partition('%')[0])
-                else:
-                    socket.inet_pton(af, host)
-                return af, type, proto, '', (host, port)
-            except OSError:
-                pass
+    if family == socket.AF_UNSPEC:
+        afs = [socket.AF_INET, socket.AF_INET6]
+    else:
+        afs = [family]
 
-        # "host" is not an IP address.
+    if isinstance(host, bytes):
+        host = host.decode('idna')
+    if '%' in host:
+        # Linux's inet_pton doesn't accept an IPv6 zone index after host,
+        # like '::1%lo0'.
         return None
 
-    # No inet_pton. (On Windows it's only available since Python 3.4.)
-    # Even though getaddrinfo with AI_NUMERICHOST would be non-blocking, it
-    # still requires a lock on some platforms, and waiting for that lock could
-    # block the event loop. Use ipaddress instead, it's just text parsing.
-    try:
-        addr = ipaddress.IPv4Address(host)
-    except ValueError:
+    for af in afs:
         try:
-            addr = ipaddress.IPv6Address(host.partition('%')[0])
-        except ValueError:
-            return None
+            socket.inet_pton(af, host)
+            # The host has already been resolved.
+            return af, type, proto, '', (host, port)
+        except OSError:
+            pass
 
-    af = socket.AF_INET if addr.version == 4 else socket.AF_INET6
-    if family not in (socket.AF_UNSPEC, af):
-        # "host" is wrong IP version for "family".
-        return None
-
-    return af, type, proto, '', (host, port)
+    # "host" is not an IP address.
+    return None
 
 
-def _check_resolved_address(sock, address):
-    # Ensure that the address is already resolved to avoid the trap of hanging
-    # the entire event loop when the address requires doing a DNS lookup.
-
-    if hasattr(socket, 'AF_UNIX') and sock.family == socket.AF_UNIX:
-        return
-
+def _ensure_resolved(address, *, family=0, type=socket.SOCK_STREAM, proto=0,
+                     flags=0, loop):
     host, port = address[:2]
-    if _ipaddr_info(host, port, sock.family, sock.type, sock.proto) is None:
-        raise ValueError("address must be resolved (IP address),"
-                         " got host %r" % host)
+    info = _ipaddr_info(host, port, family, type, proto)
+    if info is not None:
+        # "host" is already a resolved IP.
+        fut = loop.create_future()
+        fut.set_result([info])
+        return fut
+    else:
+        return loop.getaddrinfo(host, port, family=family, type=type,
+                                proto=proto, flags=flags)
 
 
 def _run_until_complete_cb(fut):
@@ -602,12 +585,7 @@ class BaseEventLoop(events.AbstractEventLoop):
 
     def getaddrinfo(self, host, port, *,
                     family=0, type=0, proto=0, flags=0):
-        info = _ipaddr_info(host, port, family, type, proto)
-        if info is not None:
-            fut = self.create_future()
-            fut.set_result([info])
-            return fut
-        elif self._debug:
+        if self._debug:
             return self.run_in_executor(None, self._getaddrinfo_debug,
                                         host, port, family, type, proto, flags)
         else:
@@ -656,14 +634,14 @@ class BaseEventLoop(events.AbstractEventLoop):
                 raise ValueError(
                     'host/port and sock can not be specified at the same time')
 
-            f1 = self.getaddrinfo(
-                host, port, family=family,
-                type=socket.SOCK_STREAM, proto=proto, flags=flags)
+            f1 = _ensure_resolved((host, port), family=family,
+                                  type=socket.SOCK_STREAM, proto=proto,
+                                  flags=flags, loop=self)
             fs = [f1]
             if local_addr is not None:
-                f2 = self.getaddrinfo(
-                    *local_addr, family=family,
-                    type=socket.SOCK_STREAM, proto=proto, flags=flags)
+                f2 = _ensure_resolved(local_addr, family=family,
+                                      type=socket.SOCK_STREAM, proto=proto,
+                                      flags=flags, loop=self)
                 fs.append(f2)
             else:
                 f2 = None
@@ -798,9 +776,9 @@ class BaseEventLoop(events.AbstractEventLoop):
                         assert isinstance(addr, tuple) and len(addr) == 2, (
                             '2-tuple is expected')
 
-                        infos = yield from self.getaddrinfo(
-                            *addr, family=family, type=socket.SOCK_DGRAM,
-                            proto=proto, flags=flags)
+                        infos = yield from _ensure_resolved(
+                            addr, family=family, type=socket.SOCK_DGRAM,
+                            proto=proto, flags=flags, loop=self)
                         if not infos:
                             raise OSError('getaddrinfo() returned empty list')
 
@@ -888,9 +866,9 @@ class BaseEventLoop(events.AbstractEventLoop):
 
     @coroutine
     def _create_server_getaddrinfo(self, host, port, family, flags):
-        infos = yield from self.getaddrinfo(host, port, family=family,
+        infos = yield from _ensure_resolved((host, port), family=family,
                                             type=socket.SOCK_STREAM,
-                                            flags=flags)
+                                            flags=flags, loop=self)
         if not infos:
             raise OSError('getaddrinfo({!r}) returned empty list'.format(host))
         return infos
index eb92458adae511b221bbc6b842b0284e85c934c4..3ac314c0cc667d4b8ed82141e1fbf5a482173f73 100644 (file)
@@ -440,14 +440,7 @@ class BaseProactorEventLoop(base_events.BaseEventLoop):
         return self._proactor.send(sock, data)
 
     def sock_connect(self, sock, address):
-        try:
-            base_events._check_resolved_address(sock, address)
-        except ValueError as err:
-            fut = self.create_future()
-            fut.set_exception(err)
-            return fut
-        else:
-            return self._proactor.connect(sock, address)
+        return self._proactor.connect(sock, address)
 
     def sock_accept(self, sock):
         return self._proactor.accept(sock)
index b34fee34df8f13bce2f97d85ca253cc3f0301041..fb7ab2108efe8f0bced7a42e6cf1f8925a4d0b48 100644 (file)
@@ -385,24 +385,28 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
     def sock_connect(self, sock, address):
         """Connect to a remote socket at address.
 
-        The address must be already resolved to avoid the trap of hanging the
-        entire event loop when the address requires doing a DNS lookup. For
-        example, it must be an IP address, not a hostname, for AF_INET and
-        AF_INET6 address families. Use getaddrinfo() to resolve the hostname
-        asynchronously.
-
         This method is a coroutine.
         """
         if self._debug and sock.gettimeout() != 0:
             raise ValueError("the socket must be non-blocking")
+
         fut = self.create_future()
+        if hasattr(socket, 'AF_UNIX') and sock.family == socket.AF_UNIX:
+            self._sock_connect(fut, sock, address)
+        else:
+            resolved = base_events._ensure_resolved(address, loop=self)
+            resolved.add_done_callback(
+                lambda resolved: self._on_resolved(fut, sock, resolved))
+
+        return fut
+
+    def _on_resolved(self, fut, sock, resolved):
         try:
-            base_events._check_resolved_address(sock, address)
-        except ValueError as err:
-            fut.set_exception(err)
+            _, _, _, _, address = resolved.result()[0]
+        except Exception as exc:
+            fut.set_exception(exc)
         else:
             self._sock_connect(fut, sock, address)
-        return fut
 
     def _sock_connect(self, fut, sock, address):
         fd = sock.fileno()
index e800ec4340011e96c4f11f28b45a97fe95780c17..0807dfbf4c04bd84155efce2ca1bec73a4db74d9 100644 (file)
@@ -45,6 +45,7 @@ def mock_socket_module():
 
     m_socket.socket = mock.MagicMock()
     m_socket.socket.return_value = test_utils.mock_nonblocking_socket()
+    m_socket.getaddrinfo._is_coroutine = False
 
     return m_socket
 
@@ -56,14 +57,6 @@ def patch_socket(f):
 
 class BaseEventTests(test_utils.TestCase):
 
-    def setUp(self):
-        super().setUp()
-        base_events._ipaddr_info.cache_clear()
-
-    def tearDown(self):
-        base_events._ipaddr_info.cache_clear()
-        super().tearDown()
-
     def test_ipaddr_info(self):
         UNSPEC = socket.AF_UNSPEC
         INET = socket.AF_INET
@@ -77,6 +70,10 @@ class BaseEventTests(test_utils.TestCase):
             (INET, STREAM, TCP, '', ('1.2.3.4', 1)),
             base_events._ipaddr_info('1.2.3.4', 1, INET, STREAM, TCP))
 
+        self.assertEqual(
+            (INET, STREAM, TCP, '', ('1.2.3.4', 1)),
+            base_events._ipaddr_info(b'1.2.3.4', 1, INET, STREAM, TCP))
+
         self.assertEqual(
             (INET, STREAM, TCP, '', ('1.2.3.4', 1)),
             base_events._ipaddr_info('1.2.3.4', 1, UNSPEC, STREAM, TCP))
@@ -116,8 +113,7 @@ class BaseEventTests(test_utils.TestCase):
             base_events._ipaddr_info('::3', 1, INET, STREAM, TCP))
 
         # IPv6 address with zone index.
-        self.assertEqual(
-            (INET6, STREAM, TCP, '', ('::3%lo0', 1)),
+        self.assertIsNone(
             base_events._ipaddr_info('::3%lo0', 1, INET6, STREAM, TCP))
 
     def test_port_parameter_types(self):
@@ -169,31 +165,10 @@ class BaseEventTests(test_utils.TestCase):
     @patch_socket
     def test_ipaddr_info_no_inet_pton(self, m_socket):
         del m_socket.inet_pton
-        self.test_ipaddr_info()
-
-    def test_check_resolved_address(self):
-        sock = socket.socket(socket.AF_INET)
-        with sock:
-            base_events._check_resolved_address(sock, ('1.2.3.4', 1))
-
-        sock = socket.socket(socket.AF_INET6)
-        with sock:
-            base_events._check_resolved_address(sock, ('::3', 1))
-            base_events._check_resolved_address(sock, ('::3%lo0', 1))
-            with self.assertRaises(ValueError):
-                base_events._check_resolved_address(sock, ('foo', 1))
-
-    def test_check_resolved_sock_type(self):
-        # Ensure we ignore extra flags in sock.type.
-        if hasattr(socket, 'SOCK_NONBLOCK'):
-            sock = socket.socket(type=socket.SOCK_STREAM | socket.SOCK_NONBLOCK)
-            with sock:
-                base_events._check_resolved_address(sock, ('1.2.3.4', 1))
-
-        if hasattr(socket, 'SOCK_CLOEXEC'):
-            sock = socket.socket(type=socket.SOCK_STREAM | socket.SOCK_CLOEXEC)
-            with sock:
-                base_events._check_resolved_address(sock, ('1.2.3.4', 1))
+        self.assertIsNone(base_events._ipaddr_info('1.2.3.4', 1,
+                                                   socket.AF_INET,
+                                                   socket.SOCK_STREAM,
+                                                   socket.IPPROTO_TCP))
 
 
 class BaseEventLoopTests(test_utils.TestCase):
@@ -1042,11 +1017,6 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase):
         self.loop = asyncio.new_event_loop()
         self.set_event_loop(self.loop)
 
-    def tearDown(self):
-        # Clear mocked constants like AF_INET from the cache.
-        base_events._ipaddr_info.cache_clear()
-        super().tearDown()
-
     @patch_socket
     def test_create_connection_multiple_errors(self, m_socket):
 
@@ -1195,10 +1165,7 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase):
         if not allow_inet_pton:
             del m_socket.inet_pton
 
-        def getaddrinfo(*args, **kw):
-            self.fail('should not have called getaddrinfo')
-
-        m_socket.getaddrinfo = getaddrinfo
+        m_socket.getaddrinfo = socket.getaddrinfo
         sock = m_socket.socket.return_value
 
         self.loop.add_reader = mock.Mock()
@@ -1210,9 +1177,9 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase):
         t, p = self.loop.run_until_complete(coro)
         try:
             sock.connect.assert_called_with(('1.2.3.4', 80))
-            m_socket.socket.assert_called_with(family=m_socket.AF_INET,
-                                               proto=m_socket.IPPROTO_TCP,
-                                               type=m_socket.SOCK_STREAM)
+            _, kwargs = m_socket.socket.call_args
+            self.assertEqual(kwargs['family'], m_socket.AF_INET)
+            self.assertEqual(kwargs['type'], m_socket.SOCK_STREAM)
         finally:
             t.close()
             test_utils.run_briefly(self.loop)  # allow transport to close
@@ -1221,10 +1188,15 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase):
         coro = self.loop.create_connection(asyncio.Protocol, '::2', 80)
         t, p = self.loop.run_until_complete(coro)
         try:
-            sock.connect.assert_called_with(('::2', 80))
-            m_socket.socket.assert_called_with(family=m_socket.AF_INET6,
-                                               proto=m_socket.IPPROTO_TCP,
-                                               type=m_socket.SOCK_STREAM)
+            # Without inet_pton we use getaddrinfo, which transforms ('::2', 80)
+            # to ('::0.0.0.2', 80, 0, 0). The last 0s are flow info, scope id.
+            [address] = sock.connect.call_args[0]
+            host, port = address[:2]
+            self.assertRegex(host, r'::(0\.)*2')
+            self.assertEqual(port, 80)
+            _, kwargs = m_socket.socket.call_args
+            self.assertEqual(kwargs['family'], m_socket.AF_INET6)
+            self.assertEqual(kwargs['type'], m_socket.SOCK_STREAM)
         finally:
             t.close()
             test_utils.run_briefly(self.loop)  # allow transport to close
@@ -1256,6 +1228,21 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase):
         self.assertRaises(
             OSError, self.loop.run_until_complete, coro)
 
+    @patch_socket
+    def test_create_connection_bluetooth(self, m_socket):
+        # See http://bugs.python.org/issue27136, fallback to getaddrinfo when
+        # we can't recognize an address is resolved, e.g. a Bluetooth address.
+        addr = ('00:01:02:03:04:05', 1)
+
+        def getaddrinfo(host, port, *args, **kw):
+            assert (host, port) == addr
+            return [(999, 1, 999, '', (addr, 1))]
+
+        m_socket.getaddrinfo = getaddrinfo
+        sock = m_socket.socket()
+        coro = self.loop.sock_connect(sock, addr)
+        self.loop.run_until_complete(coro)
+
     def test_create_connection_ssl_server_hostname_default(self):
         self.loop.getaddrinfo = mock.Mock()
 
@@ -1369,7 +1356,7 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase):
         getaddrinfo = self.loop.getaddrinfo = mock.Mock()
         getaddrinfo.return_value = []
 
-        f = self.loop.create_server(MyProto, '0.0.0.0', 0)
+        f = self.loop.create_server(MyProto, 'python.org', 0)
         self.assertRaises(OSError, self.loop.run_until_complete, f)
 
     @patch_socket
index d52213ceb2696592a52433f6836cecefdcacc75e..d0777758a7d99fdb674d5e7ed5bb31d452f21ca4 100644 (file)
@@ -1610,25 +1610,6 @@ class EventLoopTestsMixin:
             {'clock_resolution': self.loop._clock_resolution,
              'selector': self.loop._selector.__class__.__name__})
 
-    def test_sock_connect_address(self):
-        addresses = [(socket.AF_INET, ('www.python.org', 80))]
-        if support.IPV6_ENABLED:
-            addresses.extend((
-                (socket.AF_INET6, ('www.python.org', 80)),
-                (socket.AF_INET6, ('www.python.org', 80, 0, 0)),
-            ))
-
-        for family, address in addresses:
-            for sock_type in (socket.SOCK_STREAM, socket.SOCK_DGRAM):
-                sock = socket.socket(family, sock_type)
-                with sock:
-                    sock.setblocking(False)
-                    connect = self.loop.sock_connect(sock, address)
-                    with self.assertRaises(ValueError) as cm:
-                        self.loop.run_until_complete(connect)
-                    self.assertIn('address must be resolved',
-                                  str(cm.exception))
-
     def test_remove_fds_after_closing(self):
         loop = self.create_event_loop()
         callback = lambda: None
index 77e72e570512029c8f807dd4f68fe9364c5d3276..8ad55358b16a07bbe96f3586b86619a2df6f8714 100644 (file)
@@ -343,9 +343,11 @@ class BaseSelectorEventLoopTests(test_utils.TestCase):
 
         f = self.loop.sock_connect(sock, ('127.0.0.1', 8080))
         self.assertIsInstance(f, asyncio.Future)
-        self.assertEqual(
-            (f, sock, ('127.0.0.1', 8080)),
-            self.loop._sock_connect.call_args[0])
+        self.loop._run_once()
+        future_in, sock_in, address_in = self.loop._sock_connect.call_args[0]
+        self.assertEqual(future_in, f)
+        self.assertEqual(sock_in, sock)
+        self.assertEqual(address_in, ('127.0.0.1', 8080))
 
     def test_sock_connect_timeout(self):
         # asyncio issue #205: sock_connect() must unregister the socket on
@@ -359,6 +361,7 @@ class BaseSelectorEventLoopTests(test_utils.TestCase):
 
         # first call to sock_connect() registers the socket
         fut = self.loop.sock_connect(sock, ('127.0.0.1', 80))
+        self.loop._run_once()
         self.assertTrue(sock.connect.called)
         self.assertTrue(self.loop.add_writer.called)
         self.assertEqual(len(fut._callbacks), 1)
@@ -376,7 +379,10 @@ class BaseSelectorEventLoopTests(test_utils.TestCase):
         sock = mock.Mock()
         sock.fileno.return_value = 10
 
-        self.loop._sock_connect(f, sock, ('127.0.0.1', 8080))
+        resolved = self.loop.create_future()
+        resolved.set_result([(socket.AF_INET, socket.SOCK_STREAM,
+                              socket.IPPROTO_TCP, '', ('127.0.0.1', 8080))])
+        self.loop._sock_connect(f, sock, resolved)
         self.assertTrue(f.done())
         self.assertIsNone(f.result())
         self.assertTrue(sock.connect.called)
@@ -402,9 +408,13 @@ class BaseSelectorEventLoopTests(test_utils.TestCase):
         sock.connect.side_effect = BlockingIOError
         sock.getsockopt.return_value = 0
         address = ('127.0.0.1', 8080)
+        resolved = self.loop.create_future()
+        resolved.set_result([(socket.AF_INET, socket.SOCK_STREAM,
+                              socket.IPPROTO_TCP, '', address)])
 
         f = asyncio.Future(loop=self.loop)
-        self.loop._sock_connect(f, sock, address)
+        self.loop._sock_connect(f, sock, resolved)
+        self.loop._run_once()
         self.assertTrue(self.loop.add_writer.called)
         self.assertEqual(10, self.loop.add_writer.call_args[0][0])