]> granicus.if.org Git - python/commitdiff
asyncio: Fix from Anthony Baire for CPython issue 19566 (replaces earlier fix).
authorGuido van Rossum <guido@dropbox.com>
Wed, 13 Nov 2013 23:50:08 +0000 (15:50 -0800)
committerGuido van Rossum <guido@dropbox.com>
Wed, 13 Nov 2013 23:50:08 +0000 (15:50 -0800)
Lib/asyncio/unix_events.py
Lib/test/test_asyncio/test_events.py
Lib/test/test_asyncio/test_unix_events.py

index 9fced334531f2475d23ea3411418a385b4908948..b611efd17d14171be8a0e512b423b653ed39a4dd 100644 (file)
@@ -440,10 +440,13 @@ class AbstractChildWatcher:
 
         raise NotImplementedError()
 
-    def set_loop(self, loop):
-        """Reattach the watcher to another event loop.
+    def attach_loop(self, loop):
+        """Attach the watcher to an event loop.
 
-        Note: loop may be None
+        If the watcher was previously attached to an event loop, then it is
+        first detached before attaching to the new loop.
+
+        Note: loop may be None.
         """
         raise NotImplementedError()
 
@@ -467,15 +470,11 @@ class AbstractChildWatcher:
 
 class BaseChildWatcher(AbstractChildWatcher):
 
-    def __init__(self, loop):
+    def __init__(self):
         self._loop = None
-        self._callbacks = {}
-
-        self.set_loop(loop)
 
     def close(self):
-        self.set_loop(None)
-        self._callbacks.clear()
+        self.attach_loop(None)
 
     def _do_waitpid(self, expected_pid):
         raise NotImplementedError()
@@ -483,7 +482,7 @@ class BaseChildWatcher(AbstractChildWatcher):
     def _do_waitpid_all(self):
         raise NotImplementedError()
 
-    def set_loop(self, loop):
+    def attach_loop(self, loop):
         assert loop is None or isinstance(loop, events.AbstractEventLoop)
 
         if self._loop is not None:
@@ -497,13 +496,6 @@ class BaseChildWatcher(AbstractChildWatcher):
             # during the switch.
             self._do_waitpid_all()
 
-    def remove_child_handler(self, pid):
-        try:
-            del self._callbacks[pid]
-            return True
-        except KeyError:
-            return False
-
     def _sig_chld(self):
         try:
             self._do_waitpid_all()
@@ -535,6 +527,14 @@ class SafeChildWatcher(BaseChildWatcher):
     big number of children (O(n) each time SIGCHLD is raised)
     """
 
+    def __init__(self):
+        super().__init__()
+        self._callbacks = {}
+
+    def close(self):
+        self._callbacks.clear()
+        super().close()
+
     def __enter__(self):
         return self
 
@@ -547,6 +547,13 @@ class SafeChildWatcher(BaseChildWatcher):
         # Prevent a race condition in case the child is already terminated.
         self._do_waitpid(pid)
 
+    def remove_child_handler(self, pid):
+        try:
+            del self._callbacks[pid]
+            return True
+        except KeyError:
+            return False
+
     def _do_waitpid_all(self):
 
         for pid in list(self._callbacks):
@@ -592,17 +599,17 @@ class FastChildWatcher(BaseChildWatcher):
     There is no noticeable overhead when handling a big number of children
     (O(1) each time a child terminates).
     """
-    def __init__(self, loop):
+    def __init__(self):
+        super().__init__()
+        self._callbacks = {}
         self._lock = threading.Lock()
         self._zombies = {}
         self._forks = 0
-        # Call base class constructor last because it calls back into
-        # the subclass (set_loop() calls _do_waitpid()).
-        super().__init__(loop)
 
     def close(self):
-        super().close()
+        self._callbacks.clear()
         self._zombies.clear()
+        super().close()
 
     def __enter__(self):
         with self._lock:
@@ -643,6 +650,13 @@ class FastChildWatcher(BaseChildWatcher):
         else:
             callback(pid, returncode, *args)
 
+    def remove_child_handler(self, pid):
+        try:
+            del self._callbacks[pid]
+            return True
+        except KeyError:
+            return False
+
     def _do_waitpid_all(self):
         # Because of signal coalescing, we must keep calling waitpid() as
         # long as we're able to reap a child.
@@ -687,25 +701,24 @@ class _UnixDefaultEventLoopPolicy(events.BaseDefaultEventLoopPolicy):
     def _init_watcher(self):
         with events._lock:
             if self._watcher is None:  # pragma: no branch
+                self._watcher = SafeChildWatcher()
                 if isinstance(threading.current_thread(),
                               threading._MainThread):
-                    self._watcher = SafeChildWatcher(self._local._loop)
-                else:
-                    self._watcher = SafeChildWatcher(None)
+                    self._watcher.attach_loop(self._local._loop)
 
     def set_event_loop(self, loop):
         """Set the event loop.
 
         As a side effect, if a child watcher was set before, then calling
-        .set_event_loop() from the main thread will call .set_loop(loop) on the
-        child watcher.
+        .set_event_loop() from the main thread will call .attach_loop(loop) on
+        the child watcher.
         """
 
         super().set_event_loop(loop)
 
         if self._watcher is not None and \
             isinstance(threading.current_thread(), threading._MainThread):
-            self._watcher.set_loop(loop)
+            self._watcher.attach_loop(loop)
 
     def get_child_watcher(self):
         """Get the child watcher
index 00bd4085c1381d35edb67fbb012bde39c79c1404..7b9839ce995427a5b16947785da529455201a380 100644 (file)
@@ -1311,7 +1311,9 @@ else:
     class UnixEventLoopTestsMixin(EventLoopTestsMixin):
         def setUp(self):
             super().setUp()
-            events.set_child_watcher(unix_events.SafeChildWatcher(self.loop))
+            watcher = unix_events.SafeChildWatcher()
+            watcher.attach_loop(self.loop)
+            events.set_child_watcher(watcher)
 
         def tearDown(self):
             events.set_child_watcher(None)
index 42eba8d6932643fa2b89f879b4b2d01108eff1e2..ea1c08cf753ad22f9bf6dc84fcac8e71e0b5ec49 100644 (file)
@@ -687,7 +687,7 @@ class AbstractChildWatcherTests(unittest.TestCase):
         self.assertRaises(
             NotImplementedError, watcher.remove_child_handler, f)
         self.assertRaises(
-            NotImplementedError, watcher.set_loop, f)
+            NotImplementedError, watcher.attach_loop, f)
         self.assertRaises(
             NotImplementedError, watcher.close)
         self.assertRaises(
@@ -700,7 +700,7 @@ class BaseChildWatcherTests(unittest.TestCase):
 
     def test_not_implemented(self):
         f = unittest.mock.Mock()
-        watcher = unix_events.BaseChildWatcher(None)
+        watcher = unix_events.BaseChildWatcher()
         self.assertRaises(
             NotImplementedError, watcher._do_waitpid, f)
 
@@ -720,10 +720,13 @@ class ChildWatcherTestsMixin:
 
         with unittest.mock.patch.object(
                 self.loop, "add_signal_handler") as self.m_add_signal_handler:
-            self.watcher = self.create_watcher(self.loop)
+            self.watcher = self.create_watcher()
+            self.watcher.attach_loop(self.loop)
 
-    def tearDown(self):
-        ChildWatcherTestsMixin.instance = None
+        def cleanup():
+            ChildWatcherTestsMixin.instance = None
+
+        self.addCleanup(cleanup)
 
     def waitpid(pid, flags):
         self = ChildWatcherTestsMixin.instance
@@ -1334,7 +1337,7 @@ class ChildWatcherTestsMixin:
                 self.loop,
                 "add_signal_handler") as m_new_add_signal_handler:
 
-            self.watcher.set_loop(self.loop)
+            self.watcher.attach_loop(self.loop)
 
             m_old_remove_signal_handler.assert_called_once_with(
                 signal.SIGCHLD)
@@ -1375,7 +1378,7 @@ class ChildWatcherTestsMixin:
         with unittest.mock.patch.object(
                 old_loop, "remove_signal_handler") as m_remove_signal_handler:
 
-            self.watcher.set_loop(None)
+            self.watcher.attach_loop(None)
 
             m_remove_signal_handler.assert_called_once_with(
                 signal.SIGCHLD)
@@ -1395,7 +1398,7 @@ class ChildWatcherTestsMixin:
         with unittest.mock.patch.object(
                 self.loop, "add_signal_handler") as m_add_signal_handler:
 
-            self.watcher.set_loop(self.loop)
+            self.watcher.attach_loop(self.loop)
 
             m_add_signal_handler.assert_called_once_with(
                 signal.SIGCHLD, self.watcher._sig_chld)
@@ -1457,13 +1460,13 @@ class ChildWatcherTestsMixin:
 
 
 class SafeChildWatcherTests (ChildWatcherTestsMixin, unittest.TestCase):
-    def create_watcher(self, loop):
-        return unix_events.SafeChildWatcher(loop)
+    def create_watcher(self):
+        return unix_events.SafeChildWatcher()
 
 
 class FastChildWatcherTests (ChildWatcherTestsMixin, unittest.TestCase):
-    def create_watcher(self, loop):
-        return unix_events.FastChildWatcher(loop)
+    def create_watcher(self):
+        return unix_events.FastChildWatcher()
 
 
 class PolicyTests(unittest.TestCase):
@@ -1485,7 +1488,7 @@ class PolicyTests(unittest.TestCase):
 
     def test_get_child_watcher_after_set(self):
         policy = self.create_policy()
-        watcher = unix_events.FastChildWatcher(None)
+        watcher = unix_events.FastChildWatcher()
 
         policy.set_child_watcher(watcher)
         self.assertIs(policy._watcher, watcher)