]> granicus.if.org Git - python/commitdiff
asyncio: Add support for UNIX Domain Sockets.
authorYury Selivanov <yselivanov@sprymix.com>
Tue, 18 Feb 2014 17:15:06 +0000 (12:15 -0500)
committerYury Selivanov <yselivanov@sprymix.com>
Tue, 18 Feb 2014 17:15:06 +0000 (12:15 -0500)
Lib/asyncio/base_events.py
Lib/asyncio/events.py
Lib/asyncio/streams.py
Lib/asyncio/test_utils.py
Lib/asyncio/unix_events.py
Lib/test/test_asyncio/test_base_events.py
Lib/test/test_asyncio/test_events.py
Lib/test/test_asyncio/test_selector_events.py
Lib/test/test_asyncio/test_streams.py
Lib/test/test_asyncio/test_unix_events.py

index 3bbf6b54661a3231a70c3641f14feba14b957abd..b74e9369414b75e4db75221c31b9c2787b1afe64 100644 (file)
@@ -407,6 +407,13 @@ class BaseEventLoop(events.AbstractEventLoop):
 
         sock.setblocking(False)
 
+        transport, protocol = yield from self._create_connection_transport(
+            sock, protocol_factory, ssl, server_hostname)
+        return transport, protocol
+
+    @tasks.coroutine
+    def _create_connection_transport(self, sock, protocol_factory, ssl,
+                                     server_hostname):
         protocol = protocol_factory()
         waiter = futures.Future(loop=self)
         if ssl:
index dd9e3fb42992df2e5fdf2157cc5532abb94de74f..7841ad9ba52a7228f82007e086645f1ea69e479a 100644 (file)
@@ -220,6 +220,32 @@ class AbstractEventLoop:
         """
         raise NotImplementedError
 
+    def create_unix_connection(self, protocol_factory, path, *,
+                               ssl=None, sock=None,
+                               server_hostname=None):
+        raise NotImplementedError
+
+    def create_unix_server(self, protocol_factory, path, *,
+                           sock=None, backlog=100, ssl=None):
+        """A coroutine which creates a UNIX Domain Socket server.
+
+        The return valud is a Server object, which can be used to stop
+        the service.
+
+        path is a str, representing a file systsem path to bind the
+        server socket to.
+
+        sock can optionally be specified in order to use a preexisting
+        socket object.
+
+        backlog is the maximum number of queued connections passed to
+        listen() (defaults to 100).
+
+        ssl can be set to an SSLContext to enable SSL over the
+        accepted connections.
+        """
+        raise NotImplementedError
+
     def create_datagram_endpoint(self, protocol_factory,
                                  local_addr=None, remote_addr=None, *,
                                  family=0, proto=0, flags=0):
index 8fc21474e90c8bd0fc129b9f11e0fbbe64d16bac..698c5c6b184201903897b791287ddb9b4a0d4758 100644 (file)
@@ -1,9 +1,13 @@
 """Stream-related things."""
 
 __all__ = ['StreamReader', 'StreamWriter', 'StreamReaderProtocol',
-           'open_connection', 'start_server', 'IncompleteReadError',
+           'open_connection', 'start_server',
+           'open_unix_connection', 'start_unix_server',
+           'IncompleteReadError',
            ]
 
+import socket
+
 from . import events
 from . import futures
 from . import protocols
@@ -93,6 +97,39 @@ def start_server(client_connected_cb, host=None, port=None, *,
     return (yield from loop.create_server(factory, host, port, **kwds))
 
 
+if hasattr(socket, 'AF_UNIX'):
+    # UNIX Domain Sockets are supported on this platform
+
+    @tasks.coroutine
+    def open_unix_connection(path=None, *,
+                             loop=None, limit=_DEFAULT_LIMIT, **kwds):
+        """Similar to `open_connection` but works with UNIX Domain Sockets."""
+        if loop is None:
+            loop = events.get_event_loop()
+        reader = StreamReader(limit=limit, loop=loop)
+        protocol = StreamReaderProtocol(reader, loop=loop)
+        transport, _ = yield from loop.create_unix_connection(
+            lambda: protocol, path, **kwds)
+        writer = StreamWriter(transport, protocol, reader, loop)
+        return reader, writer
+
+
+    @tasks.coroutine
+    def start_unix_server(client_connected_cb, path=None, *,
+                          loop=None, limit=_DEFAULT_LIMIT, **kwds):
+        """Similar to `start_server` but works with UNIX Domain Sockets."""
+        if loop is None:
+            loop = events.get_event_loop()
+
+        def factory():
+            reader = StreamReader(limit=limit, loop=loop)
+            protocol = StreamReaderProtocol(reader, client_connected_cb,
+                                            loop=loop)
+            return protocol
+
+        return (yield from loop.create_unix_server(factory, path, **kwds))
+
+
 class FlowControlMixin(protocols.Protocol):
     """Reusable flow control logic for StreamWriter.drain().
 
index deab7c33122f066fd52d5a756c1894c9fef6d6d2..de2916bfc743b205d801ecf9b892b31698823d2f 100644 (file)
@@ -4,12 +4,18 @@ import collections
 import contextlib
 import io
 import os
+import socket
+import socketserver
 import sys
+import tempfile
 import threading
 import time
 import unittest
 import unittest.mock
+
+from http.server import HTTPServer
 from wsgiref.simple_server import make_server, WSGIRequestHandler, WSGIServer
+
 try:
     import ssl
 except ImportError:  # pragma: no cover
@@ -70,42 +76,51 @@ def run_once(loop):
     loop.run_forever()
 
 
-@contextlib.contextmanager
-def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False):
+class SilentWSGIRequestHandler(WSGIRequestHandler):
 
-    class SilentWSGIRequestHandler(WSGIRequestHandler):
-        def get_stderr(self):
-            return io.StringIO()
+    def get_stderr(self):
+        return io.StringIO()
 
-        def log_message(self, format, *args):
-            pass
+    def log_message(self, format, *args):
+        pass
 
-    class SilentWSGIServer(WSGIServer):
-        def handle_error(self, request, client_address):
+
+class SilentWSGIServer(WSGIServer):
+
+    def handle_error(self, request, client_address):
+        pass
+
+
+class SSLWSGIServerMixin:
+
+    def finish_request(self, request, client_address):
+        # The relative location of our test directory (which
+        # contains the ssl key and certificate files) differs
+        # between the stdlib and stand-alone asyncio.
+        # Prefer our own if we can find it.
+        here = os.path.join(os.path.dirname(__file__), '..', 'tests')
+        if not os.path.isdir(here):
+            here = os.path.join(os.path.dirname(os.__file__),
+                                'test', 'test_asyncio')
+        keyfile = os.path.join(here, 'ssl_key.pem')
+        certfile = os.path.join(here, 'ssl_cert.pem')
+        ssock = ssl.wrap_socket(request,
+                                keyfile=keyfile,
+                                certfile=certfile,
+                                server_side=True)
+        try:
+            self.RequestHandlerClass(ssock, client_address, self)
+            ssock.close()
+        except OSError:
+            # maybe socket has been closed by peer
             pass
 
-    class SSLWSGIServer(SilentWSGIServer):
-        def finish_request(self, request, client_address):
-            # The relative location of our test directory (which
-            # contains the ssl key and certificate files) differs
-            # between the stdlib and stand-alone asyncio.
-            # Prefer our own if we can find it.
-            here = os.path.join(os.path.dirname(__file__), '..', 'tests')
-            if not os.path.isdir(here):
-                here = os.path.join(os.path.dirname(os.__file__),
-                                    'test', 'test_asyncio')
-            keyfile = os.path.join(here, 'ssl_key.pem')
-            certfile = os.path.join(here, 'ssl_cert.pem')
-            ssock = ssl.wrap_socket(request,
-                                    keyfile=keyfile,
-                                    certfile=certfile,
-                                    server_side=True)
-            try:
-                self.RequestHandlerClass(ssock, client_address, self)
-                ssock.close()
-            except OSError:
-                # maybe socket has been closed by peer
-                pass
+
+class SSLWSGIServer(SSLWSGIServerMixin, SilentWSGIServer):
+    pass
+
+
+def _run_test_server(*, address, use_ssl=False, server_cls, server_ssl_cls):
 
     def app(environ, start_response):
         status = '200 OK'
@@ -115,9 +130,9 @@ def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False):
 
     # Run the test WSGI server in a separate thread in order not to
     # interfere with event handling in the main thread
-    server_class = SSLWSGIServer if use_ssl else SilentWSGIServer
-    httpd = make_server(host, port, app,
-                        server_class, SilentWSGIRequestHandler)
+    server_class = server_ssl_cls if use_ssl else server_cls
+    httpd = server_class(address, SilentWSGIRequestHandler)
+    httpd.set_app(app)
     httpd.address = httpd.server_address
     server_thread = threading.Thread(target=httpd.serve_forever)
     server_thread.start()
@@ -129,6 +144,75 @@ def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False):
         server_thread.join()
 
 
+if hasattr(socket, 'AF_UNIX'):
+
+    class UnixHTTPServer(socketserver.UnixStreamServer, HTTPServer):
+
+        def server_bind(self):
+            socketserver.UnixStreamServer.server_bind(self)
+            self.server_name = '127.0.0.1'
+            self.server_port = 80
+
+
+    class UnixWSGIServer(UnixHTTPServer, WSGIServer):
+
+        def server_bind(self):
+            UnixHTTPServer.server_bind(self)
+            self.setup_environ()
+
+        def get_request(self):
+            request, client_addr = super().get_request()
+            # Code in the stdlib expects that get_request
+            # will return a socket and a tuple (host, port).
+            # However, this isn't true for UNIX sockets,
+            # as the second return value will be a path;
+            # hence we return some fake data sufficient
+            # to get the tests going
+            return request, ('127.0.0.1', '')
+
+
+    class SilentUnixWSGIServer(UnixWSGIServer):
+
+        def handle_error(self, request, client_address):
+            pass
+
+
+    class UnixSSLWSGIServer(SSLWSGIServerMixin, SilentUnixWSGIServer):
+        pass
+
+
+    def gen_unix_socket_path():
+        with tempfile.NamedTemporaryFile() as file:
+            return file.name
+
+
+    @contextlib.contextmanager
+    def unix_socket_path():
+        path = gen_unix_socket_path()
+        try:
+            yield path
+        finally:
+            try:
+                os.unlink(path)
+            except OSError:
+                pass
+
+
+    @contextlib.contextmanager
+    def run_test_unix_server(*, use_ssl=False):
+        with unix_socket_path() as path:
+            yield from _run_test_server(address=path, use_ssl=use_ssl,
+                                        server_cls=SilentUnixWSGIServer,
+                                        server_ssl_cls=UnixSSLWSGIServer)
+
+
+@contextlib.contextmanager
+def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False):
+    yield from _run_test_server(address=(host, port), use_ssl=use_ssl,
+                                server_cls=SilentWSGIServer,
+                                server_ssl_cls=SSLWSGIServer)
+
+
 def make_test_protocol(base):
     dct = {}
     for name in dir(base):
@@ -275,5 +359,6 @@ class TestLoop(base_events.BaseEventLoop):
     def _write_to_self(self):
         pass
 
+
 def MockCallback(**kwargs):
     return unittest.mock.Mock(spec=['__call__'], **kwargs)
index ea79d33b33a9f2c47b51dcce95a110dca5f14e0a..e0d75077518278e7261d3a1c7f48799564c1295e 100644 (file)
@@ -11,6 +11,7 @@ import sys
 import threading
 
 
+from . import base_events
 from . import base_subprocess
 from . import constants
 from . import events
@@ -31,9 +32,9 @@ if sys.platform == 'win32':  # pragma: no cover
 
 
 class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop):
-    """Unix event loop
+    """Unix event loop.
 
-    Adds signal handling to SelectorEventLoop
+    Adds signal handling and UNIX Domain Socket support to SelectorEventLoop.
     """
 
     def __init__(self, selector=None):
@@ -164,6 +165,76 @@ class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop):
     def _child_watcher_callback(self, pid, returncode, transp):
         self.call_soon_threadsafe(transp._process_exited, returncode)
 
+    @tasks.coroutine
+    def create_unix_connection(self, protocol_factory, path, *,
+                               ssl=None, sock=None,
+                               server_hostname=None):
+        assert server_hostname is None or isinstance(server_hostname, str)
+        if ssl:
+            if server_hostname is None:
+                raise ValueError(
+                    'you have to pass server_hostname when using ssl')
+        else:
+            if server_hostname is not None:
+                raise ValueError('server_hostname is only meaningful with ssl')
+
+        if path is not None:
+            if sock is not None:
+                raise ValueError(
+                    'path and sock can not be specified at the same time')
+
+            try:
+                sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM, 0)
+                sock.setblocking(False)
+                yield from self.sock_connect(sock, path)
+            except OSError:
+                if sock is not None:
+                    sock.close()
+                raise
+
+        else:
+            if sock is None:
+                raise ValueError('no path and sock were specified')
+            sock.setblocking(False)
+
+        transport, protocol = yield from self._create_connection_transport(
+            sock, protocol_factory, ssl, server_hostname)
+        return transport, protocol
+
+    @tasks.coroutine
+    def create_unix_server(self, protocol_factory, path=None, *,
+                           sock=None, backlog=100, ssl=None):
+        if isinstance(ssl, bool):
+            raise TypeError('ssl argument must be an SSLContext or None')
+
+        if path is not None:
+            sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
+
+            try:
+                sock.bind(path)
+            except OSError as exc:
+                if exc.errno == errno.EADDRINUSE:
+                    # Let's improve the error message by adding
+                    # with what exact address it occurs.
+                    msg = 'Address {!r} is already in use'.format(path)
+                    raise OSError(errno.EADDRINUSE, msg) from None
+                else:
+                    raise
+        else:
+            if sock is None:
+                raise ValueError(
+                    'path was not specified, and no sock specified')
+
+            if sock.family != socket.AF_UNIX:
+                raise ValueError(
+                    'A UNIX Domain Socket was expected, got {!r}'.format(sock))
+
+        server = base_events.Server(self, [sock])
+        sock.listen(backlog)
+        sock.setblocking(False)
+        self._start_serving(protocol_factory, sock, ssl, server)
+        return server
+
 
 def _set_nonblocking(fd):
     flags = fcntl.fcntl(fd, fcntl.F_GETFL)
index 94e2d59df1691b551079c1fb960b5ef4cd73cd7d..9fa984156b0d315f50a93c886334d8c36826967b 100644 (file)
@@ -212,7 +212,7 @@ class BaseEventLoopTests(unittest.TestCase):
 
         idx = -1
         data = [10.0, 10.0, 10.3, 13.0]
-        self.loop._scheduled = [asyncio.TimerHandle(11.0, lambda:True, ())]
+        self.loop._scheduled = [asyncio.TimerHandle(11.0, lambda: True, ())]
         self.loop._run_once()
         self.assertEqual(logging.DEBUG, m_logger.log.call_args[0][0])
 
index 8c32a6e7c2be90b52c3abb899a0d272fad4aeda7..c9d04c044c871dd98dffb0d424ab64173086711d 100644 (file)
@@ -39,13 +39,14 @@ def data_file(filename):
         return fullname
     raise FileNotFoundError(filename)
 
+
 ONLYCERT = data_file('ssl_cert.pem')
 ONLYKEY = data_file('ssl_key.pem')
 SIGNED_CERTFILE = data_file('keycert3.pem')
 SIGNING_CA = data_file('pycacert.pem')
 
 
-class MyProto(asyncio.Protocol):
+class MyBaseProto(asyncio.Protocol):
     done = None
 
     def __init__(self, loop=None):
@@ -59,7 +60,6 @@ class MyProto(asyncio.Protocol):
         self.transport = transport
         assert self.state == 'INITIAL', self.state
         self.state = 'CONNECTED'
-        transport.write(b'GET / HTTP/1.0\r\nHost: example.com\r\n\r\n')
 
     def data_received(self, data):
         assert self.state == 'CONNECTED', self.state
@@ -76,6 +76,12 @@ class MyProto(asyncio.Protocol):
             self.done.set_result(None)
 
 
+class MyProto(MyBaseProto):
+    def connection_made(self, transport):
+        super().connection_made(transport)
+        transport.write(b'GET / HTTP/1.0\r\nHost: example.com\r\n\r\n')
+
+
 class MyDatagramProto(asyncio.DatagramProtocol):
     done = None
 
@@ -357,22 +363,30 @@ class EventLoopTestsMixin:
         r.close()
         self.assertGreaterEqual(len(data), 200)
 
+    def _basetest_sock_client_ops(self, httpd, sock):
+        sock.setblocking(False)
+        self.loop.run_until_complete(
+            self.loop.sock_connect(sock, httpd.address))
+        self.loop.run_until_complete(
+            self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n'))
+        data = self.loop.run_until_complete(
+            self.loop.sock_recv(sock, 1024))
+        # consume data
+        self.loop.run_until_complete(
+            self.loop.sock_recv(sock, 1024))
+        sock.close()
+        self.assertTrue(data.startswith(b'HTTP/1.0 200 OK'))
+
     def test_sock_client_ops(self):
         with test_utils.run_test_server() as httpd:
             sock = socket.socket()
-            sock.setblocking(False)
-            self.loop.run_until_complete(
-                self.loop.sock_connect(sock, httpd.address))
-            self.loop.run_until_complete(
-                self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n'))
-            data = self.loop.run_until_complete(
-                self.loop.sock_recv(sock, 1024))
-            # consume data
-            self.loop.run_until_complete(
-                self.loop.sock_recv(sock, 1024))
-            sock.close()
+            self._basetest_sock_client_ops(httpd, sock)
 
-        self.assertTrue(data.startswith(b'HTTP/1.0 200 OK'))
+    @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets')
+    def test_unix_sock_client_ops(self):
+        with test_utils.run_test_unix_server() as httpd:
+            sock = socket.socket(socket.AF_UNIX)
+            self._basetest_sock_client_ops(httpd, sock)
 
     def test_sock_client_fail(self):
         # Make sure that we will get an unused port
@@ -485,16 +499,26 @@ class EventLoopTestsMixin:
         self.loop.run_forever()
         self.assertEqual(caught, 1)
 
+    def _basetest_create_connection(self, connection_fut):
+        tr, pr = self.loop.run_until_complete(connection_fut)
+        self.assertIsInstance(tr, asyncio.Transport)
+        self.assertIsInstance(pr, asyncio.Protocol)
+        self.loop.run_until_complete(pr.done)
+        self.assertGreater(pr.nbytes, 0)
+        tr.close()
+
     def test_create_connection(self):
         with test_utils.run_test_server() as httpd:
-            f = self.loop.create_connection(
+            conn_fut = self.loop.create_connection(
                 lambda: MyProto(loop=self.loop), *httpd.address)
-            tr, pr = self.loop.run_until_complete(f)
-            self.assertIsInstance(tr, asyncio.Transport)
-            self.assertIsInstance(pr, asyncio.Protocol)
-            self.loop.run_until_complete(pr.done)
-            self.assertGreater(pr.nbytes, 0)
-            tr.close()
+            self._basetest_create_connection(conn_fut)
+
+    @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets')
+    def test_create_unix_connection(self):
+        with test_utils.run_test_unix_server() as httpd:
+            conn_fut = self.loop.create_unix_connection(
+                lambda: MyProto(loop=self.loop), httpd.address)
+            self._basetest_create_connection(conn_fut)
 
     def test_create_connection_sock(self):
         with test_utils.run_test_server() as httpd:
@@ -524,20 +548,37 @@ class EventLoopTestsMixin:
             self.assertGreater(pr.nbytes, 0)
             tr.close()
 
+    def _basetest_create_ssl_connection(self, connection_fut):
+        tr, pr = self.loop.run_until_complete(connection_fut)
+        self.assertIsInstance(tr, asyncio.Transport)
+        self.assertIsInstance(pr, asyncio.Protocol)
+        self.assertTrue('ssl' in tr.__class__.__name__.lower())
+        self.assertIsNotNone(tr.get_extra_info('sockname'))
+        self.loop.run_until_complete(pr.done)
+        self.assertGreater(pr.nbytes, 0)
+        tr.close()
+
     @unittest.skipIf(ssl is None, 'No ssl module')
     def test_create_ssl_connection(self):
         with test_utils.run_test_server(use_ssl=True) as httpd:
-            f = self.loop.create_connection(
-                lambda: MyProto(loop=self.loop), *httpd.address,
+            conn_fut = self.loop.create_connection(
+                lambda: MyProto(loop=self.loop),
+                *httpd.address,
                 ssl=test_utils.dummy_ssl_context())
-            tr, pr = self.loop.run_until_complete(f)
-            self.assertIsInstance(tr, asyncio.Transport)
-            self.assertIsInstance(pr, asyncio.Protocol)
-            self.assertTrue('ssl' in tr.__class__.__name__.lower())
-            self.assertIsNotNone(tr.get_extra_info('sockname'))
-            self.loop.run_until_complete(pr.done)
-            self.assertGreater(pr.nbytes, 0)
-            tr.close()
+
+            self._basetest_create_ssl_connection(conn_fut)
+
+    @unittest.skipIf(ssl is None, 'No ssl module')
+    @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets')
+    def test_create_ssl_unix_connection(self):
+        with test_utils.run_test_unix_server(use_ssl=True) as httpd:
+            conn_fut = self.loop.create_unix_connection(
+                lambda: MyProto(loop=self.loop),
+                httpd.address,
+                ssl=test_utils.dummy_ssl_context(),
+                server_hostname='127.0.0.1')
+
+            self._basetest_create_ssl_connection(conn_fut)
 
     def test_create_connection_local_addr(self):
         with test_utils.run_test_server() as httpd:
@@ -561,14 +602,8 @@ class EventLoopTestsMixin:
             self.assertIn(str(httpd.address), cm.exception.strerror)
 
     def test_create_server(self):
-        proto = None
-
-        def factory():
-            nonlocal proto
-            proto = MyProto()
-            return proto
-
-        f = self.loop.create_server(factory, '0.0.0.0', 0)
+        proto = MyProto()
+        f = self.loop.create_server(lambda: proto, '0.0.0.0', 0)
         server = self.loop.run_until_complete(f)
         self.assertEqual(len(server.sockets), 1)
         sock = server.sockets[0]
@@ -605,38 +640,76 @@ class EventLoopTestsMixin:
         # close server
         server.close()
 
-    def _make_ssl_server(self, factory, certfile, keyfile=None):
+    def _make_unix_server(self, factory, **kwargs):
+        path = test_utils.gen_unix_socket_path()
+        self.addCleanup(lambda: os.path.exists(path) and os.unlink(path))
+
+        f = self.loop.create_unix_server(factory, path, **kwargs)
+        server = self.loop.run_until_complete(f)
+
+        return server, path
+
+    @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets')
+    def test_create_unix_server(self):
+        proto = MyProto()
+        server, path = self._make_unix_server(lambda: proto)
+        self.assertEqual(len(server.sockets), 1)
+
+        client = socket.socket(socket.AF_UNIX)
+        client.connect(path)
+        client.sendall(b'xxx')
+        test_utils.run_briefly(self.loop)
+        test_utils.run_until(self.loop, lambda: proto is not None, 10)
+
+        self.assertIsInstance(proto, MyProto)
+        self.assertEqual('INITIAL', proto.state)
+        test_utils.run_briefly(self.loop)
+        self.assertEqual('CONNECTED', proto.state)
+        test_utils.run_until(self.loop, lambda: proto.nbytes > 0,
+                             timeout=10)
+        self.assertEqual(3, proto.nbytes)
+
+        # close connection
+        proto.transport.close()
+        test_utils.run_briefly(self.loop)  # windows iocp
+
+        self.assertEqual('CLOSED', proto.state)
+
+        # the client socket must be closed after to avoid ECONNRESET upon
+        # recv()/send() on the serving socket
+        client.close()
+
+        # close server
+        server.close()
+
+    def _create_ssl_context(self, certfile, keyfile=None):
         sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
         sslcontext.options |= ssl.OP_NO_SSLv2
         sslcontext.load_cert_chain(certfile, keyfile)
+        return sslcontext
 
-        f = self.loop.create_server(
-            factory, '127.0.0.1', 0, ssl=sslcontext)
+    def _make_ssl_server(self, factory, certfile, keyfile=None):
+        sslcontext = self._create_ssl_context(certfile, keyfile)
 
+        f = self.loop.create_server(factory, '127.0.0.1', 0, ssl=sslcontext)
         server = self.loop.run_until_complete(f)
+
         sock = server.sockets[0]
         host, port = sock.getsockname()
         self.assertEqual(host, '127.0.0.1')
         return server, host, port
 
+    def _make_ssl_unix_server(self, factory, certfile, keyfile=None):
+        sslcontext = self._create_ssl_context(certfile, keyfile)
+        return self._make_unix_server(factory, ssl=sslcontext)
+
     @unittest.skipIf(ssl is None, 'No ssl module')
     def test_create_server_ssl(self):
-        proto = None
-
-        class ClientMyProto(MyProto):
-            def connection_made(self, transport):
-                self.transport = transport
-                assert self.state == 'INITIAL', self.state
-                self.state = 'CONNECTED'
+        proto = MyProto(loop=self.loop)
+        server, host, port = self._make_ssl_server(
+            lambda: proto, ONLYCERT, ONLYKEY)
 
-        def factory():
-            nonlocal proto
-            proto = MyProto(loop=self.loop)
-            return proto
-
-        server, host, port = self._make_ssl_server(factory, ONLYCERT, ONLYKEY)
-
-        f_c = self.loop.create_connection(ClientMyProto, host, port,
+        f_c = self.loop.create_connection(MyBaseProto, host, port,
                                           ssl=test_utils.dummy_ssl_context())
         client, pr = self.loop.run_until_complete(f_c)
 
@@ -667,16 +740,45 @@ class EventLoopTestsMixin:
         server.close()
 
     @unittest.skipIf(ssl is None, 'No ssl module')
-    @unittest.skipUnless(HAS_SNI, 'No SNI support in ssl module')
-    def test_create_server_ssl_verify_failed(self):
-        proto = None
+    @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets')
+    def test_create_unix_server_ssl(self):
+        proto = MyProto(loop=self.loop)
+        server, path = self._make_ssl_unix_server(
+            lambda: proto, ONLYCERT, ONLYKEY)
 
-        def factory():
-            nonlocal proto
-            proto = MyProto(loop=self.loop)
-            return proto
+        f_c = self.loop.create_unix_connection(
+            MyBaseProto, path, ssl=test_utils.dummy_ssl_context(),
+            server_hostname='')
+
+        client, pr = self.loop.run_until_complete(f_c)
+
+        client.write(b'xxx')
+        test_utils.run_briefly(self.loop)
+        self.assertIsInstance(proto, MyProto)
+        test_utils.run_briefly(self.loop)
+        self.assertEqual('CONNECTED', proto.state)
+        test_utils.run_until(self.loop, lambda: proto.nbytes > 0,
+                             timeout=10)
+        self.assertEqual(3, proto.nbytes)
+
+        # close connection
+        proto.transport.close()
+        self.loop.run_until_complete(proto.done)
+        self.assertEqual('CLOSED', proto.state)
+
+        # the client socket must be closed after to avoid ECONNRESET upon
+        # recv()/send() on the serving socket
+        client.close()
 
-        server, host, port = self._make_ssl_server(factory, SIGNED_CERTFILE)
+        # stop serving
+        server.close()
+
+    @unittest.skipIf(ssl is None, 'No ssl module')
+    @unittest.skipUnless(HAS_SNI, 'No SNI support in ssl module')
+    def test_create_server_ssl_verify_failed(self):
+        proto = MyProto(loop=self.loop)
+        server, host, port = self._make_ssl_server(
+            lambda: proto, SIGNED_CERTFILE)
 
         sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
         sslcontext_client.options |= ssl.OP_NO_SSLv2
@@ -697,15 +799,36 @@ class EventLoopTestsMixin:
 
     @unittest.skipIf(ssl is None, 'No ssl module')
     @unittest.skipUnless(HAS_SNI, 'No SNI support in ssl module')
-    def test_create_server_ssl_match_failed(self):
-        proto = None
+    @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets')
+    def test_create_unix_server_ssl_verify_failed(self):
+        proto = MyProto(loop=self.loop)
+        server, path = self._make_ssl_unix_server(
+            lambda: proto, SIGNED_CERTFILE)
 
-        def factory():
-            nonlocal proto
-            proto = MyProto(loop=self.loop)
-            return proto
+        sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
+        sslcontext_client.options |= ssl.OP_NO_SSLv2
+        sslcontext_client.verify_mode = ssl.CERT_REQUIRED
+        if hasattr(sslcontext_client, 'check_hostname'):
+            sslcontext_client.check_hostname = True
 
-        server, host, port = self._make_ssl_server(factory, SIGNED_CERTFILE)
+        # no CA loaded
+        f_c = self.loop.create_unix_connection(MyProto, path,
+                                               ssl=sslcontext_client,
+                                               server_hostname='invalid')
+        with self.assertRaisesRegex(ssl.SSLError,
+                                    'certificate verify failed '):
+            self.loop.run_until_complete(f_c)
+
+        # close connection
+        self.assertIsNone(proto.transport)
+        server.close()
+
+    @unittest.skipIf(ssl is None, 'No ssl module')
+    @unittest.skipUnless(HAS_SNI, 'No SNI support in ssl module')
+    def test_create_server_ssl_match_failed(self):
+        proto = MyProto(loop=self.loop)
+        server, host, port = self._make_ssl_server(
+            lambda: proto, SIGNED_CERTFILE)
 
         sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
         sslcontext_client.options |= ssl.OP_NO_SSLv2
@@ -729,15 +852,36 @@ class EventLoopTestsMixin:
 
     @unittest.skipIf(ssl is None, 'No ssl module')
     @unittest.skipUnless(HAS_SNI, 'No SNI support in ssl module')
-    def test_create_server_ssl_verified(self):
-        proto = None
+    @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets')
+    def test_create_unix_server_ssl_verified(self):
+        proto = MyProto(loop=self.loop)
+        server, path = self._make_ssl_unix_server(
+            lambda: proto, SIGNED_CERTFILE)
 
-        def factory():
-            nonlocal proto
-            proto = MyProto(loop=self.loop)
-            return proto
+        sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
+        sslcontext_client.options |= ssl.OP_NO_SSLv2
+        sslcontext_client.verify_mode = ssl.CERT_REQUIRED
+        sslcontext_client.load_verify_locations(cafile=SIGNING_CA)
+        if hasattr(sslcontext_client, 'check_hostname'):
+            sslcontext_client.check_hostname = True
 
-        server, host, port = self._make_ssl_server(factory, SIGNED_CERTFILE)
+        # Connection succeeds with correct CA and server hostname.
+        f_c = self.loop.create_unix_connection(MyProto, path,
+                                               ssl=sslcontext_client,
+                                               server_hostname='localhost')
+        client, pr = self.loop.run_until_complete(f_c)
+
+        # close connection
+        proto.transport.close()
+        client.close()
+        server.close()
+
+    @unittest.skipIf(ssl is None, 'No ssl module')
+    @unittest.skipUnless(HAS_SNI, 'No SNI support in ssl module')
+    def test_create_server_ssl_verified(self):
+        proto = MyProto(loop=self.loop)
+        server, host, port = self._make_ssl_server(
+            lambda: proto, SIGNED_CERTFILE)
 
         sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
         sslcontext_client.options |= ssl.OP_NO_SSLv2
@@ -915,19 +1059,15 @@ class EventLoopTestsMixin:
     @unittest.skipUnless(sys.platform != 'win32',
                          "Don't support pipes for Windows")
     def test_read_pipe(self):
-        proto = None
-
-        def factory():
-            nonlocal proto
-            proto = MyReadPipeProto(loop=self.loop)
-            return proto
+        proto = MyReadPipeProto(loop=self.loop)
 
         rpipe, wpipe = os.pipe()
         pipeobj = io.open(rpipe, 'rb', 1024)
 
         @asyncio.coroutine
         def connect():
-            t, p = yield from self.loop.connect_read_pipe(factory, pipeobj)
+            t, p = yield from self.loop.connect_read_pipe(
+                lambda: proto, pipeobj)
             self.assertIs(p, proto)
             self.assertIs(t, proto.transport)
             self.assertEqual(['INITIAL', 'CONNECTED'], proto.state)
@@ -959,19 +1099,14 @@ class EventLoopTestsMixin:
     # Issue #20495: The test hangs on FreeBSD 7.2 but pass on FreeBSD 9
     @support.requires_freebsd_version(8)
     def test_read_pty_output(self):
-        proto = None
-
-        def factory():
-            nonlocal proto
-            proto = MyReadPipeProto(loop=self.loop)
-            return proto
+        proto = MyReadPipeProto(loop=self.loop)
 
         master, slave = os.openpty()
         master_read_obj = io.open(master, 'rb', 0)
 
         @asyncio.coroutine
         def connect():
-            t, p = yield from self.loop.connect_read_pipe(factory,
+            t, p = yield from self.loop.connect_read_pipe(lambda: proto,
                                                           master_read_obj)
             self.assertIs(p, proto)
             self.assertIs(t, proto.transport)
@@ -999,21 +1134,17 @@ class EventLoopTestsMixin:
     @unittest.skipUnless(sys.platform != 'win32',
                          "Don't support pipes for Windows")
     def test_write_pipe(self):
-        proto = None
+        proto = MyWritePipeProto(loop=self.loop)
         transport = None
 
-        def factory():
-            nonlocal proto
-            proto = MyWritePipeProto(loop=self.loop)
-            return proto
-
         rpipe, wpipe = os.pipe()
         pipeobj = io.open(wpipe, 'wb', 1024)
 
         @asyncio.coroutine
         def connect():
             nonlocal transport
-            t, p = yield from self.loop.connect_write_pipe(factory, pipeobj)
+            t, p = yield from self.loop.connect_write_pipe(
+                        lambda: proto, pipeobj)
             self.assertIs(p, proto)
             self.assertIs(t, proto.transport)
             self.assertEqual('CONNECTED', proto.state)
@@ -1045,21 +1176,16 @@ class EventLoopTestsMixin:
     @unittest.skipUnless(sys.platform != 'win32',
                          "Don't support pipes for Windows")
     def test_write_pipe_disconnect_on_close(self):
-        proto = None
+        proto = MyWritePipeProto(loop=self.loop)
         transport = None
 
-        def factory():
-            nonlocal proto
-            proto = MyWritePipeProto(loop=self.loop)
-            return proto
-
         rsock, wsock = test_utils.socketpair()
         pipeobj = io.open(wsock.detach(), 'wb', 1024)
 
         @asyncio.coroutine
         def connect():
             nonlocal transport
-            t, p = yield from self.loop.connect_write_pipe(factory,
+            t, p = yield from self.loop.connect_write_pipe(lambda: proto,
                                                            pipeobj)
             self.assertIs(p, proto)
             self.assertIs(t, proto.transport)
@@ -1084,21 +1210,16 @@ class EventLoopTestsMixin:
     # older than 10.6 (Snow Leopard)
     @support.requires_mac_ver(10, 6)
     def test_write_pty(self):
-        proto = None
+        proto = MyWritePipeProto(loop=self.loop)
         transport = None
 
-        def factory():
-            nonlocal proto
-            proto = MyWritePipeProto(loop=self.loop)
-            return proto
-
         master, slave = os.openpty()
         slave_write_obj = io.open(slave, 'wb', 0)
 
         @asyncio.coroutine
         def connect():
             nonlocal transport
-            t, p = yield from self.loop.connect_write_pipe(factory,
+            t, p = yield from self.loop.connect_write_pipe(lambda: proto,
                                                            slave_write_obj)
             self.assertIs(p, proto)
             self.assertIs(t, proto.transport)
index 855a8954e868291131f347e102e645fed8782f6d..7741e191b56ff41e2063a7423d8e1c7b58907730 100644 (file)
@@ -55,7 +55,8 @@ class BaseSelectorEventLoopTests(unittest.TestCase):
         self.loop.remove_reader = unittest.mock.Mock()
         self.loop.remove_writer = unittest.mock.Mock()
         waiter = asyncio.Future(loop=self.loop)
-        transport = self.loop._make_ssl_transport(m, asyncio.Protocol(), m, waiter)
+        transport = self.loop._make_ssl_transport(
+            m, asyncio.Protocol(), m, waiter)
         self.assertIsInstance(transport, _SelectorSslTransport)
 
     @unittest.mock.patch('asyncio.selector_events.ssl', None)
index ee3fb450291aa31ac7cd7d6d6a98bf8fe04c8174..31e26a641ad38e0ecdbe4db5b9f369af915a6c1c 100644 (file)
@@ -1,6 +1,8 @@
 """Tests for streams.py."""
 
+import functools
 import gc
+import socket
 import unittest
 import unittest.mock
 try:
@@ -32,48 +34,85 @@ class StreamReaderTests(unittest.TestCase):
         stream = asyncio.StreamReader()
         self.assertIs(stream._loop, m_events.get_event_loop.return_value)
 
+    def _basetest_open_connection(self, open_connection_fut):
+        reader, writer = self.loop.run_until_complete(open_connection_fut)
+        writer.write(b'GET / HTTP/1.0\r\n\r\n')
+        f = reader.readline()
+        data = self.loop.run_until_complete(f)
+        self.assertEqual(data, b'HTTP/1.0 200 OK\r\n')
+        f = reader.read()
+        data = self.loop.run_until_complete(f)
+        self.assertTrue(data.endswith(b'\r\n\r\nTest message'))
+        writer.close()
+
     def test_open_connection(self):
         with test_utils.run_test_server() as httpd:
-            f = asyncio.open_connection(*httpd.address, loop=self.loop)
-            reader, writer = self.loop.run_until_complete(f)
-            writer.write(b'GET / HTTP/1.0\r\n\r\n')
-            f = reader.readline()
-            data = self.loop.run_until_complete(f)
-            self.assertEqual(data, b'HTTP/1.0 200 OK\r\n')
-            f = reader.read()
-            data = self.loop.run_until_complete(f)
-            self.assertTrue(data.endswith(b'\r\n\r\nTest message'))
-
-            writer.close()
+            conn_fut = asyncio.open_connection(*httpd.address,
+                                               loop=self.loop)
+            self._basetest_open_connection(conn_fut)
+
+    @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets')
+    def test_open_unix_connection(self):
+        with test_utils.run_test_unix_server() as httpd:
+            conn_fut = asyncio.open_unix_connection(httpd.address,
+                                                    loop=self.loop)
+            self._basetest_open_connection(conn_fut)
+
+    def _basetest_open_connection_no_loop_ssl(self, open_connection_fut):
+        try:
+            reader, writer = self.loop.run_until_complete(open_connection_fut)
+        finally:
+            asyncio.set_event_loop(None)
+        writer.write(b'GET / HTTP/1.0\r\n\r\n')
+        f = reader.read()
+        data = self.loop.run_until_complete(f)
+        self.assertTrue(data.endswith(b'\r\n\r\nTest message'))
+
+        writer.close()
 
     @unittest.skipIf(ssl is None, 'No ssl module')
     def test_open_connection_no_loop_ssl(self):
         with test_utils.run_test_server(use_ssl=True) as httpd:
-            try:
-                asyncio.set_event_loop(self.loop)
-                f = asyncio.open_connection(*httpd.address,
-                                            ssl=test_utils.dummy_ssl_context())
-                reader, writer = self.loop.run_until_complete(f)
-            finally:
-                asyncio.set_event_loop(None)
-            writer.write(b'GET / HTTP/1.0\r\n\r\n')
-            f = reader.read()
-            data = self.loop.run_until_complete(f)
-            self.assertTrue(data.endswith(b'\r\n\r\nTest message'))
+            conn_fut = asyncio.open_connection(
+                *httpd.address,
+                ssl=test_utils.dummy_ssl_context(),
+                loop=self.loop)
 
-            writer.close()
+            self._basetest_open_connection_no_loop_ssl(conn_fut)
+
+    @unittest.skipIf(ssl is None, 'No ssl module')
+    @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets')
+    def test_open_unix_connection_no_loop_ssl(self):
+        with test_utils.run_test_unix_server(use_ssl=True) as httpd:
+            conn_fut = asyncio.open_unix_connection(
+                httpd.address,
+                ssl=test_utils.dummy_ssl_context(),
+                server_hostname='',
+                loop=self.loop)
+
+            self._basetest_open_connection_no_loop_ssl(conn_fut)
+
+    def _basetest_open_connection_error(self, open_connection_fut):
+        reader, writer = self.loop.run_until_complete(open_connection_fut)
+        writer._protocol.connection_lost(ZeroDivisionError())
+        f = reader.read()
+        with self.assertRaises(ZeroDivisionError):
+            self.loop.run_until_complete(f)
+        writer.close()
+        test_utils.run_briefly(self.loop)
 
     def test_open_connection_error(self):
         with test_utils.run_test_server() as httpd:
-            f = asyncio.open_connection(*httpd.address, loop=self.loop)
-            reader, writer = self.loop.run_until_complete(f)
-            writer._protocol.connection_lost(ZeroDivisionError())
-            f = reader.read()
-            with self.assertRaises(ZeroDivisionError):
-                self.loop.run_until_complete(f)
+            conn_fut = asyncio.open_connection(*httpd.address,
+                                               loop=self.loop)
+            self._basetest_open_connection_error(conn_fut)
 
-            writer.close()
-            test_utils.run_briefly(self.loop)
+    @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets')
+    def test_open_unix_connection_error(self):
+        with test_utils.run_test_unix_server() as httpd:
+            conn_fut = asyncio.open_unix_connection(httpd.address,
+                                                    loop=self.loop)
+            self._basetest_open_connection_error(conn_fut)
 
     def test_feed_empty_data(self):
         stream = asyncio.StreamReader(loop=self.loop)
@@ -415,10 +454,13 @@ class StreamReaderTests(unittest.TestCase):
                 client_writer.write(data)
 
             def start(self):
+                sock = socket.socket()
+                sock.bind(('127.0.0.1', 0))
                 self.server = self.loop.run_until_complete(
                     asyncio.start_server(self.handle_client,
-                                         '127.0.0.1', 12345,
+                                         sock=sock,
                                          loop=self.loop))
+                return sock.getsockname()
 
             def handle_client_callback(self, client_reader, client_writer):
                 task = asyncio.Task(client_reader.readline(), loop=self.loop)
@@ -429,10 +471,15 @@ class StreamReaderTests(unittest.TestCase):
                 task.add_done_callback(done)
 
             def start_callback(self):
+                sock = socket.socket()
+                sock.bind(('127.0.0.1', 0))
+                addr = sock.getsockname()
+                sock.close()
                 self.server = self.loop.run_until_complete(
                     asyncio.start_server(self.handle_client_callback,
-                                         '127.0.0.1', 12345,
+                                         host=addr[0], port=addr[1],
                                          loop=self.loop))
+                return addr
 
             def stop(self):
                 if self.server is not None:
@@ -441,9 +488,9 @@ class StreamReaderTests(unittest.TestCase):
                     self.server = None
 
         @asyncio.coroutine
-        def client():
+        def client(addr):
             reader, writer = yield from asyncio.open_connection(
-                '127.0.0.1', 12345, loop=self.loop)
+                *addr, loop=self.loop)
             # send a line
             writer.write(b"hello world!\n")
             # read it back
@@ -453,20 +500,90 @@ class StreamReaderTests(unittest.TestCase):
 
         # test the server variant with a coroutine as client handler
         server = MyServer(self.loop)
-        server.start()
-        msg = self.loop.run_until_complete(asyncio.Task(client(),
+        addr = server.start()
+        msg = self.loop.run_until_complete(asyncio.Task(client(addr),
                                                         loop=self.loop))
         server.stop()
         self.assertEqual(msg, b"hello world!\n")
 
         # test the server variant with a callback as client handler
         server = MyServer(self.loop)
-        server.start_callback()
-        msg = self.loop.run_until_complete(asyncio.Task(client(),
+        addr = server.start_callback()
+        msg = self.loop.run_until_complete(asyncio.Task(client(addr),
                                                         loop=self.loop))
         server.stop()
         self.assertEqual(msg, b"hello world!\n")
 
+    @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets')
+    def test_start_unix_server(self):
+
+        class MyServer:
+
+            def __init__(self, loop, path):
+                self.server = None
+                self.loop = loop
+                self.path = path
+
+            @asyncio.coroutine
+            def handle_client(self, client_reader, client_writer):
+                data = yield from client_reader.readline()
+                client_writer.write(data)
+
+            def start(self):
+                self.server = self.loop.run_until_complete(
+                    asyncio.start_unix_server(self.handle_client,
+                                              path=self.path,
+                                              loop=self.loop))
+
+            def handle_client_callback(self, client_reader, client_writer):
+                task = asyncio.Task(client_reader.readline(), loop=self.loop)
+
+                def done(task):
+                    client_writer.write(task.result())
+
+                task.add_done_callback(done)
+
+            def start_callback(self):
+                self.server = self.loop.run_until_complete(
+                    asyncio.start_unix_server(self.handle_client_callback,
+                                              path=self.path,
+                                              loop=self.loop))
+
+            def stop(self):
+                if self.server is not None:
+                    self.server.close()
+                    self.loop.run_until_complete(self.server.wait_closed())
+                    self.server = None
+
+        @asyncio.coroutine
+        def client(path):
+            reader, writer = yield from asyncio.open_unix_connection(
+                path, loop=self.loop)
+            # send a line
+            writer.write(b"hello world!\n")
+            # read it back
+            msgback = yield from reader.readline()
+            writer.close()
+            return msgback
+
+        # test the server variant with a coroutine as client handler
+        with test_utils.unix_socket_path() as path:
+            server = MyServer(self.loop, path)
+            server.start()
+            msg = self.loop.run_until_complete(asyncio.Task(client(path),
+                                                            loop=self.loop))
+            server.stop()
+            self.assertEqual(msg, b"hello world!\n")
+
+        # test the server variant with a callback as client handler
+        with test_utils.unix_socket_path() as path:
+            server = MyServer(self.loop, path)
+            server.start_callback()
+            msg = self.loop.run_until_complete(asyncio.Task(client(path),
+                                                            loop=self.loop))
+            server.stop()
+            self.assertEqual(msg, b"hello world!\n")
+
 
 if __name__ == '__main__':
     unittest.main()
index 9461ec8b87ded040f5143839525ba391e220023a..2fa1db454b8859423d778017864bada5802ae683 100644 (file)
@@ -7,8 +7,10 @@ import io
 import os
 import pprint
 import signal
+import socket
 import stat
 import sys
+import tempfile
 import threading
 import unittest
 import unittest.mock
@@ -24,7 +26,7 @@ from asyncio import unix_events
 
 
 @unittest.skipUnless(signal, 'Signals are not supported')
-class SelectorEventLoopTests(unittest.TestCase):
+class SelectorEventLoopSignalTests(unittest.TestCase):
 
     def setUp(self):
         self.loop = asyncio.SelectorEventLoop()
@@ -200,6 +202,84 @@ class SelectorEventLoopTests(unittest.TestCase):
         m_signal.set_wakeup_fd.assert_called_once_with(-1)
 
 
+@unittest.skipUnless(hasattr(socket, 'AF_UNIX'),
+                     'UNIX Sockets are not supported')
+class SelectorEventLoopUnixSocketTests(unittest.TestCase):
+
+    def setUp(self):
+        self.loop = asyncio.SelectorEventLoop()
+        asyncio.set_event_loop(None)
+
+    def tearDown(self):
+        self.loop.close()
+
+    def test_create_unix_server_existing_path_sock(self):
+        with test_utils.unix_socket_path() as path:
+            sock = socket.socket(socket.AF_UNIX)
+            sock.bind(path)
+
+            coro = self.loop.create_unix_server(lambda: None, path)
+            with self.assertRaisesRegexp(OSError,
+                                         'Address.*is already in use'):
+                self.loop.run_until_complete(coro)
+
+    def test_create_unix_server_existing_path_nonsock(self):
+        with tempfile.NamedTemporaryFile() as file:
+            coro = self.loop.create_unix_server(lambda: None, file.name)
+            with self.assertRaisesRegexp(OSError,
+                                         'Address.*is already in use'):
+                self.loop.run_until_complete(coro)
+
+    def test_create_unix_server_ssl_bool(self):
+        coro = self.loop.create_unix_server(lambda: None, path='spam',
+                                            ssl=True)
+        with self.assertRaisesRegex(TypeError,
+                                    'ssl argument must be an SSLContext'):
+            self.loop.run_until_complete(coro)
+
+    def test_create_unix_server_nopath_nosock(self):
+        coro = self.loop.create_unix_server(lambda: None, path=None)
+        with self.assertRaisesRegex(ValueError,
+                                    'path was not specified, and no sock'):
+            self.loop.run_until_complete(coro)
+
+    def test_create_unix_server_path_inetsock(self):
+        coro = self.loop.create_unix_server(lambda: None, path=None,
+                                            sock=socket.socket())
+        with self.assertRaisesRegex(ValueError,
+                                    'A UNIX Domain Socket was expected'):
+            self.loop.run_until_complete(coro)
+
+    def test_create_unix_connection_path_sock(self):
+        coro = self.loop.create_unix_connection(
+            lambda: None, '/dev/null', sock=object())
+        with self.assertRaisesRegex(ValueError, 'path and sock can not be'):
+            self.loop.run_until_complete(coro)
+
+    def test_create_unix_connection_nopath_nosock(self):
+        coro = self.loop.create_unix_connection(
+            lambda: None, None)
+        with self.assertRaisesRegex(ValueError,
+                                    'no path and sock were specified'):
+            self.loop.run_until_complete(coro)
+
+    def test_create_unix_connection_nossl_serverhost(self):
+        coro = self.loop.create_unix_connection(
+            lambda: None, '/dev/null', server_hostname='spam')
+        with self.assertRaisesRegex(ValueError,
+                                    'server_hostname is only meaningful'):
+            self.loop.run_until_complete(coro)
+
+    def test_create_unix_connection_ssl_noserverhost(self):
+        coro = self.loop.create_unix_connection(
+            lambda: None, '/dev/null', ssl=True)
+
+        with self.assertRaisesRegexp(
+            ValueError, 'you have to pass server_hostname when using ssl'):
+
+            self.loop.run_until_complete(coro)
+
+
 class UnixReadPipeTransportTests(unittest.TestCase):
 
     def setUp(self):