sock.setblocking(False)
+ transport, protocol = yield from self._create_connection_transport(
+ sock, protocol_factory, ssl, server_hostname)
+ return transport, protocol
+
+ @tasks.coroutine
+ def _create_connection_transport(self, sock, protocol_factory, ssl,
+ server_hostname):
protocol = protocol_factory()
waiter = futures.Future(loop=self)
if ssl:
"""
raise NotImplementedError
+ def create_unix_connection(self, protocol_factory, path, *,
+ ssl=None, sock=None,
+ server_hostname=None):
+ raise NotImplementedError
+
+ def create_unix_server(self, protocol_factory, path, *,
+ sock=None, backlog=100, ssl=None):
+ """A coroutine which creates a UNIX Domain Socket server.
+
+ The return valud is a Server object, which can be used to stop
+ the service.
+
+ path is a str, representing a file systsem path to bind the
+ server socket to.
+
+ sock can optionally be specified in order to use a preexisting
+ socket object.
+
+ backlog is the maximum number of queued connections passed to
+ listen() (defaults to 100).
+
+ ssl can be set to an SSLContext to enable SSL over the
+ accepted connections.
+ """
+ raise NotImplementedError
+
def create_datagram_endpoint(self, protocol_factory,
local_addr=None, remote_addr=None, *,
family=0, proto=0, flags=0):
"""Stream-related things."""
__all__ = ['StreamReader', 'StreamWriter', 'StreamReaderProtocol',
- 'open_connection', 'start_server', 'IncompleteReadError',
+ 'open_connection', 'start_server',
+ 'open_unix_connection', 'start_unix_server',
+ 'IncompleteReadError',
]
+import socket
+
from . import events
from . import futures
from . import protocols
return (yield from loop.create_server(factory, host, port, **kwds))
+if hasattr(socket, 'AF_UNIX'):
+ # UNIX Domain Sockets are supported on this platform
+
+ @tasks.coroutine
+ def open_unix_connection(path=None, *,
+ loop=None, limit=_DEFAULT_LIMIT, **kwds):
+ """Similar to `open_connection` but works with UNIX Domain Sockets."""
+ if loop is None:
+ loop = events.get_event_loop()
+ reader = StreamReader(limit=limit, loop=loop)
+ protocol = StreamReaderProtocol(reader, loop=loop)
+ transport, _ = yield from loop.create_unix_connection(
+ lambda: protocol, path, **kwds)
+ writer = StreamWriter(transport, protocol, reader, loop)
+ return reader, writer
+
+
+ @tasks.coroutine
+ def start_unix_server(client_connected_cb, path=None, *,
+ loop=None, limit=_DEFAULT_LIMIT, **kwds):
+ """Similar to `start_server` but works with UNIX Domain Sockets."""
+ if loop is None:
+ loop = events.get_event_loop()
+
+ def factory():
+ reader = StreamReader(limit=limit, loop=loop)
+ protocol = StreamReaderProtocol(reader, client_connected_cb,
+ loop=loop)
+ return protocol
+
+ return (yield from loop.create_unix_server(factory, path, **kwds))
+
+
class FlowControlMixin(protocols.Protocol):
"""Reusable flow control logic for StreamWriter.drain().
import contextlib
import io
import os
+import socket
+import socketserver
import sys
+import tempfile
import threading
import time
import unittest
import unittest.mock
+
+from http.server import HTTPServer
from wsgiref.simple_server import make_server, WSGIRequestHandler, WSGIServer
+
try:
import ssl
except ImportError: # pragma: no cover
loop.run_forever()
-@contextlib.contextmanager
-def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False):
+class SilentWSGIRequestHandler(WSGIRequestHandler):
- class SilentWSGIRequestHandler(WSGIRequestHandler):
- def get_stderr(self):
- return io.StringIO()
+ def get_stderr(self):
+ return io.StringIO()
- def log_message(self, format, *args):
- pass
+ def log_message(self, format, *args):
+ pass
- class SilentWSGIServer(WSGIServer):
- def handle_error(self, request, client_address):
+
+class SilentWSGIServer(WSGIServer):
+
+ def handle_error(self, request, client_address):
+ pass
+
+
+class SSLWSGIServerMixin:
+
+ def finish_request(self, request, client_address):
+ # The relative location of our test directory (which
+ # contains the ssl key and certificate files) differs
+ # between the stdlib and stand-alone asyncio.
+ # Prefer our own if we can find it.
+ here = os.path.join(os.path.dirname(__file__), '..', 'tests')
+ if not os.path.isdir(here):
+ here = os.path.join(os.path.dirname(os.__file__),
+ 'test', 'test_asyncio')
+ keyfile = os.path.join(here, 'ssl_key.pem')
+ certfile = os.path.join(here, 'ssl_cert.pem')
+ ssock = ssl.wrap_socket(request,
+ keyfile=keyfile,
+ certfile=certfile,
+ server_side=True)
+ try:
+ self.RequestHandlerClass(ssock, client_address, self)
+ ssock.close()
+ except OSError:
+ # maybe socket has been closed by peer
pass
- class SSLWSGIServer(SilentWSGIServer):
- def finish_request(self, request, client_address):
- # The relative location of our test directory (which
- # contains the ssl key and certificate files) differs
- # between the stdlib and stand-alone asyncio.
- # Prefer our own if we can find it.
- here = os.path.join(os.path.dirname(__file__), '..', 'tests')
- if not os.path.isdir(here):
- here = os.path.join(os.path.dirname(os.__file__),
- 'test', 'test_asyncio')
- keyfile = os.path.join(here, 'ssl_key.pem')
- certfile = os.path.join(here, 'ssl_cert.pem')
- ssock = ssl.wrap_socket(request,
- keyfile=keyfile,
- certfile=certfile,
- server_side=True)
- try:
- self.RequestHandlerClass(ssock, client_address, self)
- ssock.close()
- except OSError:
- # maybe socket has been closed by peer
- pass
+
+class SSLWSGIServer(SSLWSGIServerMixin, SilentWSGIServer):
+ pass
+
+
+def _run_test_server(*, address, use_ssl=False, server_cls, server_ssl_cls):
def app(environ, start_response):
status = '200 OK'
# Run the test WSGI server in a separate thread in order not to
# interfere with event handling in the main thread
- server_class = SSLWSGIServer if use_ssl else SilentWSGIServer
- httpd = make_server(host, port, app,
- server_class, SilentWSGIRequestHandler)
+ server_class = server_ssl_cls if use_ssl else server_cls
+ httpd = server_class(address, SilentWSGIRequestHandler)
+ httpd.set_app(app)
httpd.address = httpd.server_address
server_thread = threading.Thread(target=httpd.serve_forever)
server_thread.start()
server_thread.join()
+if hasattr(socket, 'AF_UNIX'):
+
+ class UnixHTTPServer(socketserver.UnixStreamServer, HTTPServer):
+
+ def server_bind(self):
+ socketserver.UnixStreamServer.server_bind(self)
+ self.server_name = '127.0.0.1'
+ self.server_port = 80
+
+
+ class UnixWSGIServer(UnixHTTPServer, WSGIServer):
+
+ def server_bind(self):
+ UnixHTTPServer.server_bind(self)
+ self.setup_environ()
+
+ def get_request(self):
+ request, client_addr = super().get_request()
+ # Code in the stdlib expects that get_request
+ # will return a socket and a tuple (host, port).
+ # However, this isn't true for UNIX sockets,
+ # as the second return value will be a path;
+ # hence we return some fake data sufficient
+ # to get the tests going
+ return request, ('127.0.0.1', '')
+
+
+ class SilentUnixWSGIServer(UnixWSGIServer):
+
+ def handle_error(self, request, client_address):
+ pass
+
+
+ class UnixSSLWSGIServer(SSLWSGIServerMixin, SilentUnixWSGIServer):
+ pass
+
+
+ def gen_unix_socket_path():
+ with tempfile.NamedTemporaryFile() as file:
+ return file.name
+
+
+ @contextlib.contextmanager
+ def unix_socket_path():
+ path = gen_unix_socket_path()
+ try:
+ yield path
+ finally:
+ try:
+ os.unlink(path)
+ except OSError:
+ pass
+
+
+ @contextlib.contextmanager
+ def run_test_unix_server(*, use_ssl=False):
+ with unix_socket_path() as path:
+ yield from _run_test_server(address=path, use_ssl=use_ssl,
+ server_cls=SilentUnixWSGIServer,
+ server_ssl_cls=UnixSSLWSGIServer)
+
+
+@contextlib.contextmanager
+def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False):
+ yield from _run_test_server(address=(host, port), use_ssl=use_ssl,
+ server_cls=SilentWSGIServer,
+ server_ssl_cls=SSLWSGIServer)
+
+
def make_test_protocol(base):
dct = {}
for name in dir(base):
def _write_to_self(self):
pass
+
def MockCallback(**kwargs):
return unittest.mock.Mock(spec=['__call__'], **kwargs)
import threading
+from . import base_events
from . import base_subprocess
from . import constants
from . import events
class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop):
- """Unix event loop
+ """Unix event loop.
- Adds signal handling to SelectorEventLoop
+ Adds signal handling and UNIX Domain Socket support to SelectorEventLoop.
"""
def __init__(self, selector=None):
def _child_watcher_callback(self, pid, returncode, transp):
self.call_soon_threadsafe(transp._process_exited, returncode)
+ @tasks.coroutine
+ def create_unix_connection(self, protocol_factory, path, *,
+ ssl=None, sock=None,
+ server_hostname=None):
+ assert server_hostname is None or isinstance(server_hostname, str)
+ if ssl:
+ if server_hostname is None:
+ raise ValueError(
+ 'you have to pass server_hostname when using ssl')
+ else:
+ if server_hostname is not None:
+ raise ValueError('server_hostname is only meaningful with ssl')
+
+ if path is not None:
+ if sock is not None:
+ raise ValueError(
+ 'path and sock can not be specified at the same time')
+
+ try:
+ sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM, 0)
+ sock.setblocking(False)
+ yield from self.sock_connect(sock, path)
+ except OSError:
+ if sock is not None:
+ sock.close()
+ raise
+
+ else:
+ if sock is None:
+ raise ValueError('no path and sock were specified')
+ sock.setblocking(False)
+
+ transport, protocol = yield from self._create_connection_transport(
+ sock, protocol_factory, ssl, server_hostname)
+ return transport, protocol
+
+ @tasks.coroutine
+ def create_unix_server(self, protocol_factory, path=None, *,
+ sock=None, backlog=100, ssl=None):
+ if isinstance(ssl, bool):
+ raise TypeError('ssl argument must be an SSLContext or None')
+
+ if path is not None:
+ sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
+
+ try:
+ sock.bind(path)
+ except OSError as exc:
+ if exc.errno == errno.EADDRINUSE:
+ # Let's improve the error message by adding
+ # with what exact address it occurs.
+ msg = 'Address {!r} is already in use'.format(path)
+ raise OSError(errno.EADDRINUSE, msg) from None
+ else:
+ raise
+ else:
+ if sock is None:
+ raise ValueError(
+ 'path was not specified, and no sock specified')
+
+ if sock.family != socket.AF_UNIX:
+ raise ValueError(
+ 'A UNIX Domain Socket was expected, got {!r}'.format(sock))
+
+ server = base_events.Server(self, [sock])
+ sock.listen(backlog)
+ sock.setblocking(False)
+ self._start_serving(protocol_factory, sock, ssl, server)
+ return server
+
def _set_nonblocking(fd):
flags = fcntl.fcntl(fd, fcntl.F_GETFL)
idx = -1
data = [10.0, 10.0, 10.3, 13.0]
- self.loop._scheduled = [asyncio.TimerHandle(11.0, lambda:True, ())]
+ self.loop._scheduled = [asyncio.TimerHandle(11.0, lambda: True, ())]
self.loop._run_once()
self.assertEqual(logging.DEBUG, m_logger.log.call_args[0][0])
return fullname
raise FileNotFoundError(filename)
+
ONLYCERT = data_file('ssl_cert.pem')
ONLYKEY = data_file('ssl_key.pem')
SIGNED_CERTFILE = data_file('keycert3.pem')
SIGNING_CA = data_file('pycacert.pem')
-class MyProto(asyncio.Protocol):
+class MyBaseProto(asyncio.Protocol):
done = None
def __init__(self, loop=None):
self.transport = transport
assert self.state == 'INITIAL', self.state
self.state = 'CONNECTED'
- transport.write(b'GET / HTTP/1.0\r\nHost: example.com\r\n\r\n')
def data_received(self, data):
assert self.state == 'CONNECTED', self.state
self.done.set_result(None)
+class MyProto(MyBaseProto):
+ def connection_made(self, transport):
+ super().connection_made(transport)
+ transport.write(b'GET / HTTP/1.0\r\nHost: example.com\r\n\r\n')
+
+
class MyDatagramProto(asyncio.DatagramProtocol):
done = None
r.close()
self.assertGreaterEqual(len(data), 200)
+ def _basetest_sock_client_ops(self, httpd, sock):
+ sock.setblocking(False)
+ self.loop.run_until_complete(
+ self.loop.sock_connect(sock, httpd.address))
+ self.loop.run_until_complete(
+ self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n'))
+ data = self.loop.run_until_complete(
+ self.loop.sock_recv(sock, 1024))
+ # consume data
+ self.loop.run_until_complete(
+ self.loop.sock_recv(sock, 1024))
+ sock.close()
+ self.assertTrue(data.startswith(b'HTTP/1.0 200 OK'))
+
def test_sock_client_ops(self):
with test_utils.run_test_server() as httpd:
sock = socket.socket()
- sock.setblocking(False)
- self.loop.run_until_complete(
- self.loop.sock_connect(sock, httpd.address))
- self.loop.run_until_complete(
- self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n'))
- data = self.loop.run_until_complete(
- self.loop.sock_recv(sock, 1024))
- # consume data
- self.loop.run_until_complete(
- self.loop.sock_recv(sock, 1024))
- sock.close()
+ self._basetest_sock_client_ops(httpd, sock)
- self.assertTrue(data.startswith(b'HTTP/1.0 200 OK'))
+ @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets')
+ def test_unix_sock_client_ops(self):
+ with test_utils.run_test_unix_server() as httpd:
+ sock = socket.socket(socket.AF_UNIX)
+ self._basetest_sock_client_ops(httpd, sock)
def test_sock_client_fail(self):
# Make sure that we will get an unused port
self.loop.run_forever()
self.assertEqual(caught, 1)
+ def _basetest_create_connection(self, connection_fut):
+ tr, pr = self.loop.run_until_complete(connection_fut)
+ self.assertIsInstance(tr, asyncio.Transport)
+ self.assertIsInstance(pr, asyncio.Protocol)
+ self.loop.run_until_complete(pr.done)
+ self.assertGreater(pr.nbytes, 0)
+ tr.close()
+
def test_create_connection(self):
with test_utils.run_test_server() as httpd:
- f = self.loop.create_connection(
+ conn_fut = self.loop.create_connection(
lambda: MyProto(loop=self.loop), *httpd.address)
- tr, pr = self.loop.run_until_complete(f)
- self.assertIsInstance(tr, asyncio.Transport)
- self.assertIsInstance(pr, asyncio.Protocol)
- self.loop.run_until_complete(pr.done)
- self.assertGreater(pr.nbytes, 0)
- tr.close()
+ self._basetest_create_connection(conn_fut)
+
+ @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets')
+ def test_create_unix_connection(self):
+ with test_utils.run_test_unix_server() as httpd:
+ conn_fut = self.loop.create_unix_connection(
+ lambda: MyProto(loop=self.loop), httpd.address)
+ self._basetest_create_connection(conn_fut)
def test_create_connection_sock(self):
with test_utils.run_test_server() as httpd:
self.assertGreater(pr.nbytes, 0)
tr.close()
+ def _basetest_create_ssl_connection(self, connection_fut):
+ tr, pr = self.loop.run_until_complete(connection_fut)
+ self.assertIsInstance(tr, asyncio.Transport)
+ self.assertIsInstance(pr, asyncio.Protocol)
+ self.assertTrue('ssl' in tr.__class__.__name__.lower())
+ self.assertIsNotNone(tr.get_extra_info('sockname'))
+ self.loop.run_until_complete(pr.done)
+ self.assertGreater(pr.nbytes, 0)
+ tr.close()
+
@unittest.skipIf(ssl is None, 'No ssl module')
def test_create_ssl_connection(self):
with test_utils.run_test_server(use_ssl=True) as httpd:
- f = self.loop.create_connection(
- lambda: MyProto(loop=self.loop), *httpd.address,
+ conn_fut = self.loop.create_connection(
+ lambda: MyProto(loop=self.loop),
+ *httpd.address,
ssl=test_utils.dummy_ssl_context())
- tr, pr = self.loop.run_until_complete(f)
- self.assertIsInstance(tr, asyncio.Transport)
- self.assertIsInstance(pr, asyncio.Protocol)
- self.assertTrue('ssl' in tr.__class__.__name__.lower())
- self.assertIsNotNone(tr.get_extra_info('sockname'))
- self.loop.run_until_complete(pr.done)
- self.assertGreater(pr.nbytes, 0)
- tr.close()
+
+ self._basetest_create_ssl_connection(conn_fut)
+
+ @unittest.skipIf(ssl is None, 'No ssl module')
+ @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets')
+ def test_create_ssl_unix_connection(self):
+ with test_utils.run_test_unix_server(use_ssl=True) as httpd:
+ conn_fut = self.loop.create_unix_connection(
+ lambda: MyProto(loop=self.loop),
+ httpd.address,
+ ssl=test_utils.dummy_ssl_context(),
+ server_hostname='127.0.0.1')
+
+ self._basetest_create_ssl_connection(conn_fut)
def test_create_connection_local_addr(self):
with test_utils.run_test_server() as httpd:
self.assertIn(str(httpd.address), cm.exception.strerror)
def test_create_server(self):
- proto = None
-
- def factory():
- nonlocal proto
- proto = MyProto()
- return proto
-
- f = self.loop.create_server(factory, '0.0.0.0', 0)
+ proto = MyProto()
+ f = self.loop.create_server(lambda: proto, '0.0.0.0', 0)
server = self.loop.run_until_complete(f)
self.assertEqual(len(server.sockets), 1)
sock = server.sockets[0]
# close server
server.close()
- def _make_ssl_server(self, factory, certfile, keyfile=None):
+ def _make_unix_server(self, factory, **kwargs):
+ path = test_utils.gen_unix_socket_path()
+ self.addCleanup(lambda: os.path.exists(path) and os.unlink(path))
+
+ f = self.loop.create_unix_server(factory, path, **kwargs)
+ server = self.loop.run_until_complete(f)
+
+ return server, path
+
+ @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets')
+ def test_create_unix_server(self):
+ proto = MyProto()
+ server, path = self._make_unix_server(lambda: proto)
+ self.assertEqual(len(server.sockets), 1)
+
+ client = socket.socket(socket.AF_UNIX)
+ client.connect(path)
+ client.sendall(b'xxx')
+ test_utils.run_briefly(self.loop)
+ test_utils.run_until(self.loop, lambda: proto is not None, 10)
+
+ self.assertIsInstance(proto, MyProto)
+ self.assertEqual('INITIAL', proto.state)
+ test_utils.run_briefly(self.loop)
+ self.assertEqual('CONNECTED', proto.state)
+ test_utils.run_until(self.loop, lambda: proto.nbytes > 0,
+ timeout=10)
+ self.assertEqual(3, proto.nbytes)
+
+ # close connection
+ proto.transport.close()
+ test_utils.run_briefly(self.loop) # windows iocp
+
+ self.assertEqual('CLOSED', proto.state)
+
+ # the client socket must be closed after to avoid ECONNRESET upon
+ # recv()/send() on the serving socket
+ client.close()
+
+ # close server
+ server.close()
+
+ def _create_ssl_context(self, certfile, keyfile=None):
sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
sslcontext.options |= ssl.OP_NO_SSLv2
sslcontext.load_cert_chain(certfile, keyfile)
+ return sslcontext
- f = self.loop.create_server(
- factory, '127.0.0.1', 0, ssl=sslcontext)
+ def _make_ssl_server(self, factory, certfile, keyfile=None):
+ sslcontext = self._create_ssl_context(certfile, keyfile)
+ f = self.loop.create_server(factory, '127.0.0.1', 0, ssl=sslcontext)
server = self.loop.run_until_complete(f)
+
sock = server.sockets[0]
host, port = sock.getsockname()
self.assertEqual(host, '127.0.0.1')
return server, host, port
+ def _make_ssl_unix_server(self, factory, certfile, keyfile=None):
+ sslcontext = self._create_ssl_context(certfile, keyfile)
+ return self._make_unix_server(factory, ssl=sslcontext)
+
@unittest.skipIf(ssl is None, 'No ssl module')
def test_create_server_ssl(self):
- proto = None
-
- class ClientMyProto(MyProto):
- def connection_made(self, transport):
- self.transport = transport
- assert self.state == 'INITIAL', self.state
- self.state = 'CONNECTED'
+ proto = MyProto(loop=self.loop)
+ server, host, port = self._make_ssl_server(
+ lambda: proto, ONLYCERT, ONLYKEY)
- def factory():
- nonlocal proto
- proto = MyProto(loop=self.loop)
- return proto
-
- server, host, port = self._make_ssl_server(factory, ONLYCERT, ONLYKEY)
-
- f_c = self.loop.create_connection(ClientMyProto, host, port,
+ f_c = self.loop.create_connection(MyBaseProto, host, port,
ssl=test_utils.dummy_ssl_context())
client, pr = self.loop.run_until_complete(f_c)
server.close()
@unittest.skipIf(ssl is None, 'No ssl module')
- @unittest.skipUnless(HAS_SNI, 'No SNI support in ssl module')
- def test_create_server_ssl_verify_failed(self):
- proto = None
+ @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets')
+ def test_create_unix_server_ssl(self):
+ proto = MyProto(loop=self.loop)
+ server, path = self._make_ssl_unix_server(
+ lambda: proto, ONLYCERT, ONLYKEY)
- def factory():
- nonlocal proto
- proto = MyProto(loop=self.loop)
- return proto
+ f_c = self.loop.create_unix_connection(
+ MyBaseProto, path, ssl=test_utils.dummy_ssl_context(),
+ server_hostname='')
+
+ client, pr = self.loop.run_until_complete(f_c)
+
+ client.write(b'xxx')
+ test_utils.run_briefly(self.loop)
+ self.assertIsInstance(proto, MyProto)
+ test_utils.run_briefly(self.loop)
+ self.assertEqual('CONNECTED', proto.state)
+ test_utils.run_until(self.loop, lambda: proto.nbytes > 0,
+ timeout=10)
+ self.assertEqual(3, proto.nbytes)
+
+ # close connection
+ proto.transport.close()
+ self.loop.run_until_complete(proto.done)
+ self.assertEqual('CLOSED', proto.state)
+
+ # the client socket must be closed after to avoid ECONNRESET upon
+ # recv()/send() on the serving socket
+ client.close()
- server, host, port = self._make_ssl_server(factory, SIGNED_CERTFILE)
+ # stop serving
+ server.close()
+
+ @unittest.skipIf(ssl is None, 'No ssl module')
+ @unittest.skipUnless(HAS_SNI, 'No SNI support in ssl module')
+ def test_create_server_ssl_verify_failed(self):
+ proto = MyProto(loop=self.loop)
+ server, host, port = self._make_ssl_server(
+ lambda: proto, SIGNED_CERTFILE)
sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
sslcontext_client.options |= ssl.OP_NO_SSLv2
@unittest.skipIf(ssl is None, 'No ssl module')
@unittest.skipUnless(HAS_SNI, 'No SNI support in ssl module')
- def test_create_server_ssl_match_failed(self):
- proto = None
+ @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets')
+ def test_create_unix_server_ssl_verify_failed(self):
+ proto = MyProto(loop=self.loop)
+ server, path = self._make_ssl_unix_server(
+ lambda: proto, SIGNED_CERTFILE)
- def factory():
- nonlocal proto
- proto = MyProto(loop=self.loop)
- return proto
+ sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
+ sslcontext_client.options |= ssl.OP_NO_SSLv2
+ sslcontext_client.verify_mode = ssl.CERT_REQUIRED
+ if hasattr(sslcontext_client, 'check_hostname'):
+ sslcontext_client.check_hostname = True
- server, host, port = self._make_ssl_server(factory, SIGNED_CERTFILE)
+ # no CA loaded
+ f_c = self.loop.create_unix_connection(MyProto, path,
+ ssl=sslcontext_client,
+ server_hostname='invalid')
+ with self.assertRaisesRegex(ssl.SSLError,
+ 'certificate verify failed '):
+ self.loop.run_until_complete(f_c)
+
+ # close connection
+ self.assertIsNone(proto.transport)
+ server.close()
+
+ @unittest.skipIf(ssl is None, 'No ssl module')
+ @unittest.skipUnless(HAS_SNI, 'No SNI support in ssl module')
+ def test_create_server_ssl_match_failed(self):
+ proto = MyProto(loop=self.loop)
+ server, host, port = self._make_ssl_server(
+ lambda: proto, SIGNED_CERTFILE)
sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
sslcontext_client.options |= ssl.OP_NO_SSLv2
@unittest.skipIf(ssl is None, 'No ssl module')
@unittest.skipUnless(HAS_SNI, 'No SNI support in ssl module')
- def test_create_server_ssl_verified(self):
- proto = None
+ @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets')
+ def test_create_unix_server_ssl_verified(self):
+ proto = MyProto(loop=self.loop)
+ server, path = self._make_ssl_unix_server(
+ lambda: proto, SIGNED_CERTFILE)
- def factory():
- nonlocal proto
- proto = MyProto(loop=self.loop)
- return proto
+ sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
+ sslcontext_client.options |= ssl.OP_NO_SSLv2
+ sslcontext_client.verify_mode = ssl.CERT_REQUIRED
+ sslcontext_client.load_verify_locations(cafile=SIGNING_CA)
+ if hasattr(sslcontext_client, 'check_hostname'):
+ sslcontext_client.check_hostname = True
- server, host, port = self._make_ssl_server(factory, SIGNED_CERTFILE)
+ # Connection succeeds with correct CA and server hostname.
+ f_c = self.loop.create_unix_connection(MyProto, path,
+ ssl=sslcontext_client,
+ server_hostname='localhost')
+ client, pr = self.loop.run_until_complete(f_c)
+
+ # close connection
+ proto.transport.close()
+ client.close()
+ server.close()
+
+ @unittest.skipIf(ssl is None, 'No ssl module')
+ @unittest.skipUnless(HAS_SNI, 'No SNI support in ssl module')
+ def test_create_server_ssl_verified(self):
+ proto = MyProto(loop=self.loop)
+ server, host, port = self._make_ssl_server(
+ lambda: proto, SIGNED_CERTFILE)
sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
sslcontext_client.options |= ssl.OP_NO_SSLv2
@unittest.skipUnless(sys.platform != 'win32',
"Don't support pipes for Windows")
def test_read_pipe(self):
- proto = None
-
- def factory():
- nonlocal proto
- proto = MyReadPipeProto(loop=self.loop)
- return proto
+ proto = MyReadPipeProto(loop=self.loop)
rpipe, wpipe = os.pipe()
pipeobj = io.open(rpipe, 'rb', 1024)
@asyncio.coroutine
def connect():
- t, p = yield from self.loop.connect_read_pipe(factory, pipeobj)
+ t, p = yield from self.loop.connect_read_pipe(
+ lambda: proto, pipeobj)
self.assertIs(p, proto)
self.assertIs(t, proto.transport)
self.assertEqual(['INITIAL', 'CONNECTED'], proto.state)
# Issue #20495: The test hangs on FreeBSD 7.2 but pass on FreeBSD 9
@support.requires_freebsd_version(8)
def test_read_pty_output(self):
- proto = None
-
- def factory():
- nonlocal proto
- proto = MyReadPipeProto(loop=self.loop)
- return proto
+ proto = MyReadPipeProto(loop=self.loop)
master, slave = os.openpty()
master_read_obj = io.open(master, 'rb', 0)
@asyncio.coroutine
def connect():
- t, p = yield from self.loop.connect_read_pipe(factory,
+ t, p = yield from self.loop.connect_read_pipe(lambda: proto,
master_read_obj)
self.assertIs(p, proto)
self.assertIs(t, proto.transport)
@unittest.skipUnless(sys.platform != 'win32',
"Don't support pipes for Windows")
def test_write_pipe(self):
- proto = None
+ proto = MyWritePipeProto(loop=self.loop)
transport = None
- def factory():
- nonlocal proto
- proto = MyWritePipeProto(loop=self.loop)
- return proto
-
rpipe, wpipe = os.pipe()
pipeobj = io.open(wpipe, 'wb', 1024)
@asyncio.coroutine
def connect():
nonlocal transport
- t, p = yield from self.loop.connect_write_pipe(factory, pipeobj)
+ t, p = yield from self.loop.connect_write_pipe(
+ lambda: proto, pipeobj)
self.assertIs(p, proto)
self.assertIs(t, proto.transport)
self.assertEqual('CONNECTED', proto.state)
@unittest.skipUnless(sys.platform != 'win32',
"Don't support pipes for Windows")
def test_write_pipe_disconnect_on_close(self):
- proto = None
+ proto = MyWritePipeProto(loop=self.loop)
transport = None
- def factory():
- nonlocal proto
- proto = MyWritePipeProto(loop=self.loop)
- return proto
-
rsock, wsock = test_utils.socketpair()
pipeobj = io.open(wsock.detach(), 'wb', 1024)
@asyncio.coroutine
def connect():
nonlocal transport
- t, p = yield from self.loop.connect_write_pipe(factory,
+ t, p = yield from self.loop.connect_write_pipe(lambda: proto,
pipeobj)
self.assertIs(p, proto)
self.assertIs(t, proto.transport)
# older than 10.6 (Snow Leopard)
@support.requires_mac_ver(10, 6)
def test_write_pty(self):
- proto = None
+ proto = MyWritePipeProto(loop=self.loop)
transport = None
- def factory():
- nonlocal proto
- proto = MyWritePipeProto(loop=self.loop)
- return proto
-
master, slave = os.openpty()
slave_write_obj = io.open(slave, 'wb', 0)
@asyncio.coroutine
def connect():
nonlocal transport
- t, p = yield from self.loop.connect_write_pipe(factory,
+ t, p = yield from self.loop.connect_write_pipe(lambda: proto,
slave_write_obj)
self.assertIs(p, proto)
self.assertIs(t, proto.transport)
self.loop.remove_reader = unittest.mock.Mock()
self.loop.remove_writer = unittest.mock.Mock()
waiter = asyncio.Future(loop=self.loop)
- transport = self.loop._make_ssl_transport(m, asyncio.Protocol(), m, waiter)
+ transport = self.loop._make_ssl_transport(
+ m, asyncio.Protocol(), m, waiter)
self.assertIsInstance(transport, _SelectorSslTransport)
@unittest.mock.patch('asyncio.selector_events.ssl', None)
"""Tests for streams.py."""
+import functools
import gc
+import socket
import unittest
import unittest.mock
try:
stream = asyncio.StreamReader()
self.assertIs(stream._loop, m_events.get_event_loop.return_value)
+ def _basetest_open_connection(self, open_connection_fut):
+ reader, writer = self.loop.run_until_complete(open_connection_fut)
+ writer.write(b'GET / HTTP/1.0\r\n\r\n')
+ f = reader.readline()
+ data = self.loop.run_until_complete(f)
+ self.assertEqual(data, b'HTTP/1.0 200 OK\r\n')
+ f = reader.read()
+ data = self.loop.run_until_complete(f)
+ self.assertTrue(data.endswith(b'\r\n\r\nTest message'))
+ writer.close()
+
def test_open_connection(self):
with test_utils.run_test_server() as httpd:
- f = asyncio.open_connection(*httpd.address, loop=self.loop)
- reader, writer = self.loop.run_until_complete(f)
- writer.write(b'GET / HTTP/1.0\r\n\r\n')
- f = reader.readline()
- data = self.loop.run_until_complete(f)
- self.assertEqual(data, b'HTTP/1.0 200 OK\r\n')
- f = reader.read()
- data = self.loop.run_until_complete(f)
- self.assertTrue(data.endswith(b'\r\n\r\nTest message'))
-
- writer.close()
+ conn_fut = asyncio.open_connection(*httpd.address,
+ loop=self.loop)
+ self._basetest_open_connection(conn_fut)
+
+ @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets')
+ def test_open_unix_connection(self):
+ with test_utils.run_test_unix_server() as httpd:
+ conn_fut = asyncio.open_unix_connection(httpd.address,
+ loop=self.loop)
+ self._basetest_open_connection(conn_fut)
+
+ def _basetest_open_connection_no_loop_ssl(self, open_connection_fut):
+ try:
+ reader, writer = self.loop.run_until_complete(open_connection_fut)
+ finally:
+ asyncio.set_event_loop(None)
+ writer.write(b'GET / HTTP/1.0\r\n\r\n')
+ f = reader.read()
+ data = self.loop.run_until_complete(f)
+ self.assertTrue(data.endswith(b'\r\n\r\nTest message'))
+
+ writer.close()
@unittest.skipIf(ssl is None, 'No ssl module')
def test_open_connection_no_loop_ssl(self):
with test_utils.run_test_server(use_ssl=True) as httpd:
- try:
- asyncio.set_event_loop(self.loop)
- f = asyncio.open_connection(*httpd.address,
- ssl=test_utils.dummy_ssl_context())
- reader, writer = self.loop.run_until_complete(f)
- finally:
- asyncio.set_event_loop(None)
- writer.write(b'GET / HTTP/1.0\r\n\r\n')
- f = reader.read()
- data = self.loop.run_until_complete(f)
- self.assertTrue(data.endswith(b'\r\n\r\nTest message'))
+ conn_fut = asyncio.open_connection(
+ *httpd.address,
+ ssl=test_utils.dummy_ssl_context(),
+ loop=self.loop)
- writer.close()
+ self._basetest_open_connection_no_loop_ssl(conn_fut)
+
+ @unittest.skipIf(ssl is None, 'No ssl module')
+ @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets')
+ def test_open_unix_connection_no_loop_ssl(self):
+ with test_utils.run_test_unix_server(use_ssl=True) as httpd:
+ conn_fut = asyncio.open_unix_connection(
+ httpd.address,
+ ssl=test_utils.dummy_ssl_context(),
+ server_hostname='',
+ loop=self.loop)
+
+ self._basetest_open_connection_no_loop_ssl(conn_fut)
+
+ def _basetest_open_connection_error(self, open_connection_fut):
+ reader, writer = self.loop.run_until_complete(open_connection_fut)
+ writer._protocol.connection_lost(ZeroDivisionError())
+ f = reader.read()
+ with self.assertRaises(ZeroDivisionError):
+ self.loop.run_until_complete(f)
+ writer.close()
+ test_utils.run_briefly(self.loop)
def test_open_connection_error(self):
with test_utils.run_test_server() as httpd:
- f = asyncio.open_connection(*httpd.address, loop=self.loop)
- reader, writer = self.loop.run_until_complete(f)
- writer._protocol.connection_lost(ZeroDivisionError())
- f = reader.read()
- with self.assertRaises(ZeroDivisionError):
- self.loop.run_until_complete(f)
+ conn_fut = asyncio.open_connection(*httpd.address,
+ loop=self.loop)
+ self._basetest_open_connection_error(conn_fut)
- writer.close()
- test_utils.run_briefly(self.loop)
+ @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets')
+ def test_open_unix_connection_error(self):
+ with test_utils.run_test_unix_server() as httpd:
+ conn_fut = asyncio.open_unix_connection(httpd.address,
+ loop=self.loop)
+ self._basetest_open_connection_error(conn_fut)
def test_feed_empty_data(self):
stream = asyncio.StreamReader(loop=self.loop)
client_writer.write(data)
def start(self):
+ sock = socket.socket()
+ sock.bind(('127.0.0.1', 0))
self.server = self.loop.run_until_complete(
asyncio.start_server(self.handle_client,
- '127.0.0.1', 12345,
+ sock=sock,
loop=self.loop))
+ return sock.getsockname()
def handle_client_callback(self, client_reader, client_writer):
task = asyncio.Task(client_reader.readline(), loop=self.loop)
task.add_done_callback(done)
def start_callback(self):
+ sock = socket.socket()
+ sock.bind(('127.0.0.1', 0))
+ addr = sock.getsockname()
+ sock.close()
self.server = self.loop.run_until_complete(
asyncio.start_server(self.handle_client_callback,
- '127.0.0.1', 12345,
+ host=addr[0], port=addr[1],
loop=self.loop))
+ return addr
def stop(self):
if self.server is not None:
self.server = None
@asyncio.coroutine
- def client():
+ def client(addr):
reader, writer = yield from asyncio.open_connection(
- '127.0.0.1', 12345, loop=self.loop)
+ *addr, loop=self.loop)
# send a line
writer.write(b"hello world!\n")
# read it back
# test the server variant with a coroutine as client handler
server = MyServer(self.loop)
- server.start()
- msg = self.loop.run_until_complete(asyncio.Task(client(),
+ addr = server.start()
+ msg = self.loop.run_until_complete(asyncio.Task(client(addr),
loop=self.loop))
server.stop()
self.assertEqual(msg, b"hello world!\n")
# test the server variant with a callback as client handler
server = MyServer(self.loop)
- server.start_callback()
- msg = self.loop.run_until_complete(asyncio.Task(client(),
+ addr = server.start_callback()
+ msg = self.loop.run_until_complete(asyncio.Task(client(addr),
loop=self.loop))
server.stop()
self.assertEqual(msg, b"hello world!\n")
+ @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets')
+ def test_start_unix_server(self):
+
+ class MyServer:
+
+ def __init__(self, loop, path):
+ self.server = None
+ self.loop = loop
+ self.path = path
+
+ @asyncio.coroutine
+ def handle_client(self, client_reader, client_writer):
+ data = yield from client_reader.readline()
+ client_writer.write(data)
+
+ def start(self):
+ self.server = self.loop.run_until_complete(
+ asyncio.start_unix_server(self.handle_client,
+ path=self.path,
+ loop=self.loop))
+
+ def handle_client_callback(self, client_reader, client_writer):
+ task = asyncio.Task(client_reader.readline(), loop=self.loop)
+
+ def done(task):
+ client_writer.write(task.result())
+
+ task.add_done_callback(done)
+
+ def start_callback(self):
+ self.server = self.loop.run_until_complete(
+ asyncio.start_unix_server(self.handle_client_callback,
+ path=self.path,
+ loop=self.loop))
+
+ def stop(self):
+ if self.server is not None:
+ self.server.close()
+ self.loop.run_until_complete(self.server.wait_closed())
+ self.server = None
+
+ @asyncio.coroutine
+ def client(path):
+ reader, writer = yield from asyncio.open_unix_connection(
+ path, loop=self.loop)
+ # send a line
+ writer.write(b"hello world!\n")
+ # read it back
+ msgback = yield from reader.readline()
+ writer.close()
+ return msgback
+
+ # test the server variant with a coroutine as client handler
+ with test_utils.unix_socket_path() as path:
+ server = MyServer(self.loop, path)
+ server.start()
+ msg = self.loop.run_until_complete(asyncio.Task(client(path),
+ loop=self.loop))
+ server.stop()
+ self.assertEqual(msg, b"hello world!\n")
+
+ # test the server variant with a callback as client handler
+ with test_utils.unix_socket_path() as path:
+ server = MyServer(self.loop, path)
+ server.start_callback()
+ msg = self.loop.run_until_complete(asyncio.Task(client(path),
+ loop=self.loop))
+ server.stop()
+ self.assertEqual(msg, b"hello world!\n")
+
if __name__ == '__main__':
unittest.main()
import os
import pprint
import signal
+import socket
import stat
import sys
+import tempfile
import threading
import unittest
import unittest.mock
@unittest.skipUnless(signal, 'Signals are not supported')
-class SelectorEventLoopTests(unittest.TestCase):
+class SelectorEventLoopSignalTests(unittest.TestCase):
def setUp(self):
self.loop = asyncio.SelectorEventLoop()
m_signal.set_wakeup_fd.assert_called_once_with(-1)
+@unittest.skipUnless(hasattr(socket, 'AF_UNIX'),
+ 'UNIX Sockets are not supported')
+class SelectorEventLoopUnixSocketTests(unittest.TestCase):
+
+ def setUp(self):
+ self.loop = asyncio.SelectorEventLoop()
+ asyncio.set_event_loop(None)
+
+ def tearDown(self):
+ self.loop.close()
+
+ def test_create_unix_server_existing_path_sock(self):
+ with test_utils.unix_socket_path() as path:
+ sock = socket.socket(socket.AF_UNIX)
+ sock.bind(path)
+
+ coro = self.loop.create_unix_server(lambda: None, path)
+ with self.assertRaisesRegexp(OSError,
+ 'Address.*is already in use'):
+ self.loop.run_until_complete(coro)
+
+ def test_create_unix_server_existing_path_nonsock(self):
+ with tempfile.NamedTemporaryFile() as file:
+ coro = self.loop.create_unix_server(lambda: None, file.name)
+ with self.assertRaisesRegexp(OSError,
+ 'Address.*is already in use'):
+ self.loop.run_until_complete(coro)
+
+ def test_create_unix_server_ssl_bool(self):
+ coro = self.loop.create_unix_server(lambda: None, path='spam',
+ ssl=True)
+ with self.assertRaisesRegex(TypeError,
+ 'ssl argument must be an SSLContext'):
+ self.loop.run_until_complete(coro)
+
+ def test_create_unix_server_nopath_nosock(self):
+ coro = self.loop.create_unix_server(lambda: None, path=None)
+ with self.assertRaisesRegex(ValueError,
+ 'path was not specified, and no sock'):
+ self.loop.run_until_complete(coro)
+
+ def test_create_unix_server_path_inetsock(self):
+ coro = self.loop.create_unix_server(lambda: None, path=None,
+ sock=socket.socket())
+ with self.assertRaisesRegex(ValueError,
+ 'A UNIX Domain Socket was expected'):
+ self.loop.run_until_complete(coro)
+
+ def test_create_unix_connection_path_sock(self):
+ coro = self.loop.create_unix_connection(
+ lambda: None, '/dev/null', sock=object())
+ with self.assertRaisesRegex(ValueError, 'path and sock can not be'):
+ self.loop.run_until_complete(coro)
+
+ def test_create_unix_connection_nopath_nosock(self):
+ coro = self.loop.create_unix_connection(
+ lambda: None, None)
+ with self.assertRaisesRegex(ValueError,
+ 'no path and sock were specified'):
+ self.loop.run_until_complete(coro)
+
+ def test_create_unix_connection_nossl_serverhost(self):
+ coro = self.loop.create_unix_connection(
+ lambda: None, '/dev/null', server_hostname='spam')
+ with self.assertRaisesRegex(ValueError,
+ 'server_hostname is only meaningful'):
+ self.loop.run_until_complete(coro)
+
+ def test_create_unix_connection_ssl_noserverhost(self):
+ coro = self.loop.create_unix_connection(
+ lambda: None, '/dev/null', ssl=True)
+
+ with self.assertRaisesRegexp(
+ ValueError, 'you have to pass server_hostname when using ssl'):
+
+ self.loop.run_until_complete(coro)
+
+
class UnixReadPipeTransportTests(unittest.TestCase):
def setUp(self):