]> granicus.if.org Git - python/commitdiff
Issue #28176: Fix callbacks race in asyncio.SelectorLoop.sock_connect.
authorYury Selivanov <yury@magic.io>
Thu, 15 Sep 2016 21:56:36 +0000 (17:56 -0400)
committerYury Selivanov <yury@magic.io>
Thu, 15 Sep 2016 21:56:36 +0000 (17:56 -0400)
Lib/asyncio/selector_events.py
Lib/test/test_asyncio/test_selector_events.py
Misc/NEWS

index c18885ebf223c76f9bc01f0e4d555168b5e6aae5..2f02d7676b62a72c2b8df6558d363ffe7ef46fad 100644 (file)
@@ -400,6 +400,7 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
                 data = data[n:]
             self.add_writer(fd, self._sock_sendall, fut, True, sock, data)
 
+    @coroutine
     def sock_connect(self, sock, address):
         """Connect to a remote socket at address.
 
@@ -408,24 +409,16 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
         if self._debug and sock.gettimeout() != 0:
             raise ValueError("the socket must be non-blocking")
 
-        fut = self.create_future()
-        if hasattr(socket, 'AF_UNIX') and sock.family == socket.AF_UNIX:
-            self._sock_connect(fut, sock, address)
-        else:
+        if not hasattr(socket, 'AF_UNIX') or sock.family != socket.AF_UNIX:
             resolved = base_events._ensure_resolved(
                 address, family=sock.family, proto=sock.proto, loop=self)
-            resolved.add_done_callback(
-                lambda resolved: self._on_resolved(fut, sock, resolved))
-
-        return fut
-
-    def _on_resolved(self, fut, sock, resolved):
-        try:
+            if not resolved.done():
+                yield from resolved
             _, _, _, _, address = resolved.result()[0]
-        except Exception as exc:
-            fut.set_exception(exc)
-        else:
-            self._sock_connect(fut, sock, address)
+
+        fut = self.create_future()
+        self._sock_connect(fut, sock, address)
+        return (yield from fut)
 
     def _sock_connect(self, fut, sock, address):
         fd = sock.fileno()
@@ -436,8 +429,8 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
             # connection runs in background. We have to wait until the socket
             # becomes writable to be notified when the connection succeed or
             # fails.
-            fut.add_done_callback(functools.partial(self._sock_connect_done,
-                                                    fd))
+            fut.add_done_callback(
+                functools.partial(self._sock_connect_done, fd))
             self.add_writer(fd, self._sock_connect_cb, fut, sock, address)
         except Exception as exc:
             fut.set_exception(exc)
index 73bc3f3281f46bd311e361fd06accb38d2b04238..0c26a87dcdfc532579a31dc655375e738960579f 100644 (file)
@@ -2,6 +2,8 @@
 
 import errno
 import socket
+import threading
+import time
 import unittest
 from unittest import mock
 try:
@@ -337,18 +339,6 @@ class BaseSelectorEventLoopTests(test_utils.TestCase):
             (10, self.loop._sock_sendall, f, True, sock, b'data'),
             self.loop.add_writer.call_args[0])
 
-    def test_sock_connect(self):
-        sock = test_utils.mock_nonblocking_socket()
-        self.loop._sock_connect = mock.Mock()
-
-        f = self.loop.sock_connect(sock, ('127.0.0.1', 8080))
-        self.assertIsInstance(f, asyncio.Future)
-        self.loop._run_once()
-        future_in, sock_in, address_in = self.loop._sock_connect.call_args[0]
-        self.assertEqual(future_in, f)
-        self.assertEqual(sock_in, sock)
-        self.assertEqual(address_in, ('127.0.0.1', 8080))
-
     def test_sock_connect_timeout(self):
         # asyncio issue #205: sock_connect() must unregister the socket on
         # timeout error
@@ -360,29 +350,34 @@ class BaseSelectorEventLoopTests(test_utils.TestCase):
         sock.connect.side_effect = BlockingIOError
 
         # first call to sock_connect() registers the socket
-        fut = self.loop.sock_connect(sock, ('127.0.0.1', 80))
+        fut = self.loop.create_task(
+            self.loop.sock_connect(sock, ('127.0.0.1', 80)))
         self.loop._run_once()
         self.assertTrue(sock.connect.called)
         self.assertTrue(self.loop.add_writer.called)
-        self.assertEqual(len(fut._callbacks), 1)
 
         # on timeout, the socket must be unregistered
         sock.connect.reset_mock()
-        fut.set_exception(asyncio.TimeoutError)
-        with self.assertRaises(asyncio.TimeoutError):
+        fut.cancel()
+        with self.assertRaises(asyncio.CancelledError):
             self.loop.run_until_complete(fut)
         self.assertTrue(self.loop.remove_writer.called)
 
-    def test_sock_connect_resolve_using_socket_params(self):
+    @mock.patch('socket.getaddrinfo')
+    def test_sock_connect_resolve_using_socket_params(self, m_gai):
         addr = ('need-resolution.com', 8080)
         sock = test_utils.mock_nonblocking_socket()
-        self.loop.getaddrinfo = mock.Mock()
-        self.loop.sock_connect(sock, addr)
-        while not self.loop.getaddrinfo.called:
+        m_gai.side_effect = (None, None, None, None, ('127.0.0.1', 0))
+        m_gai._is_coroutine = False
+        con = self.loop.create_task(self.loop.sock_connect(sock, addr))
+        while not m_gai.called:
             self.loop._run_once()
-        self.loop.getaddrinfo.assert_called_with(
-            *addr, type=sock.type, family=sock.family, proto=sock.proto,
-            flags=0)
+        m_gai.assert_called_with(
+            addr[0], addr[1], sock.family, sock.type, sock.proto, 0)
+
+        con.cancel()
+        with self.assertRaises(asyncio.CancelledError):
+            self.loop.run_until_complete(con)
 
     def test__sock_connect(self):
         f = asyncio.Future(loop=self.loop)
@@ -1792,5 +1787,88 @@ class SelectorDatagramTransportTests(test_utils.TestCase):
             exc_info=(ConnectionRefusedError, MOCK_ANY, MOCK_ANY))
 
 
+class SelectorLoopFunctionalTests(unittest.TestCase):
+
+    def setUp(self):
+        self.loop = asyncio.new_event_loop()
+        asyncio.set_event_loop(None)
+
+    def tearDown(self):
+        self.loop.close()
+
+    @asyncio.coroutine
+    def recv_all(self, sock, nbytes):
+        buf = b''
+        while len(buf) < nbytes:
+            buf += yield from self.loop.sock_recv(sock, nbytes - len(buf))
+        return buf
+
+    def test_sock_connect_sock_write_race(self):
+        TIMEOUT = 3.0
+        PAYLOAD = b'DATA' * 1024 * 1024
+
+        class Server(threading.Thread):
+            def __init__(self, *args, srv_sock, **kwargs):
+                super().__init__(*args, **kwargs)
+                self.srv_sock = srv_sock
+
+            def run(self):
+                with self.srv_sock:
+                    srv_sock.listen(100)
+
+                    sock, addr = self.srv_sock.accept()
+                    sock.settimeout(TIMEOUT)
+
+                    with sock:
+                        sock.sendall(b'helo')
+
+                        buf = bytearray()
+                        while len(buf) < len(PAYLOAD):
+                            pack = sock.recv(1024 * 65)
+                            if not pack:
+                                break
+                            buf.extend(pack)
+
+        @asyncio.coroutine
+        def client(addr):
+            sock = socket.socket()
+            with sock:
+                sock.setblocking(False)
+
+                started = time.monotonic()
+                while True:
+                    if time.monotonic() - started > TIMEOUT:
+                        self.fail('unable to connect to the socket')
+                        return
+                    try:
+                        yield from self.loop.sock_connect(sock, addr)
+                    except OSError:
+                        yield from asyncio.sleep(0.05, loop=self.loop)
+                    else:
+                        break
+
+                # Give 'Server' thread a chance to accept and send b'helo'
+                time.sleep(0.1)
+
+                data = yield from self.recv_all(sock, 4)
+                self.assertEqual(data, b'helo')
+                yield from self.loop.sock_sendall(sock, PAYLOAD)
+
+        srv_sock = socket.socket()
+        srv_sock.settimeout(TIMEOUT)
+        srv_sock.bind(('127.0.0.1', 0))
+        srv_addr = srv_sock.getsockname()
+
+        srv = Server(srv_sock=srv_sock, daemon=True)
+        srv.start()
+
+        try:
+            self.loop.run_until_complete(
+                asyncio.wait_for(client(srv_addr), loop=self.loop,
+                                 timeout=TIMEOUT))
+        finally:
+            srv.join()
+
+
 if __name__ == '__main__':
     unittest.main()
index be959745c7770a920358ed14ef7469aedd854b17..8f5ee3b8856a259c13e653320f791cbfdfc56263 100644 (file)
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -275,6 +275,8 @@ Library
 - Issue #26909: Fix slow pipes IO in asyncio.
   Patch by INADA Naoki.
 
+- Issue #28176: Fix callbacks race in asyncio.SelectorLoop.sock_connect.
+
 IDLE
 ----