]> granicus.if.org Git - python/commitdiff
asyncio: Skip getaddrinfo if host is already resolved.
authorYury Selivanov <yselivanov@sprymix.com>
Thu, 17 Dec 2015 00:31:17 +0000 (19:31 -0500)
committerYury Selivanov <yselivanov@sprymix.com>
Thu, 17 Dec 2015 00:31:17 +0000 (19:31 -0500)
getaddrinfo takes an exclusive lock on some platforms, causing clients to queue
up waiting for the lock if many names are being resolved concurrently. Users
may want to handle name resolution in their own code, for the sake of caching,
using an alternate resolver, or to measure DNS duration separately from
connection duration. Skip getaddrinfo if the "host" passed into
create_connection is already resolved.

See https://github.com/python/asyncio/pull/302 for details.

Patch by A. Jesse Jiryu Davis.

Lib/asyncio/base_events.py
Lib/asyncio/proactor_events.py
Lib/asyncio/selector_events.py
Lib/asyncio/test_utils.py
Lib/test/test_asyncio/test_base_events.py
Lib/test/test_asyncio/test_events.py
Lib/test/test_asyncio/test_proactor_events.py

index c5ffad408078f36cedd6fb09985c06bde0b27d61..4505732f9ac02bdf3919a678b2347ff1ee353cae 100644 (file)
@@ -16,8 +16,10 @@ to modify the meaning of the API call itself.
 
 import collections
 import concurrent.futures
+import functools
 import heapq
 import inspect
+import ipaddress
 import itertools
 import logging
 import os
@@ -70,49 +72,83 @@ def _format_pipe(fd):
         return repr(fd)
 
 
+# Linux's sock.type is a bitmask that can include extra info about socket.
+_SOCKET_TYPE_MASK = 0
+if hasattr(socket, 'SOCK_NONBLOCK'):
+    _SOCKET_TYPE_MASK |= socket.SOCK_NONBLOCK
+if hasattr(socket, 'SOCK_CLOEXEC'):
+    _SOCKET_TYPE_MASK |= socket.SOCK_CLOEXEC
+
+
+@functools.lru_cache(maxsize=1024)
+def _ipaddr_info(host, port, family, type, proto):
+    # Try to skip getaddrinfo if "host" is already an IP. Since getaddrinfo
+    # blocks on an exclusive lock on some platforms, users might handle name
+    # resolution in their own code and pass in resolved IPs.
+    if proto not in {0, socket.IPPROTO_TCP, socket.IPPROTO_UDP} or host is None:
+        return None
+
+    type &= ~_SOCKET_TYPE_MASK
+    if type == socket.SOCK_STREAM:
+        proto = socket.IPPROTO_TCP
+    elif type == socket.SOCK_DGRAM:
+        proto = socket.IPPROTO_UDP
+    else:
+        return None
+
+    if hasattr(socket, 'inet_pton'):
+        if family == socket.AF_UNSPEC:
+            afs = [socket.AF_INET, socket.AF_INET6]
+        else:
+            afs = [family]
+
+        for af in afs:
+            # Linux's inet_pton doesn't accept an IPv6 zone index after host,
+            # like '::1%lo0', so strip it. If we happen to make an invalid
+            # address look valid, we fail later in sock.connect or sock.bind.
+            try:
+                if af == socket.AF_INET6:
+                    socket.inet_pton(af, host.partition('%')[0])
+                else:
+                    socket.inet_pton(af, host)
+                return af, type, proto, '', (host, port)
+            except OSError:
+                pass
+
+        # "host" is not an IP address.
+        return None
+
+    # No inet_pton. (On Windows it's only available since Python 3.4.)
+    # Even though getaddrinfo with AI_NUMERICHOST would be non-blocking, it
+    # still requires a lock on some platforms, and waiting for that lock could
+    # block the event loop. Use ipaddress instead, it's just text parsing.
+    try:
+        addr = ipaddress.IPv4Address(host)
+    except ValueError:
+        try:
+            addr = ipaddress.IPv6Address(host.partition('%')[0])
+        except ValueError:
+            return None
+
+    af = socket.AF_INET if addr.version == 4 else socket.AF_INET6
+    if family not in (socket.AF_UNSPEC, af):
+        # "host" is wrong IP version for "family".
+        return None
+
+    return af, type, proto, '', (host, port)
+
+
 def _check_resolved_address(sock, address):
     # Ensure that the address is already resolved to avoid the trap of hanging
     # the entire event loop when the address requires doing a DNS lookup.
-    #
-    # getaddrinfo() is slow (around 10 us per call): this function should only
-    # be called in debug mode
-    family = sock.family
-
-    if family == socket.AF_INET:
-        host, port = address
-    elif family == socket.AF_INET6:
-        host, port = address[:2]
-    else:
+
+    if hasattr(socket, 'AF_UNIX') and sock.family == socket.AF_UNIX:
         return
 
-    # On Windows, socket.inet_pton() is only available since Python 3.4
-    if hasattr(socket, 'inet_pton'):
-        # getaddrinfo() is slow and has known issue: prefer inet_pton()
-        # if available
-        try:
-            socket.inet_pton(family, host)
-        except OSError as exc:
-            raise ValueError("address must be resolved (IP address), "
-                             "got host %r: %s"
-                             % (host, exc))
-    else:
-        # Use getaddrinfo(flags=AI_NUMERICHOST) to ensure that the address is
-        # already resolved.
-        type_mask = 0
-        if hasattr(socket, 'SOCK_NONBLOCK'):
-            type_mask |= socket.SOCK_NONBLOCK
-        if hasattr(socket, 'SOCK_CLOEXEC'):
-            type_mask |= socket.SOCK_CLOEXEC
-        try:
-            socket.getaddrinfo(host, port,
-                               family=family,
-                               type=(sock.type & ~type_mask),
-                               proto=sock.proto,
-                               flags=socket.AI_NUMERICHOST)
-        except socket.gaierror as err:
-            raise ValueError("address must be resolved (IP address), "
-                             "got host %r: %s"
-                             % (host, err))
+    host, port = address[:2]
+    if _ipaddr_info(host, port, sock.family, sock.type, sock.proto) is None:
+        raise ValueError("address must be resolved (IP address),"
+                         " got host %r" % host)
 
 
 def _run_until_complete_cb(fut):
@@ -535,7 +571,12 @@ class BaseEventLoop(events.AbstractEventLoop):
 
     def getaddrinfo(self, host, port, *,
                     family=0, type=0, proto=0, flags=0):
-        if self._debug:
+        info = _ipaddr_info(host, port, family, type, proto)
+        if info is not None:
+            fut = futures.Future(loop=self)
+            fut.set_result([info])
+            return fut
+        elif self._debug:
             return self.run_in_executor(None, self._getaddrinfo_debug,
                                         host, port, family, type, proto, flags)
         else:
index 7eac41eec028129d90668f3ebf845cb5ea8bf179..14c0659ddee466a0b1bcee4e09335eb39d61e20e 100644 (file)
@@ -441,8 +441,7 @@ class BaseProactorEventLoop(base_events.BaseEventLoop):
 
     def sock_connect(self, sock, address):
         try:
-            if self._debug:
-                base_events._check_resolved_address(sock, address)
+            base_events._check_resolved_address(sock, address)
         except ValueError as err:
             fut = futures.Future(loop=self)
             fut.set_exception(err)
index a05f81cd9de6557755bb55bae664a8baac69df87..5b26631d80d33574268f73e9f00239402d3b6a23 100644 (file)
@@ -397,8 +397,7 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
             raise ValueError("the socket must be non-blocking")
         fut = futures.Future(loop=self)
         try:
-            if self._debug:
-                base_events._check_resolved_address(sock, address)
+            base_events._check_resolved_address(sock, address)
         except ValueError as err:
             fut.set_exception(err)
         else:
index 8170533188c370436646530bf825e8a99a07cc89..396e6aed5673c4c6323ae7f8e091c0ff9ab0e3a9 100644 (file)
@@ -446,9 +446,14 @@ def disable_logger():
     finally:
         logger.setLevel(old_level)
 
-def mock_nonblocking_socket():
+
+def mock_nonblocking_socket(proto=socket.IPPROTO_TCP, type=socket.SOCK_STREAM,
+                            family=socket.AF_INET):
     """Create a mock of a non-blocking socket."""
-    sock = mock.Mock(socket.socket)
+    sock = mock.MagicMock(socket.socket)
+    sock.proto = proto
+    sock.type = type
+    sock.family = family
     sock.gettimeout.return_value = 0.0
     return sock
 
index 072fe2f93e7f11d240f3368f5de5a27e798dc00c..608398cc67d7b8a6e04b5cdfc5b4ef96ca5edc05 100644 (file)
@@ -32,6 +32,120 @@ MOCK_ANY = mock.ANY
 PY34 = sys.version_info >= (3, 4)
 
 
+def mock_socket_module():
+    m_socket = mock.MagicMock(spec=socket)
+    for name in (
+        'AF_INET', 'AF_INET6', 'AF_UNSPEC', 'IPPROTO_TCP', 'IPPROTO_UDP',
+        'SOCK_STREAM', 'SOCK_DGRAM', 'SOL_SOCKET', 'SO_REUSEADDR', 'inet_pton'
+    ):
+        if hasattr(socket, name):
+            setattr(m_socket, name, getattr(socket, name))
+        else:
+            delattr(m_socket, name)
+
+    m_socket.socket = mock.MagicMock()
+    m_socket.socket.return_value = test_utils.mock_nonblocking_socket()
+
+    return m_socket
+
+
+def patch_socket(f):
+    return mock.patch('asyncio.base_events.socket',
+                      new_callable=mock_socket_module)(f)
+
+
+class BaseEventTests(test_utils.TestCase):
+
+    def setUp(self):
+        super().setUp()
+        base_events._ipaddr_info.cache_clear()
+
+    def tearDown(self):
+        base_events._ipaddr_info.cache_clear()
+        super().tearDown()
+
+    def test_ipaddr_info(self):
+        UNSPEC = socket.AF_UNSPEC
+        INET = socket.AF_INET
+        INET6 = socket.AF_INET6
+        STREAM = socket.SOCK_STREAM
+        DGRAM = socket.SOCK_DGRAM
+        TCP = socket.IPPROTO_TCP
+        UDP = socket.IPPROTO_UDP
+
+        self.assertEqual(
+            (INET, STREAM, TCP, '', ('1.2.3.4', 1)),
+            base_events._ipaddr_info('1.2.3.4', 1, INET, STREAM, TCP))
+
+        self.assertEqual(
+            (INET, STREAM, TCP, '', ('1.2.3.4', 1)),
+            base_events._ipaddr_info('1.2.3.4', 1, UNSPEC, STREAM, TCP))
+
+        self.assertEqual(
+            (INET, DGRAM, UDP, '', ('1.2.3.4', 1)),
+            base_events._ipaddr_info('1.2.3.4', 1, UNSPEC, DGRAM, UDP))
+
+        # Socket type STREAM implies TCP protocol.
+        self.assertEqual(
+            (INET, STREAM, TCP, '', ('1.2.3.4', 1)),
+            base_events._ipaddr_info('1.2.3.4', 1, UNSPEC, STREAM, 0))
+
+        # Socket type DGRAM implies UDP protocol.
+        self.assertEqual(
+            (INET, DGRAM, UDP, '', ('1.2.3.4', 1)),
+            base_events._ipaddr_info('1.2.3.4', 1, UNSPEC, DGRAM, 0))
+
+        # No socket type.
+        self.assertIsNone(
+            base_events._ipaddr_info('1.2.3.4', 1, UNSPEC, 0, 0))
+
+        # IPv4 address with family IPv6.
+        self.assertIsNone(
+            base_events._ipaddr_info('1.2.3.4', 1, INET6, STREAM, TCP))
+
+        self.assertEqual(
+            (INET6, STREAM, TCP, '', ('::3', 1)),
+            base_events._ipaddr_info('::3', 1, INET6, STREAM, TCP))
+
+        self.assertEqual(
+            (INET6, STREAM, TCP, '', ('::3', 1)),
+            base_events._ipaddr_info('::3', 1, UNSPEC, STREAM, TCP))
+
+        # IPv6 address with family IPv4.
+        self.assertIsNone(
+            base_events._ipaddr_info('::3', 1, INET, STREAM, TCP))
+
+        # IPv6 address with zone index.
+        self.assertEqual(
+            (INET6, STREAM, TCP, '', ('::3%lo0', 1)),
+            base_events._ipaddr_info('::3%lo0', 1, INET6, STREAM, TCP))
+
+    @patch_socket
+    def test_ipaddr_info_no_inet_pton(self, m_socket):
+        del m_socket.inet_pton
+        self.test_ipaddr_info()
+
+    def test_check_resolved_address(self):
+        sock = socket.socket(socket.AF_INET)
+        base_events._check_resolved_address(sock, ('1.2.3.4', 1))
+
+        sock = socket.socket(socket.AF_INET6)
+        base_events._check_resolved_address(sock, ('::3', 1))
+        base_events._check_resolved_address(sock, ('::3%lo0', 1))
+        self.assertRaises(ValueError,
+                          base_events._check_resolved_address, sock, ('foo', 1))
+
+    def test_check_resolved_sock_type(self):
+        # Ensure we ignore extra flags in sock.type.
+        if hasattr(socket, 'SOCK_NONBLOCK'):
+            sock = socket.socket(type=socket.SOCK_STREAM | socket.SOCK_NONBLOCK)
+            base_events._check_resolved_address(sock, ('1.2.3.4', 1))
+
+        if hasattr(socket, 'SOCK_CLOEXEC'):
+            sock = socket.socket(type=socket.SOCK_STREAM | socket.SOCK_CLOEXEC)
+            base_events._check_resolved_address(sock, ('1.2.3.4', 1))
+
+
 class BaseEventLoopTests(test_utils.TestCase):
 
     def setUp(self):
@@ -875,7 +989,12 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase):
         self.loop = asyncio.new_event_loop()
         self.set_event_loop(self.loop)
 
-    @mock.patch('asyncio.base_events.socket')
+    def tearDown(self):
+        # Clear mocked constants like AF_INET from the cache.
+        base_events._ipaddr_info.cache_clear()
+        super().tearDown()
+
+    @patch_socket
     def test_create_connection_multiple_errors(self, m_socket):
 
         class MyProto(asyncio.Protocol):
@@ -908,7 +1027,7 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase):
 
         self.assertEqual(str(cm.exception), 'Multiple exceptions: err1, err2')
 
-    @mock.patch('asyncio.base_events.socket')
+    @patch_socket
     def test_create_connection_timeout(self, m_socket):
         # Ensure that the socket is closed on timeout
         sock = mock.Mock()
@@ -986,7 +1105,7 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase):
         with self.assertRaises(OSError):
             self.loop.run_until_complete(coro)
 
-    @mock.patch('asyncio.base_events.socket')
+    @patch_socket
     def test_create_connection_multiple_errors_local_addr(self, m_socket):
 
         def bind(addr):
@@ -1018,6 +1137,46 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase):
         self.assertTrue(str(cm.exception).startswith('Multiple exceptions: '))
         self.assertTrue(m_socket.socket.return_value.close.called)
 
+    def _test_create_connection_ip_addr(self, m_socket, allow_inet_pton):
+        # Test the fallback code, even if this system has inet_pton.
+        if not allow_inet_pton:
+            del m_socket.inet_pton
+
+        def getaddrinfo(*args, **kw):
+            self.fail('should not have called getaddrinfo')
+
+        m_socket.getaddrinfo = getaddrinfo
+        sock = m_socket.socket.return_value
+
+        self.loop.add_reader = mock.Mock()
+        self.loop.add_reader._is_coroutine = False
+        self.loop.add_writer = mock.Mock()
+        self.loop.add_writer._is_coroutine = False
+
+        coro = self.loop.create_connection(MyProto, '1.2.3.4', 80)
+        self.loop.run_until_complete(coro)
+        sock.connect.assert_called_with(('1.2.3.4', 80))
+        m_socket.socket.assert_called_with(family=m_socket.AF_INET,
+                                           proto=m_socket.IPPROTO_TCP,
+                                           type=m_socket.SOCK_STREAM)
+
+        sock.family = socket.AF_INET6
+        coro = self.loop.create_connection(MyProto, '::2', 80)
+
+        self.loop.run_until_complete(coro)
+        sock.connect.assert_called_with(('::2', 80))
+        m_socket.socket.assert_called_with(family=m_socket.AF_INET6,
+                                           proto=m_socket.IPPROTO_TCP,
+                                           type=m_socket.SOCK_STREAM)
+
+    @patch_socket
+    def test_create_connection_ip_addr(self, m_socket):
+        self._test_create_connection_ip_addr(m_socket, True)
+
+    @patch_socket
+    def test_create_connection_no_inet_pton(self, m_socket):
+        self._test_create_connection_ip_addr(m_socket, False)
+
     def test_create_connection_no_local_addr(self):
         @asyncio.coroutine
         def getaddrinfo(host, *args, **kw):
@@ -1153,11 +1312,9 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase):
         f = self.loop.create_server(MyProto, '0.0.0.0', 0)
         self.assertRaises(OSError, self.loop.run_until_complete, f)
 
-    @mock.patch('asyncio.base_events.socket')
+    @patch_socket
     def test_create_server_nosoreuseport(self, m_socket):
         m_socket.getaddrinfo = socket.getaddrinfo
-        m_socket.SOCK_STREAM = socket.SOCK_STREAM
-        m_socket.SOL_SOCKET = socket.SOL_SOCKET
         del m_socket.SO_REUSEPORT
         m_socket.socket.return_value = mock.Mock()
 
@@ -1166,7 +1323,7 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase):
 
         self.assertRaises(ValueError, self.loop.run_until_complete, f)
 
-    @mock.patch('asyncio.base_events.socket')
+    @patch_socket
     def test_create_server_cant_bind(self, m_socket):
 
         class Err(OSError):
@@ -1182,7 +1339,7 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase):
         self.assertRaises(OSError, self.loop.run_until_complete, fut)
         self.assertTrue(m_sock.close.called)
 
-    @mock.patch('asyncio.base_events.socket')
+    @patch_socket
     def test_create_datagram_endpoint_no_addrinfo(self, m_socket):
         m_socket.getaddrinfo.return_value = []
         m_socket.getaddrinfo._is_coroutine = False
@@ -1211,7 +1368,7 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase):
         self.assertRaises(
             OSError, self.loop.run_until_complete, coro)
 
-    @mock.patch('asyncio.base_events.socket')
+    @patch_socket 
     def test_create_datagram_endpoint_socket_err(self, m_socket):
         m_socket.getaddrinfo = socket.getaddrinfo
         m_socket.socket.side_effect = OSError
@@ -1234,7 +1391,7 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase):
         self.assertRaises(
             ValueError, self.loop.run_until_complete, coro)
 
-    @mock.patch('asyncio.base_events.socket')
+    @patch_socket
     def test_create_datagram_endpoint_setblk_err(self, m_socket):
         m_socket.socket.return_value.setblocking.side_effect = OSError
 
@@ -1250,12 +1407,11 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase):
             asyncio.DatagramProtocol)
         self.assertRaises(ValueError, self.loop.run_until_complete, coro)
 
-    @mock.patch('asyncio.base_events.socket')
+    @patch_socket
     def test_create_datagram_endpoint_cant_bind(self, m_socket):
         class Err(OSError):
             pass
 
-        m_socket.AF_INET6 = socket.AF_INET6
         m_socket.getaddrinfo = socket.getaddrinfo
         m_sock = m_socket.socket.return_value = mock.Mock()
         m_sock.bind.side_effect = Err
@@ -1369,11 +1525,8 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase):
         self.loop.run_until_complete(protocol.done)
         self.assertEqual('CLOSED', protocol.state)
 
-    @mock.patch('asyncio.base_events.socket')
+    @patch_socket
     def test_create_datagram_endpoint_nosoreuseport(self, m_socket):
-        m_socket.getaddrinfo = socket.getaddrinfo
-        m_socket.SOCK_DGRAM = socket.SOCK_DGRAM
-        m_socket.SOL_SOCKET = socket.SOL_SOCKET
         del m_socket.SO_REUSEPORT
         m_socket.socket.return_value = mock.Mock()
 
@@ -1385,6 +1538,29 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase):
 
         self.assertRaises(ValueError, self.loop.run_until_complete, coro)
 
+    @patch_socket
+    def test_create_datagram_endpoint_ip_addr(self, m_socket):
+        def getaddrinfo(*args, **kw):
+            self.fail('should not have called getaddrinfo')
+
+        m_socket.getaddrinfo = getaddrinfo
+        m_socket.socket.return_value.bind = bind = mock.Mock()
+        self.loop.add_reader = mock.Mock()
+        self.loop.add_reader._is_coroutine = False
+
+        reuseport_supported = hasattr(socket, 'SO_REUSEPORT')
+        coro = self.loop.create_datagram_endpoint(
+            lambda: MyDatagramProto(loop=self.loop),
+            local_addr=('1.2.3.4', 0),
+            reuse_address=False,
+            reuse_port=reuseport_supported)
+
+        self.loop.run_until_complete(coro)
+        bind.assert_called_with(('1.2.3.4', 0))
+        m_socket.socket.assert_called_with(family=m_socket.AF_INET,
+                                           proto=m_socket.IPPROTO_UDP,
+                                           type=m_socket.SOCK_DGRAM)
+
     def test_accept_connection_retry(self):
         sock = mock.Mock()
         sock.accept.side_effect = BlockingIOError()
index 141fde74e6bd1f910c4bfa20523de507b9807ae6..f1746043e770a612bd9bff6c5100c46a3ce445b0 100644 (file)
@@ -1573,10 +1573,6 @@ class EventLoopTestsMixin:
              'selector': self.loop._selector.__class__.__name__})
 
     def test_sock_connect_address(self):
-        # In debug mode, sock_connect() must ensure that the address is already
-        # resolved (call _check_resolved_address())
-        self.loop.set_debug(True)
-
         addresses = [(socket.AF_INET, ('www.python.org', 80))]
         if support.IPV6_ENABLED:
             addresses.extend((
index 5a0f0881e4c65c2f6c0545982bb49c85aee812f2..5a92b1e34a583ffd6ee06b0cf11b09931ad7e025 100644 (file)
@@ -436,7 +436,7 @@ class ProactorSocketTransportTests(test_utils.TestCase):
 class BaseProactorEventLoopTests(test_utils.TestCase):
 
     def setUp(self):
-        self.sock = mock.Mock(socket.socket)
+        self.sock = test_utils.mock_nonblocking_socket()
         self.proactor = mock.Mock()
 
         self.ssock, self.csock = mock.Mock(), mock.Mock()
@@ -491,8 +491,8 @@ class BaseProactorEventLoopTests(test_utils.TestCase):
         self.proactor.send.assert_called_with(self.sock, b'data')
 
     def test_sock_connect(self):
-        self.loop.sock_connect(self.sock, 123)
-        self.proactor.connect.assert_called_with(self.sock, 123)
+        self.loop.sock_connect(self.sock, ('1.2.3.4', 123))
+        self.proactor.connect.assert_called_with(self.sock, ('1.2.3.4', 123))
 
     def test_sock_accept(self):
         self.loop.sock_accept(self.sock)