]> granicus.if.org Git - python/commitdiff
bpo-31234: Add test.support.wait_threads_exit() (#3578)
authorVictor Stinner <victor.stinner@gmail.com>
Thu, 14 Sep 2017 20:07:24 +0000 (13:07 -0700)
committerGitHub <noreply@github.com>
Thu, 14 Sep 2017 20:07:24 +0000 (13:07 -0700)
Use _thread.count() to wait until threads exit. The new context
manager prevents the "dangling thread" warning.

Lib/test/lock_tests.py
Lib/test/support/__init__.py
Lib/test/test_socket.py
Lib/test/test_thread.py
Lib/test/test_threading.py
Lib/test/test_threadsignals.py

index e8fa4f999a7ce8f9c64eaba2eb25124189d12545..a1ea96d42ce8106bd86ef5e112224cb90f7c8a86 100644 (file)
@@ -31,6 +31,9 @@ class Bunch(object):
         self.started = []
         self.finished = []
         self._can_exit = not wait_before_exit
+        self.wait_thread = support.wait_threads_exit()
+        self.wait_thread.__enter__()
+
         def task():
             tid = threading.get_ident()
             self.started.append(tid)
@@ -40,6 +43,7 @@ class Bunch(object):
                 self.finished.append(tid)
                 while not self._can_exit:
                     _wait()
+
         try:
             for i in range(n):
                 start_new_thread(task, ())
@@ -54,13 +58,8 @@ class Bunch(object):
     def wait_for_finished(self):
         while len(self.finished) < self.n:
             _wait()
-        # Wait a little bit longer to prevent the "threading_cleanup()
-        # failed to cleanup X threads" warning. The loop above is a weak
-        # synchronization. At the C level, t_bootstrap() can still be
-        # running and so _thread.count() still accounts the "almost dead"
-        # thead.
-        for _ in range(self.n):
-            _wait()
+        # Wait for threads exit
+        self.wait_thread.__exit__(None, None, None)
 
     def do_finish(self):
         self._can_exit = True
@@ -227,20 +226,23 @@ class LockTests(BaseLockTests):
         # Lock needs to be released before re-acquiring.
         lock = self.locktype()
         phase = []
+
         def f():
             lock.acquire()
             phase.append(None)
             lock.acquire()
             phase.append(None)
-        start_new_thread(f, ())
-        while len(phase) == 0:
-            _wait()
-        _wait()
-        self.assertEqual(len(phase), 1)
-        lock.release()
-        while len(phase) == 1:
+
+        with support.wait_threads_exit():
+            start_new_thread(f, ())
+            while len(phase) == 0:
+                _wait()
             _wait()
-        self.assertEqual(len(phase), 2)
+            self.assertEqual(len(phase), 1)
+            lock.release()
+            while len(phase) == 1:
+                _wait()
+            self.assertEqual(len(phase), 2)
 
     def test_different_thread(self):
         # Lock can be released from a different thread.
index df235050ae2c7cef4b56e0c0f68d8310f178afba..63f7a910710b47ad499d00f0ff2626b68d4abc98 100644 (file)
@@ -2072,6 +2072,41 @@ def reap_threads(func):
     return decorator
 
 
+@contextlib.contextmanager
+def wait_threads_exit(timeout=60.0):
+    """
+    bpo-31234: Context manager to wait until all threads created in the with
+    statement exit.
+
+    Use _thread.count() to check if threads exited. Indirectly, wait until
+    threads exit the internal t_bootstrap() C function of the _thread module.
+
+    threading_setup() and threading_cleanup() are designed to emit a warning
+    if a test leaves running threads in the background. This context manager
+    is designed to cleanup threads started by the _thread.start_new_thread()
+    which doesn't allow to wait for thread exit, whereas thread.Thread has a
+    join() method.
+    """
+    old_count = _thread._count()
+    try:
+        yield
+    finally:
+        start_time = time.monotonic()
+        deadline = start_time + timeout
+        while True:
+            count = _thread._count()
+            if count <= old_count:
+                break
+            if time.monotonic() > deadline:
+                dt = time.monotonic() - start_time
+                msg = (f"wait_threads() failed to cleanup {count - old_count} "
+                       f"threads after {dt:.1f} seconds "
+                       f"(count: {count}, old count: {old_count})")
+                raise AssertionError(msg)
+            time.sleep(0.010)
+            gc_collect()
+
+
 def reap_children():
     """Use this function at the end of test_main() whenever sub-processes
     are started.  This will help ensure that no extra children (zombies)
index 05d8761241e8a4733da9ad9441de7efbf6c3bf6e..01502c805cc896aa910f62b76b2638666075d341 100644 (file)
@@ -271,6 +271,9 @@ class ThreadableTest:
         self.server_ready.set()
 
     def _setUp(self):
+        self.wait_threads = support.wait_threads_exit()
+        self.wait_threads.__enter__()
+
         self.server_ready = threading.Event()
         self.client_ready = threading.Event()
         self.done = threading.Event()
@@ -297,6 +300,7 @@ class ThreadableTest:
     def _tearDown(self):
         self.__tearDown()
         self.done.wait()
+        self.wait_threads.__exit__(None, None, None)
 
         if self.queue.qsize():
             exc = self.queue.get()
index 2dd1593eaa0cb1399df46fc8ea3005d16729f75c..52f6c798b8785470f17661987be887c5875f8847 100644 (file)
@@ -59,12 +59,13 @@ class ThreadRunningTests(BasicThreadTest):
                 self.done_mutex.release()
 
     def test_starting_threads(self):
-        # Basic test for thread creation.
-        for i in range(NUMTASKS):
-            self.newtask()
-        verbose_print("waiting for tasks to complete...")
-        self.done_mutex.acquire()
-        verbose_print("all tasks done")
+        with support.wait_threads_exit():
+            # Basic test for thread creation.
+            for i in range(NUMTASKS):
+                self.newtask()
+            verbose_print("waiting for tasks to complete...")
+            self.done_mutex.acquire()
+            verbose_print("all tasks done")
 
     def test_stack_size(self):
         # Various stack size tests.
@@ -94,12 +95,13 @@ class ThreadRunningTests(BasicThreadTest):
             verbose_print("trying stack_size = (%d)" % tss)
             self.next_ident = 0
             self.created = 0
-            for i in range(NUMTASKS):
-                self.newtask()
+            with support.wait_threads_exit():
+                for i in range(NUMTASKS):
+                    self.newtask()
 
-            verbose_print("waiting for all tasks to complete")
-            self.done_mutex.acquire()
-            verbose_print("all tasks done")
+                verbose_print("waiting for all tasks to complete")
+                self.done_mutex.acquire()
+                verbose_print("all tasks done")
 
         thread.stack_size(0)
 
@@ -109,25 +111,28 @@ class ThreadRunningTests(BasicThreadTest):
         mut = thread.allocate_lock()
         mut.acquire()
         started = []
+
         def task():
             started.append(None)
             mut.acquire()
             mut.release()
-        thread.start_new_thread(task, ())
-        while not started:
-            time.sleep(POLL_SLEEP)
-        self.assertEqual(thread._count(), orig + 1)
-        # Allow the task to finish.
-        mut.release()
-        # The only reliable way to be sure that the thread ended from the
-        # interpreter's point of view is to wait for the function object to be
-        # destroyed.
-        done = []
-        wr = weakref.ref(task, lambda _: done.append(None))
-        del task
-        while not done:
-            time.sleep(POLL_SLEEP)
-        self.assertEqual(thread._count(), orig)
+
+        with support.wait_threads_exit():
+            thread.start_new_thread(task, ())
+            while not started:
+                time.sleep(POLL_SLEEP)
+            self.assertEqual(thread._count(), orig + 1)
+            # Allow the task to finish.
+            mut.release()
+            # The only reliable way to be sure that the thread ended from the
+            # interpreter's point of view is to wait for the function object to be
+            # destroyed.
+            done = []
+            wr = weakref.ref(task, lambda _: done.append(None))
+            del task
+            while not done:
+                time.sleep(POLL_SLEEP)
+            self.assertEqual(thread._count(), orig)
 
     def test_save_exception_state_on_error(self):
         # See issue #14474
@@ -140,16 +145,14 @@ class ThreadRunningTests(BasicThreadTest):
             except ValueError:
                 pass
             real_write(self, *args)
-        c = thread._count()
         started = thread.allocate_lock()
         with support.captured_output("stderr") as stderr:
             real_write = stderr.write
             stderr.write = mywrite
             started.acquire()
-            thread.start_new_thread(task, ())
-            started.acquire()
-            while thread._count() > c:
-                time.sleep(POLL_SLEEP)
+            with support.wait_threads_exit():
+                thread.start_new_thread(task, ())
+                started.acquire()
         self.assertIn("Traceback", stderr.getvalue())
 
 
@@ -181,13 +184,14 @@ class Barrier:
 class BarrierTest(BasicThreadTest):
 
     def test_barrier(self):
-        self.bar = Barrier(NUMTASKS)
-        self.running = NUMTASKS
-        for i in range(NUMTASKS):
-            thread.start_new_thread(self.task2, (i,))
-        verbose_print("waiting for tasks to end")
-        self.done_mutex.acquire()
-        verbose_print("tasks done")
+        with support.wait_threads_exit():
+            self.bar = Barrier(NUMTASKS)
+            self.running = NUMTASKS
+            for i in range(NUMTASKS):
+                thread.start_new_thread(self.task2, (i,))
+            verbose_print("waiting for tasks to end")
+            self.done_mutex.acquire()
+            verbose_print("tasks done")
 
     def task2(self, ident):
         for i in range(NUMTRIPS):
@@ -225,11 +229,10 @@ class TestForkInThread(unittest.TestCase):
     @unittest.skipUnless(hasattr(os, 'fork'), 'need os.fork')
     @support.reap_threads
     def test_forkinthread(self):
-        running = True
         status = "not set"
 
         def thread1():
-            nonlocal running, status
+            nonlocal status
 
             # fork in a thread
             pid = os.fork()
@@ -244,13 +247,11 @@ class TestForkInThread(unittest.TestCase):
                 # parent
                 os.close(self.write_fd)
                 pid, status = os.waitpid(pid, 0)
-                running = False
 
-        thread.start_new_thread(thread1, ())
-        self.assertEqual(os.read(self.read_fd, 2), b"OK",
-                         "Unable to fork() in thread")
-        while running:
-            time.sleep(POLL_SLEEP)
+        with support.wait_threads_exit():
+            thread.start_new_thread(thread1, ())
+            self.assertEqual(os.read(self.read_fd, 2), b"OK",
+                             "Unable to fork() in thread")
         self.assertEqual(status, 0)
 
     def tearDown(self):
index ab383c23332b0992b63f9dae4ebe4b215f04915f..af6796cd2e6242d6ffcf8a0e7729a4dfa5f871f7 100644 (file)
@@ -125,9 +125,10 @@ class ThreadTests(BaseTestCase):
             done.set()
         done = threading.Event()
         ident = []
-        _thread.start_new_thread(f, ())
-        done.wait()
-        self.assertIsNotNone(ident[0])
+        with support.wait_threads_exit():
+            tid = _thread.start_new_thread(f, ())
+            done.wait()
+            self.assertEqual(ident[0], tid)
         # Kill the "immortal" _DummyThread
         del threading._active[ident[0]]
 
@@ -165,9 +166,10 @@ class ThreadTests(BaseTestCase):
 
         mutex = threading.Lock()
         mutex.acquire()
-        tid = _thread.start_new_thread(f, (mutex,))
-        # Wait for the thread to finish.
-        mutex.acquire()
+        with support.wait_threads_exit():
+            tid = _thread.start_new_thread(f, (mutex,))
+            # Wait for the thread to finish.
+            mutex.acquire()
         self.assertIn(tid, threading._active)
         self.assertIsInstance(threading._active[tid], threading._DummyThread)
         #Issue 29376
index 9d927423756e666d6f5d2b7095a42accf0c976cf..f93dd772555bbffbd12c3350d7ea74c450ae8769 100644 (file)
@@ -4,8 +4,8 @@ import unittest
 import signal
 import os
 import sys
-from test.support import run_unittest, import_module
-thread = import_module('_thread')
+from test import support
+thread = support.import_module('_thread')
 import time
 
 if (sys.platform[:3] == 'win'):
@@ -39,13 +39,15 @@ def send_signals():
 class ThreadSignals(unittest.TestCase):
 
     def test_signals(self):
-        # Test signal handling semantics of threads.
-        # We spawn a thread, have the thread send two signals, and
-        # wait for it to finish. Check that we got both signals
-        # and that they were run by the main thread.
-        signalled_all.acquire()
-        self.spawnSignallingThread()
-        signalled_all.acquire()
+        with support.wait_threads_exit():
+            # Test signal handling semantics of threads.
+            # We spawn a thread, have the thread send two signals, and
+            # wait for it to finish. Check that we got both signals
+            # and that they were run by the main thread.
+            signalled_all.acquire()
+            self.spawnSignallingThread()
+            signalled_all.acquire()
+
         # the signals that we asked the kernel to send
         # will come back, but we don't know when.
         # (it might even be after the thread exits
@@ -115,17 +117,19 @@ class ThreadSignals(unittest.TestCase):
             # thread.
             def other_thread():
                 rlock.acquire()
-            thread.start_new_thread(other_thread, ())
-            # Wait until we can't acquire it without blocking...
-            while rlock.acquire(blocking=False):
-                rlock.release()
-                time.sleep(0.01)
-            signal.alarm(1)
-            t1 = time.time()
-            self.assertRaises(KeyboardInterrupt, rlock.acquire, timeout=5)
-            dt = time.time() - t1
-            # See rationale above in test_lock_acquire_interruption
-            self.assertLess(dt, 3.0)
+
+            with support.wait_threads_exit():
+                thread.start_new_thread(other_thread, ())
+                # Wait until we can't acquire it without blocking...
+                while rlock.acquire(blocking=False):
+                    rlock.release()
+                    time.sleep(0.01)
+                signal.alarm(1)
+                t1 = time.time()
+                self.assertRaises(KeyboardInterrupt, rlock.acquire, timeout=5)
+                dt = time.time() - t1
+                # See rationale above in test_lock_acquire_interruption
+                self.assertLess(dt, 3.0)
         finally:
             signal.signal(signal.SIGALRM, oldalrm)
 
@@ -133,6 +137,7 @@ class ThreadSignals(unittest.TestCase):
         self.sig_recvd = False
         def my_handler(signal, frame):
             self.sig_recvd = True
+
         old_handler = signal.signal(signal.SIGUSR1, my_handler)
         try:
             def other_thread():
@@ -147,14 +152,16 @@ class ThreadSignals(unittest.TestCase):
                 # the lock acquisition.  Then we'll let it run.
                 time.sleep(0.5)
                 lock.release()
-            thread.start_new_thread(other_thread, ())
-            # Wait until we can't acquire it without blocking...
-            while lock.acquire(blocking=False):
-                lock.release()
-                time.sleep(0.01)
-            result = lock.acquire()  # Block while we receive a signal.
-            self.assertTrue(self.sig_recvd)
-            self.assertTrue(result)
+
+            with support.wait_threads_exit():
+                thread.start_new_thread(other_thread, ())
+                # Wait until we can't acquire it without blocking...
+                while lock.acquire(blocking=False):
+                    lock.release()
+                    time.sleep(0.01)
+                result = lock.acquire()  # Block while we receive a signal.
+                self.assertTrue(self.sig_recvd)
+                self.assertTrue(result)
         finally:
             signal.signal(signal.SIGUSR1, old_handler)
 
@@ -193,19 +200,20 @@ class ThreadSignals(unittest.TestCase):
                     os.kill(process_pid, signal.SIGUSR1)
                 done.release()
 
-            # Send the signals from the non-main thread, since the main thread
-            # is the only one that can process signals.
-            thread.start_new_thread(send_signals, ())
-            timed_acquire()
-            # Wait for thread to finish
-            done.acquire()
-            # This allows for some timing and scheduling imprecision
-            self.assertLess(self.end - self.start, 2.0)
-            self.assertGreater(self.end - self.start, 0.3)
-            # If the signal is received several times before PyErr_CheckSignals()
-            # is called, the handler will get called less than 40 times. Just
-            # check it's been called at least once.
-            self.assertGreater(self.sigs_recvd, 0)
+            with support.wait_threads_exit():
+                # Send the signals from the non-main thread, since the main thread
+                # is the only one that can process signals.
+                thread.start_new_thread(send_signals, ())
+                timed_acquire()
+                # Wait for thread to finish
+                done.acquire()
+                # This allows for some timing and scheduling imprecision
+                self.assertLess(self.end - self.start, 2.0)
+                self.assertGreater(self.end - self.start, 0.3)
+                # If the signal is received several times before PyErr_CheckSignals()
+                # is called, the handler will get called less than 40 times. Just
+                # check it's been called at least once.
+                self.assertGreater(self.sigs_recvd, 0)
         finally:
             signal.signal(signal.SIGUSR1, old_handler)
 
@@ -219,7 +227,7 @@ def test_main():
 
     oldsigs = registerSignals(handle_signals, handle_signals, handle_signals)
     try:
-        run_unittest(ThreadSignals)
+        support.run_unittest(ThreadSignals)
     finally:
         registerSignals(*oldsigs)