asyncio: Synchronize with Tulip
authorVictor Stinner <victor.stinner@gmail.com>
Wed, 5 Mar 2014 23:52:53 +0000 (00:52 +0100)
committerVictor Stinner <victor.stinner@gmail.com>
Wed, 5 Mar 2014 23:52:53 +0000 (00:52 +0100)
* Issue #159: Fix windows_utils.socketpair()

  - Use "127.0.0.1" (IPv4) or "::1" (IPv6) host instead of "localhost", because
    "localhost" may be a different IP address
  - Reject also invalid arguments: only AF_INET/AF_INET6 with SOCK_STREAM (and
    proto=0) are supported

* Reject add/remove reader/writer when event loop is closed.
* Fix ResourceWarning warnings

Lib/asyncio/selector_events.py
Lib/asyncio/windows_utils.py
Lib/test/test_asyncio/test_events.py
Lib/test/test_asyncio/test_windows_utils.py

index 70d8a9588fd1cc1cb69d92fbf90c6ecdd253a543..367c5fbe3f5fe7ffc7eab73e39be3c34ce82434d 100644 (file)
@@ -136,6 +136,8 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
 
     def add_reader(self, fd, callback, *args):
         """Add a reader callback."""
+        if self._selector is None:
+            raise RuntimeError('Event loop is closed')
         handle = events.Handle(callback, args, self)
         try:
             key = self._selector.get_key(fd)
@@ -151,6 +153,8 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
 
     def remove_reader(self, fd):
         """Remove a reader callback."""
+        if self._selector is None:
+            return False
         try:
             key = self._selector.get_key(fd)
         except KeyError:
@@ -171,6 +175,8 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
 
     def add_writer(self, fd, callback, *args):
         """Add a writer callback.."""
+        if self._selector is None:
+            raise RuntimeError('Event loop is closed')
         handle = events.Handle(callback, args, self)
         try:
             key = self._selector.get_key(fd)
@@ -186,6 +192,8 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
 
     def remove_writer(self, fd):
         """Remove a writer callback."""
+        if self._selector is None:
+            return False
         try:
             key = self._selector.get_key(fd)
         except KeyError:
index aa1c0648067352b64e9623e207dbc5054c069bae..2a196cc76b42e75f3d7da1dd3ef5ebcba3f0b8d0 100644 (file)
@@ -36,12 +36,25 @@ def socketpair(family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0):
 
     Origin: https://gist.github.com/4325783, by Geert Jansen.  Public domain.
     """
+    if family == socket.AF_INET:
+        host = '127.0.0.1'
+    elif family == socket.AF_INET6:
+        host = '::1'
+    else:
+        raise ValueError("Ony AF_INET and AF_INET6 socket address families "
+                         "are supported")
+    if type != socket.SOCK_STREAM:
+        raise ValueError("Only SOCK_STREAM socket type is supported")
+    if proto != 0:
+        raise ValueError("Only protocol zero is supported")
+
     # We create a connected TCP socket. Note the trick with setblocking(0)
     # that prevents us from having to create a thread.
     lsock = socket.socket(family, type, proto)
-    lsock.bind(('localhost', 0))
+    lsock.bind((host, 0))
     lsock.listen(1)
-    addr, port = lsock.getsockname()
+    # On IPv6, ignore flow_info and scope_id
+    addr, port = lsock.getsockname()[:2]
     csock = socket.socket(family, type, proto)
     csock.setblocking(False)
     try:
index f01d1f38a8f5807a7f07cf3bda9d83dff2a8bad4..fd7022ff865236a3e301e95f49cb461ba5a960ea 100644 (file)
@@ -1326,6 +1326,30 @@ class EventLoopTestsMixin:
                     self.assertIn('address must be resolved',
                                   str(cm.exception))
 
+    def test_remove_fds_after_closing(self):
+        loop = self.create_event_loop()
+        callback = lambda: None
+        r, w = test_utils.socketpair()
+        self.addCleanup(r.close)
+        self.addCleanup(w.close)
+        loop.add_reader(r, callback)
+        loop.add_writer(w, callback)
+        loop.close()
+        self.assertFalse(loop.remove_reader(r))
+        self.assertFalse(loop.remove_writer(w))
+
+    def test_add_fds_after_closing(self):
+        loop = self.create_event_loop()
+        callback = lambda: None
+        r, w = test_utils.socketpair()
+        self.addCleanup(r.close)
+        self.addCleanup(w.close)
+        loop.close()
+        with self.assertRaises(RuntimeError):
+            loop.add_reader(r, callback)
+        with self.assertRaises(RuntimeError):
+            loop.add_writer(w, callback)
+
 
 class SubprocessTestsMixin:
 
@@ -1632,6 +1656,9 @@ if sys.platform == 'win32':
         def test_create_datagram_endpoint(self):
             raise unittest.SkipTest(
                 "IocpEventLoop does not have create_datagram_endpoint()")
+
+        def test_remove_fds_after_closing(self):
+            raise unittest.SkipTest("IocpEventLoop does not have add_reader()")
 else:
     from asyncio import selectors
 
index 7616c73e45ca4ef0865285f0554641614646edc6..9daf4340a4a28dc7571376e747fd12d4c8c7fc54 100644 (file)
@@ -1,8 +1,10 @@
 """Tests for window_utils"""
 
+import socket
 import sys
 import test.support
 import unittest
+from test.support import IPV6_ENABLED
 from unittest import mock
 
 if sys.platform != 'win32':
@@ -16,23 +18,40 @@ from asyncio import _overlapped
 
 class WinsocketpairTests(unittest.TestCase):
 
-    def test_winsocketpair(self):
-        ssock, csock = windows_utils.socketpair()
-
+    def check_winsocketpair(self, ssock, csock):
         csock.send(b'xxx')
         self.assertEqual(b'xxx', ssock.recv(1024))
-
         csock.close()
         ssock.close()
 
+    def test_winsocketpair(self):
+        ssock, csock = windows_utils.socketpair()
+        self.check_winsocketpair(ssock, csock)
+
+    @unittest.skipUnless(IPV6_ENABLED, 'IPv6 not supported or enabled')
+    def test_winsocketpair_ipv6(self):
+        ssock, csock = windows_utils.socketpair(family=socket.AF_INET6)
+        self.check_winsocketpair(ssock, csock)
+
     @mock.patch('asyncio.windows_utils.socket')
     def test_winsocketpair_exc(self, m_socket):
+        m_socket.AF_INET = socket.AF_INET
+        m_socket.SOCK_STREAM = socket.SOCK_STREAM
         m_socket.socket.return_value.getsockname.return_value = ('', 12345)
         m_socket.socket.return_value.accept.return_value = object(), object()
         m_socket.socket.return_value.connect.side_effect = OSError()
 
         self.assertRaises(OSError, windows_utils.socketpair)
 
+    def test_winsocketpair_invalid_args(self):
+        self.assertRaises(ValueError,
+                          windows_utils.socketpair, family=socket.AF_UNSPEC)
+        self.assertRaises(ValueError,
+                          windows_utils.socketpair, type=socket.SOCK_DGRAM)
+        self.assertRaises(ValueError,
+                          windows_utils.socketpair, proto=1)
+
+
 
 class PipeTests(unittest.TestCase):