data = data[n:]
self.add_writer(fd, self._sock_sendall, fut, True, sock, data)
+ @coroutine
def sock_connect(self, sock, address):
"""Connect to a remote socket at address.
if self._debug and sock.gettimeout() != 0:
raise ValueError("the socket must be non-blocking")
- fut = self.create_future()
- if hasattr(socket, 'AF_UNIX') and sock.family == socket.AF_UNIX:
- self._sock_connect(fut, sock, address)
- else:
+ if not hasattr(socket, 'AF_UNIX') or sock.family != socket.AF_UNIX:
resolved = base_events._ensure_resolved(
address, family=sock.family, proto=sock.proto, loop=self)
- resolved.add_done_callback(
- lambda resolved: self._on_resolved(fut, sock, resolved))
-
- return fut
-
- def _on_resolved(self, fut, sock, resolved):
- try:
+ if not resolved.done():
+ yield from resolved
_, _, _, _, address = resolved.result()[0]
- except Exception as exc:
- fut.set_exception(exc)
- else:
- self._sock_connect(fut, sock, address)
+
+ fut = self.create_future()
+ self._sock_connect(fut, sock, address)
+ return (yield from fut)
def _sock_connect(self, fut, sock, address):
fd = sock.fileno()
# connection runs in background. We have to wait until the socket
# becomes writable to be notified when the connection succeed or
# fails.
- fut.add_done_callback(functools.partial(self._sock_connect_done,
- fd))
+ fut.add_done_callback(
+ functools.partial(self._sock_connect_done, fd))
self.add_writer(fd, self._sock_connect_cb, fut, sock, address)
except Exception as exc:
fut.set_exception(exc)
import errno
import socket
+import threading
+import time
import unittest
from unittest import mock
try:
(10, self.loop._sock_sendall, f, True, sock, b'data'),
self.loop.add_writer.call_args[0])
- def test_sock_connect(self):
- sock = test_utils.mock_nonblocking_socket()
- self.loop._sock_connect = mock.Mock()
-
- f = self.loop.sock_connect(sock, ('127.0.0.1', 8080))
- self.assertIsInstance(f, asyncio.Future)
- self.loop._run_once()
- future_in, sock_in, address_in = self.loop._sock_connect.call_args[0]
- self.assertEqual(future_in, f)
- self.assertEqual(sock_in, sock)
- self.assertEqual(address_in, ('127.0.0.1', 8080))
-
def test_sock_connect_timeout(self):
# asyncio issue #205: sock_connect() must unregister the socket on
# timeout error
sock.connect.side_effect = BlockingIOError
# first call to sock_connect() registers the socket
- fut = self.loop.sock_connect(sock, ('127.0.0.1', 80))
+ fut = self.loop.create_task(
+ self.loop.sock_connect(sock, ('127.0.0.1', 80)))
self.loop._run_once()
self.assertTrue(sock.connect.called)
self.assertTrue(self.loop.add_writer.called)
- self.assertEqual(len(fut._callbacks), 1)
# on timeout, the socket must be unregistered
sock.connect.reset_mock()
- fut.set_exception(asyncio.TimeoutError)
- with self.assertRaises(asyncio.TimeoutError):
+ fut.cancel()
+ with self.assertRaises(asyncio.CancelledError):
self.loop.run_until_complete(fut)
self.assertTrue(self.loop.remove_writer.called)
- def test_sock_connect_resolve_using_socket_params(self):
+ @mock.patch('socket.getaddrinfo')
+ def test_sock_connect_resolve_using_socket_params(self, m_gai):
addr = ('need-resolution.com', 8080)
sock = test_utils.mock_nonblocking_socket()
- self.loop.getaddrinfo = mock.Mock()
- self.loop.sock_connect(sock, addr)
- while not self.loop.getaddrinfo.called:
+ m_gai.side_effect = (None, None, None, None, ('127.0.0.1', 0))
+ m_gai._is_coroutine = False
+ con = self.loop.create_task(self.loop.sock_connect(sock, addr))
+ while not m_gai.called:
self.loop._run_once()
- self.loop.getaddrinfo.assert_called_with(
- *addr, type=sock.type, family=sock.family, proto=sock.proto,
- flags=0)
+ m_gai.assert_called_with(
+ addr[0], addr[1], sock.family, sock.type, sock.proto, 0)
+
+ con.cancel()
+ with self.assertRaises(asyncio.CancelledError):
+ self.loop.run_until_complete(con)
def test__sock_connect(self):
f = asyncio.Future(loop=self.loop)
exc_info=(ConnectionRefusedError, MOCK_ANY, MOCK_ANY))
+class SelectorLoopFunctionalTests(unittest.TestCase):
+
+ def setUp(self):
+ self.loop = asyncio.new_event_loop()
+ asyncio.set_event_loop(None)
+
+ def tearDown(self):
+ self.loop.close()
+
+ @asyncio.coroutine
+ def recv_all(self, sock, nbytes):
+ buf = b''
+ while len(buf) < nbytes:
+ buf += yield from self.loop.sock_recv(sock, nbytes - len(buf))
+ return buf
+
+ def test_sock_connect_sock_write_race(self):
+ TIMEOUT = 3.0
+ PAYLOAD = b'DATA' * 1024 * 1024
+
+ class Server(threading.Thread):
+ def __init__(self, *args, srv_sock, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.srv_sock = srv_sock
+
+ def run(self):
+ with self.srv_sock:
+ srv_sock.listen(100)
+
+ sock, addr = self.srv_sock.accept()
+ sock.settimeout(TIMEOUT)
+
+ with sock:
+ sock.sendall(b'helo')
+
+ buf = bytearray()
+ while len(buf) < len(PAYLOAD):
+ pack = sock.recv(1024 * 65)
+ if not pack:
+ break
+ buf.extend(pack)
+
+ @asyncio.coroutine
+ def client(addr):
+ sock = socket.socket()
+ with sock:
+ sock.setblocking(False)
+
+ started = time.monotonic()
+ while True:
+ if time.monotonic() - started > TIMEOUT:
+ self.fail('unable to connect to the socket')
+ return
+ try:
+ yield from self.loop.sock_connect(sock, addr)
+ except OSError:
+ yield from asyncio.sleep(0.05, loop=self.loop)
+ else:
+ break
+
+ # Give 'Server' thread a chance to accept and send b'helo'
+ time.sleep(0.1)
+
+ data = yield from self.recv_all(sock, 4)
+ self.assertEqual(data, b'helo')
+ yield from self.loop.sock_sendall(sock, PAYLOAD)
+
+ srv_sock = socket.socket()
+ srv_sock.settimeout(TIMEOUT)
+ srv_sock.bind(('127.0.0.1', 0))
+ srv_addr = srv_sock.getsockname()
+
+ srv = Server(srv_sock=srv_sock, daemon=True)
+ srv.start()
+
+ try:
+ self.loop.run_until_complete(
+ asyncio.wait_for(client(srv_addr), loop=self.loop,
+ timeout=TIMEOUT))
+ finally:
+ srv.join()
+
+
if __name__ == '__main__':
unittest.main()