]> granicus.if.org Git - python/commitdiff
Apply patch for 874900: threading module can deadlock after fork
authorJesse Noller <jnoller@gmail.com>
Wed, 16 Jul 2008 20:03:47 +0000 (20:03 +0000)
committerJesse Noller <jnoller@gmail.com>
Wed, 16 Jul 2008 20:03:47 +0000 (20:03 +0000)
Lib/test/test_threading.py
Lib/threading.py
Python/ceval.c

index db6b0d75044dbbc9b7bdcecf3e6d69aac63aacac..8c1ecce2dda04ee102f9e589f2bca156a02236a7 100644 (file)
@@ -323,6 +323,82 @@ class ThreadTests(unittest.TestCase):
                                sys.getrefcount(weak_raising_cyclic_object())))
 
 
+class ThreadJoinOnShutdown(unittest.TestCase):
+
+    def _run_and_join(self, script):
+        script = """if 1:
+            import sys, os, time, threading
+
+            # a thread, which waits for the main program to terminate
+            def joiningfunc(mainthread):
+                mainthread.join()
+                print 'end of thread'
+        \n""" + script
+
+        import subprocess
+        p = subprocess.Popen([sys.executable, "-c", script], stdout=subprocess.PIPE)
+        rc = p.wait()
+        self.assertEqual(p.stdout.read(), "end of main\nend of thread\n")
+        self.failIf(rc == 2, "interpreter was blocked")
+        self.failUnless(rc == 0, "Unexpected error")
+
+    def test_1_join_on_shutdown(self):
+        # The usual case: on exit, wait for a non-daemon thread
+        script = """if 1:
+            import os
+            t = threading.Thread(target=joiningfunc,
+                                 args=(threading.current_thread(),))
+            t.start()
+            time.sleep(0.1)
+            print 'end of main'
+            """
+        self._run_and_join(script)
+
+
+    def test_2_join_in_forked_process(self):
+        # Like the test above, but from a forked interpreter
+        import os
+        if not hasattr(os, 'fork'):
+            return
+        script = """if 1:
+            childpid = os.fork()
+            if childpid != 0:
+                os.waitpid(childpid, 0)
+                sys.exit(0)
+
+            t = threading.Thread(target=joiningfunc,
+                                 args=(threading.current_thread(),))
+            t.start()
+            print 'end of main'
+            """
+        self._run_and_join(script)
+
+    def test_3_join_in_forked_from_thread(self):
+        # Like the test above, but fork() was called from a worker thread
+        # In the forked process, the main Thread object must be marked as stopped.
+        import os
+        if not hasattr(os, 'fork'):
+            return
+        script = """if 1:
+            main_thread = threading.current_thread()
+            def worker():
+                childpid = os.fork()
+                if childpid != 0:
+                    os.waitpid(childpid, 0)
+                    sys.exit(0)
+
+                t = threading.Thread(target=joiningfunc,
+                                     args=(main_thread,))
+                print 'end of main'
+                t.start()
+                t.join() # Should not block: main_thread is already stopped
+
+            w = threading.Thread(target=worker)
+            w.start()
+            """
+        self._run_and_join(script)
+
+
 class ThreadingExceptionTests(unittest.TestCase):
     # A RuntimeError should be raised if Thread.start() is called
     # multiple times.
@@ -363,7 +439,9 @@ class ThreadingExceptionTests(unittest.TestCase):
 
 def test_main():
     test.test_support.run_unittest(ThreadTests,
-                                   ThreadingExceptionTests)
+                                   ThreadJoinOnShutdown,
+                                   ThreadingExceptionTests,
+                                   )
 
 if __name__ == "__main__":
     test_main()
index bfca44c065dfca47b9d62172d1e30c679aedd3a6..8a1de42bea6eb208ee13ed80a8e9bf6a2e1270bf 100644 (file)
@@ -825,6 +825,37 @@ except ImportError:
     from _threading_local import local
 
 
+def _after_fork():
+    # This function is called by Python/ceval.c:PyEval_ReInitThreads which
+    # is called from PyOS_AfterFork.  Here we cleanup threading module state
+    # that should not exist after a fork.
+
+    # Reset _active_limbo_lock, in case we forked while the lock was held
+    # by another (non-forked) thread.  http://bugs.python.org/issue874900
+    global _active_limbo_lock
+    _active_limbo_lock = _allocate_lock()
+
+    # fork() only copied the current thread; clear references to others.
+    new_active = {}
+    current = current_thread()
+    with _active_limbo_lock:
+        for ident, thread in _active.iteritems():
+            if thread is current:
+                # There is only one active thread.
+                new_active[ident] = thread
+            else:
+                # All the others are already stopped.
+                # We don't call _Thread__stop() because it tries to acquire
+                # thread._Thread__block which could also have been held while
+                # we forked.
+                thread._Thread__stopped = True
+
+        _limbo.clear()
+        _active.clear()
+        _active.update(new_active)
+        assert len(_active) == 1
+
+
 # Self-test code
 
 def _test():
index a9e37ae1fcbae0f6f72cb0ca69a56b0156963b8c..f61bcd51b2a68e2a8ee4360b15db38ce3a8040a0 100644 (file)
@@ -274,6 +274,9 @@ PyEval_ReleaseThread(PyThreadState *tstate)
 void
 PyEval_ReInitThreads(void)
 {
+       PyObject *threading, *result;
+       PyThreadState *tstate;
+
        if (!interpreter_lock)
                return;
        /*XXX Can't use PyThread_free_lock here because it does too
@@ -283,6 +286,23 @@ PyEval_ReInitThreads(void)
        interpreter_lock = PyThread_allocate_lock();
        PyThread_acquire_lock(interpreter_lock, 1);
        main_thread = PyThread_get_thread_ident();
+
+       /* Update the threading module with the new state.
+        */
+       tstate = PyThreadState_GET();
+       threading = PyMapping_GetItemString(tstate->interp->modules,
+                                           "threading");
+       if (threading == NULL) {
+               /* threading not imported */
+               PyErr_Clear();
+               return;
+       }
+       result = PyObject_CallMethod(threading, "_after_fork", NULL);
+       if (result == NULL)
+               PyErr_WriteUnraisable(threading);
+       else
+               Py_DECREF(result);
+       Py_DECREF(threading);
 }
 #endif