]> granicus.if.org Git - python/commitdiff
bpo-33530: Implement Happy Eyeballs in asyncio, v2 (GH-7237)
authortwisteroid ambassador <twisteroidambassador@users.noreply.github.com>
Sun, 5 May 2019 11:14:35 +0000 (19:14 +0800)
committerMiss Islington (bot) <31488909+miss-islington@users.noreply.github.com>
Sun, 5 May 2019 11:14:35 +0000 (04:14 -0700)
Added two keyword arguments, `delay` and `interleave`, to
`BaseEventLoop.create_connection`. Happy eyeballs is activated if
`delay` is specified.

We now have documentation for the new arguments. `staggered_race()` is in its own module, but not exported to the main asyncio package.

https://bugs.python.org/issue33530

Doc/library/asyncio-eventloop.rst
Lib/asyncio/base_events.py
Lib/asyncio/events.py
Lib/asyncio/staggered.py [new file with mode: 0644]
Misc/NEWS.d/next/Library/2018-05-29-18-34-53.bpo-33530._4Q_bi.rst [new file with mode: 0644]

index e2b312453921221099c0f863540768e43fd860be..06f673be7902c274423f5371b01cabf3a6be118e 100644 (file)
@@ -397,9 +397,27 @@ Opening network connections
      If given, these should all be integers from the corresponding
      :mod:`socket` module constants.
 
+   * *happy_eyeballs_delay*, if given, enables Happy Eyeballs for this
+     connection. It should
+     be a floating-point number representing the amount of time in seconds
+     to wait for a connection attempt to complete, before starting the next
+     attempt in parallel. This is the "Connection Attempt Delay" as defined
+     in :rfc:`8305`. A sensible default value recommended by the RFC is ``0.25``
+     (250 milliseconds).
+
+   * *interleave* controls address reordering when a host name resolves to
+     multiple IP addresses.
+     If ``0`` or unspecified, no reordering is done, and addresses are
+     tried in the order returned by :meth:`getaddrinfo`. If a positive integer
+     is specified, the addresses are interleaved by address family, and the
+     given integer is interpreted as "First Address Family Count" as defined
+     in :rfc:`8305`. The default is ``0`` if *happy_eyeballs_delay* is not
+     specified, and ``1`` if it is.
+
    * *sock*, if given, should be an existing, already connected
      :class:`socket.socket` object to be used by the transport.
-     If *sock* is given, none of *host*, *port*, *family*, *proto*, *flags*
+     If *sock* is given, none of *host*, *port*, *family*, *proto*, *flags*,
+     *happy_eyeballs_delay*, *interleave*
      and *local_addr* should be specified.
 
    * *local_addr*, if given, is a ``(local_host, local_port)`` tuple used
@@ -410,6 +428,10 @@ Opening network connections
      to wait for the TLS handshake to complete before aborting the connection.
      ``60.0`` seconds if ``None`` (default).
 
+   .. versionadded:: 3.8
+
+      The *happy_eyeballs_delay* and *interleave* parameters.
+
    .. versionadded:: 3.7
 
       The *ssl_handshake_timeout* parameter.
index 9b4b846131de10e5299e231aecdab475cf2a8672..c58906f8b4897f2f3a3984879425657c48f7909d 100644 (file)
@@ -16,6 +16,7 @@ to modify the meaning of the API call itself.
 import collections
 import collections.abc
 import concurrent.futures
+import functools
 import heapq
 import itertools
 import os
@@ -41,6 +42,7 @@ from . import exceptions
 from . import futures
 from . import protocols
 from . import sslproto
+from . import staggered
 from . import tasks
 from . import transports
 from .log import logger
@@ -159,6 +161,28 @@ def _ipaddr_info(host, port, family, type, proto):
     return None
 
 
+def _interleave_addrinfos(addrinfos, first_address_family_count=1):
+    """Interleave list of addrinfo tuples by family."""
+    # Group addresses by family
+    addrinfos_by_family = collections.OrderedDict()
+    for addr in addrinfos:
+        family = addr[0]
+        if family not in addrinfos_by_family:
+            addrinfos_by_family[family] = []
+        addrinfos_by_family[family].append(addr)
+    addrinfos_lists = list(addrinfos_by_family.values())
+
+    reordered = []
+    if first_address_family_count > 1:
+        reordered.extend(addrinfos_lists[0][:first_address_family_count - 1])
+        del addrinfos_lists[0][:first_address_family_count - 1]
+    reordered.extend(
+        a for a in itertools.chain.from_iterable(
+            itertools.zip_longest(*addrinfos_lists)
+        ) if a is not None)
+    return reordered
+
+
 def _run_until_complete_cb(fut):
     if not fut.cancelled():
         exc = fut.exception()
@@ -871,12 +895,49 @@ class BaseEventLoop(events.AbstractEventLoop):
                 "offset must be a non-negative integer (got {!r})".format(
                     offset))
 
+    async def _connect_sock(self, exceptions, addr_info, local_addr_infos=None):
+        """Create, bind and connect one socket."""
+        my_exceptions = []
+        exceptions.append(my_exceptions)
+        family, type_, proto, _, address = addr_info
+        sock = None
+        try:
+            sock = socket.socket(family=family, type=type_, proto=proto)
+            sock.setblocking(False)
+            if local_addr_infos is not None:
+                for _, _, _, _, laddr in local_addr_infos:
+                    try:
+                        sock.bind(laddr)
+                        break
+                    except OSError as exc:
+                        msg = (
+                            f'error while attempting to bind on '
+                            f'address {laddr!r}: '
+                            f'{exc.strerror.lower()}'
+                        )
+                        exc = OSError(exc.errno, msg)
+                        my_exceptions.append(exc)
+                else:  # all bind attempts failed
+                    raise my_exceptions.pop()
+            await self.sock_connect(sock, address)
+            return sock
+        except OSError as exc:
+            my_exceptions.append(exc)
+            if sock is not None:
+                sock.close()
+            raise
+        except:
+            if sock is not None:
+                sock.close()
+            raise
+
     async def create_connection(
             self, protocol_factory, host=None, port=None,
             *, ssl=None, family=0,
             proto=0, flags=0, sock=None,
             local_addr=None, server_hostname=None,
-            ssl_handshake_timeout=None):
+            ssl_handshake_timeout=None,
+            happy_eyeballs_delay=None, interleave=None):
         """Connect to a TCP server.
 
         Create a streaming transport connection to a given Internet host and
@@ -911,6 +972,10 @@ class BaseEventLoop(events.AbstractEventLoop):
             raise ValueError(
                 'ssl_handshake_timeout is only meaningful with ssl')
 
+        if happy_eyeballs_delay is not None and interleave is None:
+            # If using happy eyeballs, default to interleave addresses by family
+            interleave = 1
+
         if host is not None or port is not None:
             if sock is not None:
                 raise ValueError(
@@ -929,43 +994,31 @@ class BaseEventLoop(events.AbstractEventLoop):
                     flags=flags, loop=self)
                 if not laddr_infos:
                     raise OSError('getaddrinfo() returned empty list')
+            else:
+                laddr_infos = None
+
+            if interleave:
+                infos = _interleave_addrinfos(infos, interleave)
 
             exceptions = []
-            for family, type, proto, cname, address in infos:
-                try:
-                    sock = socket.socket(family=family, type=type, proto=proto)
-                    sock.setblocking(False)
-                    if local_addr is not None:
-                        for _, _, _, _, laddr in laddr_infos:
-                            try:
-                                sock.bind(laddr)
-                                break
-                            except OSError as exc:
-                                msg = (
-                                    f'error while attempting to bind on '
-                                    f'address {laddr!r}: '
-                                    f'{exc.strerror.lower()}'
-                                )
-                                exc = OSError(exc.errno, msg)
-                                exceptions.append(exc)
-                        else:
-                            sock.close()
-                            sock = None
-                            continue
-                    if self._debug:
-                        logger.debug("connect %r to %r", sock, address)
-                    await self.sock_connect(sock, address)
-                except OSError as exc:
-                    if sock is not None:
-                        sock.close()
-                    exceptions.append(exc)
-                except:
-                    if sock is not None:
-                        sock.close()
-                    raise
-                else:
-                    break
-            else:
+            if happy_eyeballs_delay is None:
+                # not using happy eyeballs
+                for addrinfo in infos:
+                    try:
+                        sock = await self._connect_sock(
+                            exceptions, addrinfo, laddr_infos)
+                        break
+                    except OSError:
+                        continue
+            else:  # using happy eyeballs
+                sock, _, _ = await staggered.staggered_race(
+                    (functools.partial(self._connect_sock,
+                                       exceptions, addrinfo, laddr_infos)
+                     for addrinfo in infos),
+                    happy_eyeballs_delay, loop=self)
+
+            if sock is None:
+                exceptions = [exc for sub in exceptions for exc in sub]
                 if len(exceptions) == 1:
                     raise exceptions[0]
                 else:
index 163b868afeee36aa873e3539015023b63721003e..9a923514db099399adbf0864be619eed5f856379 100644 (file)
@@ -298,7 +298,8 @@ class AbstractEventLoop:
             *, ssl=None, family=0, proto=0,
             flags=0, sock=None, local_addr=None,
             server_hostname=None,
-            ssl_handshake_timeout=None):
+            ssl_handshake_timeout=None,
+            happy_eyeballs_delay=None, interleave=None):
         raise NotImplementedError
 
     async def create_server(
diff --git a/Lib/asyncio/staggered.py b/Lib/asyncio/staggered.py
new file mode 100644 (file)
index 0000000..feec681
--- /dev/null
@@ -0,0 +1,147 @@
+"""Support for running coroutines in parallel with staggered start times."""
+
+__all__ = 'staggered_race',
+
+import contextlib
+import typing
+
+from . import events
+from . import futures
+from . import locks
+from . import tasks
+
+
+async def staggered_race(
+        coro_fns: typing.Iterable[typing.Callable[[], typing.Awaitable]],
+        delay: typing.Optional[float],
+        *,
+        loop: events.AbstractEventLoop = None,
+) -> typing.Tuple[
+    typing.Any,
+    typing.Optional[int],
+    typing.List[typing.Optional[Exception]]
+]:
+    """Run coroutines with staggered start times and take the first to finish.
+
+    This method takes an iterable of coroutine functions. The first one is
+    started immediately. From then on, whenever the immediately preceding one
+    fails (raises an exception), or when *delay* seconds has passed, the next
+    coroutine is started. This continues until one of the coroutines complete
+    successfully, in which case all others are cancelled, or until all
+    coroutines fail.
+
+    The coroutines provided should be well-behaved in the following way:
+
+    * They should only ``return`` if completed successfully.
+
+    * They should always raise an exception if they did not complete
+      successfully. In particular, if they handle cancellation, they should
+      probably reraise, like this::
+
+        try:
+            # do work
+        except asyncio.CancelledError:
+            # undo partially completed work
+            raise
+
+    Args:
+        coro_fns: an iterable of coroutine functions, i.e. callables that
+            return a coroutine object when called. Use ``functools.partial`` or
+            lambdas to pass arguments.
+
+        delay: amount of time, in seconds, between starting coroutines. If
+            ``None``, the coroutines will run sequentially.
+
+        loop: the event loop to use.
+
+    Returns:
+        tuple *(winner_result, winner_index, exceptions)* where
+
+        - *winner_result*: the result of the winning coroutine, or ``None``
+          if no coroutines won.
+
+        - *winner_index*: the index of the winning coroutine in
+          ``coro_fns``, or ``None`` if no coroutines won. If the winning
+          coroutine may return None on success, *winner_index* can be used
+          to definitively determine whether any coroutine won.
+
+        - *exceptions*: list of exceptions returned by the coroutines.
+          ``len(exceptions)`` is equal to the number of coroutines actually
+          started, and the order is the same as in ``coro_fns``. The winning
+          coroutine's entry is ``None``.
+
+    """
+    # TODO: when we have aiter() and anext(), allow async iterables in coro_fns.
+    loop = loop or events.get_running_loop()
+    enum_coro_fns = enumerate(coro_fns)
+    winner_result = None
+    winner_index = None
+    exceptions = []
+    running_tasks = []
+
+    async def run_one_coro(
+            previous_failed: typing.Optional[locks.Event]) -> None:
+        # Wait for the previous task to finish, or for delay seconds
+        if previous_failed is not None:
+            with contextlib.suppress(futures.TimeoutError):
+                # Use asyncio.wait_for() instead of asyncio.wait() here, so
+                # that if we get cancelled at this point, Event.wait() is also
+                # cancelled, otherwise there will be a "Task destroyed but it is
+                # pending" later.
+                await tasks.wait_for(previous_failed.wait(), delay)
+        # Get the next coroutine to run
+        try:
+            this_index, coro_fn = next(enum_coro_fns)
+        except StopIteration:
+            return
+        # Start task that will run the next coroutine
+        this_failed = locks.Event()
+        next_task = loop.create_task(run_one_coro(this_failed))
+        running_tasks.append(next_task)
+        assert len(running_tasks) == this_index + 2
+        # Prepare place to put this coroutine's exceptions if not won
+        exceptions.append(None)
+        assert len(exceptions) == this_index + 1
+
+        try:
+            result = await coro_fn()
+        except Exception as e:
+            exceptions[this_index] = e
+            this_failed.set()  # Kickstart the next coroutine
+        else:
+            # Store winner's results
+            nonlocal winner_index, winner_result
+            assert winner_index is None
+            winner_index = this_index
+            winner_result = result
+            # Cancel all other tasks. We take care to not cancel the current
+            # task as well. If we do so, then since there is no `await` after
+            # here and CancelledError are usually thrown at one, we will
+            # encounter a curious corner case where the current task will end
+            # up as done() == True, cancelled() == False, exception() ==
+            # asyncio.CancelledError. This behavior is specified in
+            # https://bugs.python.org/issue30048
+            for i, t in enumerate(running_tasks):
+                if i != this_index:
+                    t.cancel()
+
+    first_task = loop.create_task(run_one_coro(None))
+    running_tasks.append(first_task)
+    try:
+        # Wait for a growing list of tasks to all finish: poor man's version of
+        # curio's TaskGroup or trio's nursery
+        done_count = 0
+        while done_count != len(running_tasks):
+            done, _ = await tasks.wait(running_tasks)
+            done_count = len(done)
+            # If run_one_coro raises an unhandled exception, it's probably a
+            # programming error, and I want to see it.
+            if __debug__:
+                for d in done:
+                    if d.done() and not d.cancelled() and d.exception():
+                        raise d.exception()
+        return winner_result, winner_index, exceptions
+    finally:
+        # Make sure no tasks are left running if we leave this function
+        for t in running_tasks:
+            t.cancel()
diff --git a/Misc/NEWS.d/next/Library/2018-05-29-18-34-53.bpo-33530._4Q_bi.rst b/Misc/NEWS.d/next/Library/2018-05-29-18-34-53.bpo-33530._4Q_bi.rst
new file mode 100644 (file)
index 0000000..747219b
--- /dev/null
@@ -0,0 +1,3 @@
+Implemented Happy Eyeballs in `asyncio.create_connection()`. Added two new
+arguments, *happy_eyeballs_delay* and *interleave*,
+to specify Happy Eyeballs behavior.