]> granicus.if.org Git - python/commitdiff
Issue #23799: Added test.support.start_threads() for running and cleaning up
authorSerhiy Storchaka <storchaka@gmail.com>
Wed, 1 Apr 2015 10:01:14 +0000 (13:01 +0300)
committerSerhiy Storchaka <storchaka@gmail.com>
Wed, 1 Apr 2015 10:01:14 +0000 (13:01 +0300)
multiple threads.

Lib/test/support/__init__.py
Lib/test/test_bz2.py
Lib/test/test_capi.py
Lib/test/test_gc.py
Lib/test/test_io.py
Lib/test/test_threaded_import.py
Lib/test/test_threadedtempfile.py
Lib/test/test_threading_local.py
Misc/NEWS

index 10c48b4b7ef60fed6fe9c952b71921bbc99b2c5a..75fff2157ed0930e1ab9a0102de9ae1e3697471d 100644 (file)
@@ -6,6 +6,7 @@ if __name__ != 'test.support':
 import collections.abc
 import contextlib
 import errno
+import faulthandler
 import fnmatch
 import functools
 import gc
@@ -96,7 +97,7 @@ __all__ = [
     # logging
     "TestHandler",
     # threads
-    "threading_setup", "threading_cleanup",
+    "threading_setup", "threading_cleanup", "reap_threads", "start_threads",
     # miscellaneous
     "check_warnings", "EnvironmentVarGuard", "run_with_locale", "swap_item",
     "swap_attr", "Matcher", "set_memlimit", "SuppressCrashReport", "sortdict",
@@ -1940,6 +1941,42 @@ def reap_children():
             except:
                 break
 
+@contextlib.contextmanager
+def start_threads(threads, unlock=None):
+    threads = list(threads)
+    started = []
+    try:
+        try:
+            for t in threads:
+                t.start()
+                started.append(t)
+        except:
+            if verbose:
+                print("Can't start %d threads, only %d threads started" %
+                      (len(threads), len(started)))
+            raise
+        yield
+    finally:
+        try:
+            if unlock:
+                unlock()
+            endtime = starttime = time.time()
+            for timeout in range(1, 16):
+                endtime += 60
+                for t in started:
+                    t.join(max(endtime - time.time(), 0.01))
+                started = [t for t in started if t.isAlive()]
+                if not started:
+                    break
+                if verbose:
+                    print('Unable to join %d threads during a period of '
+                          '%d minutes' % (len(started), timeout))
+        finally:
+            started = [t for t in started if t.isAlive()]
+            if started:
+                faulthandler.dump_traceback(sys.stdout)
+                raise AssertionError('Unable to join %d threads' % len(started))
+
 @contextlib.contextmanager
 def swap_attr(obj, attr, new_val):
     """Temporary swap out an attribute with a new object.
index 1535e8e6695bca3345f9323ab48bd024d93e2575..beef27593001b5f9472935cbce1bbe371813f8a8 100644 (file)
@@ -493,10 +493,8 @@ class BZ2FileTest(BaseTest):
                 for i in range(5):
                     f.write(data)
             threads = [threading.Thread(target=comp) for i in range(nthreads)]
-            for t in threads:
-                t.start()
-            for t in threads:
-                t.join()
+            with support.start_threads(threads):
+                pass
 
     def testWithoutThreading(self):
         module = support.import_fresh_module("bz2", blocked=("threading",))
index ba7c38db27cf908f9b3820743d5129889a8d72e2..36c62376b1577cff0d81113fb707c383ffaad2ae 100644 (file)
@@ -202,15 +202,11 @@ class TestPendingCalls(unittest.TestCase):
         context.lock = threading.Lock()
         context.event = threading.Event()
 
-        for i in range(context.nThreads):
-            t = threading.Thread(target=self.pendingcalls_thread, args = (context,))
-            t.start()
-            threads.append(t)
-
-        self.pendingcalls_wait(context.l, n, context)
-
-        for t in threads:
-            t.join()
+        threads = [threading.Thread(target=self.pendingcalls_thread,
+                                    args=(context,))
+                   for i in range(context.nThreads)]
+        with support.start_threads(threads):
+            self.pendingcalls_wait(context.l, n, context)
 
     def pendingcalls_thread(self, context):
         try:
index c025512790cd451500698abdfba20a736cb20ec1..2ac1d4bb64259d6a15d49723454e0140e2e9a7c1 100644 (file)
@@ -1,6 +1,6 @@
 import unittest
 from test.support import (verbose, refcount_test, run_unittest,
-                            strip_python_stderr, cpython_only)
+                            strip_python_stderr, cpython_only, start_threads)
 from test.script_helper import assert_python_ok, make_script, temp_dir
 
 import sys
@@ -397,19 +397,13 @@ class GCTests(unittest.TestCase):
         old_switchinterval = sys.getswitchinterval()
         sys.setswitchinterval(1e-5)
         try:
-            exit = False
+            exit = []
             threads = []
             for i in range(N_THREADS):
                 t = threading.Thread(target=run_thread)
                 threads.append(t)
-            try:
-                for t in threads:
-                    t.start()
-            finally:
+            with start_threads(threads, lambda: exit.append(1)):
                 time.sleep(1.0)
-                exit = True
-            for t in threads:
-                t.join()
         finally:
             sys.setswitchinterval(old_switchinterval)
         gc.collect()
index ec19562006f615e069668113f5febf43c3eba6b5..95277d9c4768ba6ba4e46aed721edc8e4fae8731 100644 (file)
@@ -1070,11 +1070,8 @@ class BufferedReaderTest(unittest.TestCase, CommonBufferedTests):
                         errors.append(e)
                         raise
                 threads = [threading.Thread(target=f) for x in range(20)]
-                for t in threads:
-                    t.start()
-                time.sleep(0.02) # yield
-                for t in threads:
-                    t.join()
+                with support.start_threads(threads):
+                    time.sleep(0.02) # yield
                 self.assertFalse(errors,
                     "the following exceptions were caught: %r" % errors)
                 s = b''.join(results)
@@ -1393,11 +1390,8 @@ class BufferedWriterTest(unittest.TestCase, CommonBufferedTests):
                         errors.append(e)
                         raise
                 threads = [threading.Thread(target=f) for x in range(20)]
-                for t in threads:
-                    t.start()
-                time.sleep(0.02) # yield
-                for t in threads:
-                    t.join()
+                with support.start_threads(threads):
+                    time.sleep(0.02) # yield
                 self.assertFalse(errors,
                     "the following exceptions were caught: %r" % errors)
                 bufio.close()
@@ -2691,14 +2685,10 @@ class TextIOWrapperTest(unittest.TestCase):
                 text = "Thread%03d\n" % n
                 event.wait()
                 f.write(text)
-            threads = [threading.Thread(target=lambda n=x: run(n))
+            threads = [threading.Thread(target=run, args=(x,))
                        for x in range(20)]
-            for t in threads:
-                t.start()
-            time.sleep(0.02)
-            event.set()
-            for t in threads:
-                t.join()
+            with support.start_threads(threads, event.set):
+                time.sleep(0.02)
         with self.open(support.TESTFN) as f:
             content = f.read()
             for n in range(20):
@@ -3402,11 +3392,11 @@ class SignalsTest(unittest.TestCase):
             # handlers, which in this case will invoke alarm_interrupt().
             signal.alarm(1)
             try:
-                self.assertRaises(ZeroDivisionError,
-                            wio.write, item * (support.PIPE_MAX_SIZE // len(item) + 1))
+                with self.assertRaises(ZeroDivisionError):
+                    wio.write(item * (support.PIPE_MAX_SIZE // len(item) + 1))
             finally:
                 signal.alarm(0)
-            t.join()
+                t.join()
             # We got one byte, get another one and check that it isn't a
             # repeat of the first one.
             read_results.append(os.read(r, 1))
index 192fa08a2659cfe1734b5ebec2480acd0c7a2000..4be615a5a809434286bb9306ab312587e2ef129c 100644 (file)
@@ -14,7 +14,7 @@ import shutil
 import unittest
 from test.support import (
     verbose, import_module, run_unittest, TESTFN, reap_threads,
-    forget, unlink, rmtree)
+    forget, unlink, rmtree, start_threads)
 threading = import_module('threading')
 
 def task(N, done, done_tasks, errors):
@@ -115,10 +115,10 @@ class ThreadedImportTests(unittest.TestCase):
             errors = []
             done_tasks = []
             done.clear()
-            for i in range(N):
-                t = threading.Thread(target=task,
-                                     args=(N, done, done_tasks, errors,))
-                t.start()
+            with start_threads(threading.Thread(target=task,
+                                                args=(N, done, done_tasks, errors,))
+                               for i in range(N)):
+                pass
             self.assertTrue(done.wait(60))
             self.assertFalse(errors)
             if verbose:
index 2dfd3a08db876046b54adbeef762bdad7b1cfcd6..b7420360eae1d4b0ae65227b938d6c035892bb23 100644 (file)
@@ -18,7 +18,7 @@ FILES_PER_THREAD = 50
 
 import tempfile
 
-from test.support import threading_setup, threading_cleanup, run_unittest, import_module
+from test.support import start_threads, import_module
 threading = import_module('threading')
 import unittest
 import io
@@ -46,33 +46,17 @@ class TempFileGreedy(threading.Thread):
 
 class ThreadedTempFileTest(unittest.TestCase):
     def test_main(self):
-        threads = []
-        thread_info = threading_setup()
-
-        for i in range(NUM_THREADS):
-            t = TempFileGreedy()
-            threads.append(t)
-            t.start()
-
-        startEvent.set()
-
-        ok = 0
-        errors = []
-        for t in threads:
-            t.join()
-            ok += t.ok_count
-            if t.error_count:
-                errors.append(str(t.name) + str(t.errors.getvalue()))
-
-        threading_cleanup(*thread_info)
+        threads = [TempFileGreedy() for i in range(NUM_THREADS)]
+        with start_threads(threads, startEvent.set):
+            pass
+        ok = sum(t.ok_count for t in threads)
+        errors = [str(t.name) + str(t.errors.getvalue())
+                  for t in threads if t.error_count]
 
         msg = "Errors: errors %d ok %d\n%s" % (len(errors), ok,
             '\n'.join(errors))
         self.assertEqual(errors, [], msg)
         self.assertEqual(ok, NUM_THREADS * FILES_PER_THREAD)
 
-def test_main():
-    run_unittest(ThreadedTempFileTest)
-
 if __name__ == "__main__":
-    test_main()
+    unittest.main()
index c886a25d8ab6d64d2ac52c452cd7cfbc02d2d814..c7f394cf60b3bc47a0507a71e45ac59744351ef0 100644 (file)
@@ -64,14 +64,9 @@ class BaseLocalTest:
             # Simply check that the variable is correctly set
             self.assertEqual(local.x, i)
 
-        threads= []
-        for i in range(10):
-            t = threading.Thread(target=f, args=(i,))
-            t.start()
-            threads.append(t)
-
-        for t in threads:
-            t.join()
+        with support.start_threads(threading.Thread(target=f, args=(i,))
+                                   for i in range(10)):
+            pass
 
     def test_derived_cycle_dealloc(self):
         # http://bugs.python.org/issue6990
index cd72813dad3fc2396ae7e826f0d82d6d1f61e502..33fc0ee3b0267e380b2d5ef908140ef267cda986 100644 (file)
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -137,6 +137,9 @@ Library
 Tests
 -----
 
+- Issue #23799: Added test.support.start_threads() for running and
+  cleaning up multiple threads.
+
 - Issue #22390: test.regrtest now emits a warning if temporary files or
   directories are left after running a test.