]> granicus.if.org Git - python/commitdiff
issue 8777
authorKristján Valur Jónsson <kristjan@ccpgames.com>
Thu, 28 Oct 2010 09:43:10 +0000 (09:43 +0000)
committerKristján Valur Jónsson <kristjan@ccpgames.com>
Thu, 28 Oct 2010 09:43:10 +0000 (09:43 +0000)
Add threading.Barrier

Doc/library/threading.rst
Lib/test/lock_tests.py
Lib/test/test_threading.py
Lib/threading.py

index 7c8f709d7a483d8b1f6a1169fa0edbccc9a205c9..64aa14b2d192d8ad5bcd7a015938f26eb4f72db9 100644 (file)
@@ -768,6 +768,110 @@ For example::
       only work if the timer is still in its waiting stage.
 
 
+.. _barrier-objects
+
+Barrier Objects
+---------------
+
+This class provides a simple synchronization primitive for use by a fixed
+number of threads that need to wait for each other.  Each of the threads
+tries to pass the barrier by calling the :meth:`wait` method and will block
+until all of the threads have made the call.
+At this points, the threads are released simultanously.
+
+The barrier can be reused any number of times for the same number of threads.
+
+As an example, here is a simple way to synchronize a client and server thread::
+
+   b = Barrier(2, timeout=5)
+   server():
+       start_server()
+       b.wait()
+       while True:
+           connection = accept_connection()
+           process_server_connection(connection)
+
+   client():
+       b.wait()
+       while True:
+          connection = make_connection()
+          process_client_connection(connection)
+
+.. class:: Barrier(parties, action=None, timeout=None)
+
+   Create a barrier object for *parties* number of threads. An *action*,
+   when provided, is a callable to be called by one of the threads when
+   they are released.  *timeout* is the default timeout value if none
+   is specified for the :meth:`wait` method.
+
+   .. method:: wait(timeout=None)
+
+      Pass the barrier.  When all the threads party to the barrier have called
+      this function, they are all released simultaneously.  If a *timeout*
+      is provided, is is used in preference to any that was supplied to the
+      class constructor.
+
+      The return value is an integer in the range 0 to *parties*-1, different
+      for each thrad.  This can be used to select a thread to do some special
+      housekeeping, eg:
+
+         i = barrier.wait()
+         if i == 0:
+            # Only one thread needs to print this
+            print("passed the barrier")
+
+      If an *action* was
+      provided to the constructor, one of the threads will have called it
+      prior to being released.  Should this call raise an error, the barrier
+      is put into the broken state.
+
+      If the call times out, the barrier is put into the broken state.
+
+      This method may raise a :class:`BrokenBarrierError` exception if the
+      barrier is broken or reset while a thread is waiting
+
+   .. method:: reset()
+
+      Return the barrier to the default, empty state.  Any threads waiting on
+      it will receive the :class:`BrokenBarrierError` exception.
+
+      Note that using this function may can require some external
+      synchronization if there are other threads whose state is unknown.
+      If a barrier is broken it may be better to just leave it and create a
+      new one.
+
+   .. method:: abort()
+
+      Put the barrier into a broken state.  This causes any active or future
+      calls to :meth:`wait` to fail with the :class:`BrokenBarrierError`.
+      Use this for example if one of the needs to abort, to avoid deadlocking
+      the application.
+
+      It may be preferable to simply create the barrier with a sensible
+      *timeout* value to automatically guard against one of the threads
+      going awry.
+
+   .. attribute:: parties
+
+      The number of threads required to pass the barrier.
+
+   .. attribute:: n_waiting
+
+      The number of threads currently waiting in the barrier.
+
+   .. attribute:: broken
+
+      A boolean that is ``True`` if the barrier is in the broken state.
+
+   .. versionadded:: 3.2
+
+.. class:: BrokenBarrierError(RuntimeError)
+
+   The exception raised when the :class:`Barrier` object is reset or broken.
+
+   .. versionadded:: 3.2
+
+
 .. _with-locks:
 
 Using locks, conditions, and semaphores in the :keyword:`with` statement
index 1ff6af0a7abdea4b5411eb3acbf1e84402fb1cfd..f256a807fad0e6e98d68e2284165303f886a4f54 100644 (file)
@@ -597,3 +597,193 @@ class BoundedSemaphoreTests(BaseSemaphoreTests):
         sem.acquire()
         sem.release()
         self.assertRaises(ValueError, sem.release)
+
+
+class BarrierTests(BaseTestCase):
+    """
+    Tests for Barrier objects.
+    """
+    N = 5
+
+    def setUp(self):
+        self.barrier = self.barriertype(self.N, timeout=0.1)
+    def tearDown(self):
+        self.barrier.abort()
+
+    def run_threads(self, f):
+        b = Bunch(f, self.N-1)
+        f()
+        b.wait_for_finished()
+
+    def multipass(self, results, n):
+        m = self.barrier.parties
+        self.assertEqual(m, self.N)
+        for i in range(n):
+            results[0].append(True)
+            self.assertEqual(len(results[1]), i * m)
+            self.barrier.wait()
+            results[1].append(True)
+            self.assertEqual(len(results[0]), (i + 1) * m)
+            self.barrier.wait()
+        self.assertEqual(self.barrier.n_waiting, 0)
+        self.assertFalse(self.barrier.broken)
+
+    def test_barrier(self, passes=1):
+        """
+        Test that a barrier is passed in lockstep
+        """
+        results = [[],[]]
+        def f():
+            self.multipass(results, passes)
+        self.run_threads(f)
+
+    def test_barrier_10(self):
+        """
+        Test that a barrier works for 10 consecutive runs
+        """
+        return self.test_barrier(10)
+
+    def test_wait_return(self):
+        """
+        test the return value from barrier.wait
+        """
+        results = []
+        def f():
+            r = self.barrier.wait()
+            results.append(r)
+
+        self.run_threads(f)
+        self.assertEqual(sum(results), sum(range(self.N)))
+
+    def test_action(self):
+        """
+        Test the 'action' callback
+        """
+        results = []
+        def action():
+            results.append(True)
+        barrier = self.barriertype(self.N, action)
+        def f():
+            barrier.wait()
+            self.assertEqual(len(results), 1)
+
+        self.run_threads(f)
+
+    def test_abort(self):
+        """
+        Test that an abort will put the barrier in a broken state
+        """
+        results1 = []
+        results2 = []
+        def f():
+            try:
+                i = self.barrier.wait()
+                if i == self.N//2:
+                    raise RuntimeError
+                self.barrier.wait()
+                results1.append(True)
+            except threading.BrokenBarrierError:
+                results2.append(True)
+            except RuntimeError:
+                self.barrier.abort()
+                pass
+
+        self.run_threads(f)
+        self.assertEqual(len(results1), 0)
+        self.assertEqual(len(results2), self.N-1)
+        self.assertTrue(self.barrier.broken)
+
+    def test_reset(self):
+        """
+        Test that a 'reset' on a barrier frees the waiting threads
+        """
+        results1 = []
+        results2 = []
+        results3 = []
+        def f():
+            i = self.barrier.wait()
+            if i == self.N//2:
+                # Wait until the other threads are all in the barrier.
+                while self.barrier.n_waiting < self.N-1:
+                    time.sleep(0.001)
+                self.barrier.reset()
+            else:
+                try:
+                    self.barrier.wait()
+                    results1.append(True)
+                except threading.BrokenBarrierError:
+                    results2.append(True)
+            # Now, pass the barrier again
+            self.barrier.wait()
+            results3.append(True)
+
+        self.run_threads(f)
+        self.assertEqual(len(results1), 0)
+        self.assertEqual(len(results2), self.N-1)
+        self.assertEqual(len(results3), self.N)
+
+
+    def test_abort_and_reset(self):
+        """
+        Test that a barrier can be reset after being broken.
+        """
+        results1 = []
+        results2 = []
+        results3 = []
+        barrier2 = self.barriertype(self.N)
+        def f():
+            try:
+                i = self.barrier.wait()
+                if i == self.N//2:
+                    raise RuntimeError
+                self.barrier.wait()
+                results1.append(True)
+            except threading.BrokenBarrierError:
+                results2.append(True)
+            except RuntimeError:
+                self.barrier.abort()
+                pass
+            # Synchronize and reset the barrier.  Must synchronize first so
+            # that everyone has left it when we reset, and after so that no
+            # one enters it before the reset.
+            if barrier2.wait() == self.N//2:
+                self.barrier.reset()
+            barrier2.wait()
+            self.barrier.wait()
+            results3.append(True)
+
+        self.run_threads(f)
+        self.assertEqual(len(results1), 0)
+        self.assertEqual(len(results2), self.N-1)
+        self.assertEqual(len(results3), self.N)
+
+    def test_timeout(self):
+        """
+        Test wait(timeout)
+        """
+        def f():
+            i = self.barrier.wait()
+            if i == self.N // 2:
+                # One thread is late!
+                time.sleep(0.1)
+            # Default timeout is 0.1, so this is shorter.
+            self.assertRaises(threading.BrokenBarrierError,
+                              self.barrier.wait, 0.05)
+        self.run_threads(f)
+
+    def test_default_timeout(self):
+        """
+        Test the barrier's default timeout
+        """
+        def f():
+            i = self.barrier.wait()
+            if i == self.N // 2:
+                # One thread is later than the default timeout of 0.1s.
+                time.sleep(0.15)
+            self.assertRaises(threading.BrokenBarrierError, self.barrier.wait)
+        self.run_threads(f)
+
+    def test_single_thread(self):
+        b = self.barriertype(1)
+        b.wait()
+        b.wait()
index 62ad4af7ec53f2314b1c70f374c200f73eb0a15a..a453ccc490973e5ad4893997f8ece6513a9abee5 100644 (file)
@@ -555,6 +555,8 @@ class SemaphoreTests(lock_tests.SemaphoreTests):
 class BoundedSemaphoreTests(lock_tests.BoundedSemaphoreTests):
     semtype = staticmethod(threading.BoundedSemaphore)
 
+class BarrierTests(lock_tests.BarrierTests):
+    barriertype = staticmethod(threading.Barrier)
 
 def test_main():
     test.support.run_unittest(LockTests, PyRLockTests, CRLockTests, EventTests,
@@ -563,6 +565,7 @@ def test_main():
                               ThreadTests,
                               ThreadJoinOnShutdown,
                               ThreadingExceptionTests,
+                              BarrierTests
                               )
 
 if __name__ == "__main__":
index 238a5c4508ff7482a3ac726e8bbedda53aaa6878..41956edce7aeb784b12ca0605bbf6d00e2cf2fb3 100644 (file)
@@ -392,6 +392,178 @@ class _Event(_Verbose):
         finally:
             self._cond.release()
 
+
+# A barrier class.  Inspired in part by the pthread_barrier_* api and
+# the CyclicBarrier class from Java.  See
+# http://sourceware.org/pthreads-win32/manual/pthread_barrier_init.html and
+# http://java.sun.com/j2se/1.5.0/docs/api/java/util/concurrent/
+#        CyclicBarrier.html
+# for information.
+# We maintain two main states, 'filling' and 'draining' enabling the barrier
+# to be cyclic.  Threads are not allowed into it until it has fully drained
+# since the previous cycle.  In addition, a 'resetting' state exists which is
+# similar to 'draining' except that threads leave with a BrokenBarrierError,
+# and a 'broken' state in which all threads get get the exception.
+class Barrier(_Verbose):
+    """
+    Barrier.  Useful for synchronizing a fixed number of threads
+    at known synchronization points.  Threads block on 'wait()' and are
+    simultaneously once they have all made that call.
+    """
+    def __init__(self, parties, action=None, timeout=None, verbose=None):
+        """
+        Create a barrier, initialised to 'parties' threads.
+        'action' is a callable which, when supplied, will be called
+        by one of the threads after they have all entered the
+        barrier and just prior to releasing them all.
+        If a 'timeout' is provided, it is uses as the default for
+        all subsequent 'wait()' calls.
+        """
+        _Verbose.__init__(self, verbose)
+        self._cond = Condition(Lock())
+        self._action = action
+        self._timeout = timeout
+        self._parties = parties
+        self._state = 0 #0 filling, 1, draining, -1 resetting, -2 broken
+        self._count = 0
+
+    def wait(self, timeout=None):
+        """
+        Wait for the barrier.  When the specified number of threads have
+        started waiting, they are all simultaneously awoken. If an 'action'
+        was provided for the barrier, one of the threads will have executed
+        that callback prior to returning.
+        Returns an individual index number from 0 to 'parties-1'.
+        """
+        if timeout is None:
+            timeout = self._timeout
+        with self._cond:
+            self._enter() # Block while the barrier drains.
+            index = self._count
+            self._count += 1
+            try:
+                if index + 1 == self._parties:
+                    # We release the barrier
+                    self._release()
+                else:
+                    # We wait until someone releases us
+                    self._wait(timeout)
+                return index
+            finally:
+                self._count -= 1
+                # Wake up any threads waiting for barrier to drain.
+                self._exit()
+
+    # Block until the barrier is ready for us, or raise an exception
+    # if it is broken.
+    def _enter(self):
+        while self._state in (-1, 1):
+            # It is draining or resetting, wait until done
+            self._cond.wait()
+        #see if the barrier is in a broken state
+        if self._state < 0:
+            raise BrokenBarrierError
+        assert self._state == 0
+
+    # Optionally run the 'action' and release the threads waiting
+    # in the barrier.
+    def _release(self):
+        try:
+            if self._action:
+                self._action()
+            # enter draining state
+            self._state = 1
+            self._cond.notify_all()
+        except:
+            #an exception during the _action handler.  Break and reraise
+            self._break()
+            raise
+
+    # Wait in the barrier until we are relased.  Raise an exception
+    # if the barrier is reset or broken.
+    def _wait(self, timeout):
+        while self._state == 0:
+            if self._cond.wait(timeout) is False:
+                #timed out.  Break the barrier
+                self._break()
+                raise BrokenBarrierError
+            if self._state < 0:
+                raise BrokenBarrierError
+        assert self._state == 1
+
+    # If we are the last thread to exit the barrier, signal any threads
+    # waiting for the barrier to drain.
+    def _exit(self):
+        if self._count == 0:
+            if self._state in (-1, 1):
+                #resetting or draining
+                self._state = 0
+                self._cond.notify_all()
+
+    def reset(self):
+        """
+        Reset the barrier to the initial state.
+        Any threads currently waiting will get the BrokenBarrier exception
+        raised.
+        """
+        with self._cond:
+            if self._count > 0:
+                if self._state == 0:
+                    #reset the barrier, waking up threads
+                    self._state = -1
+                elif self._state == -2:
+                    #was broken, set it to reset state
+                    #which clears when the last thread exits
+                    self._state = -1
+            else:
+                self._state = 0
+            self._cond.notify_all()
+
+    def abort(self):
+        """
+        Place the barrier into a 'broken' state.
+        Useful in case of error.  Any currently waiting threads and
+        threads attempting to 'wait()' will have BrokenBarrierError
+        raised.
+        """
+        with self._cond:
+            self._break()
+
+    def _break(self):
+        # An internal error was detected.  The barrier is set to
+        # a broken state all parties awakened.
+        self._state = -2
+        self._cond.notify_all()
+
+    @property
+    def parties(self):
+        """
+        Return the number of threads required to trip the barrier.
+        """
+        return self._parties
+
+    @property
+    def n_waiting(self):
+        """
+        Return the number of threads that are currently waiting at the barrier.
+        """
+        # We don't need synchronization here since this is an ephemeral result
+        # anyway.  It returns the correct value in the steady state.
+        if self._state == 0:
+            return self._count
+        return 0
+
+    @property
+    def broken(self):
+        """
+        Return True if the barrier is in a broken state
+        """
+        return self._state == -2
+
+#exception raised by the Barrier class
+class BrokenBarrierError(RuntimeError): pass
+
+
 # Helper to generate new thread names
 _counter = 0
 def _newname(template="Thread-%d"):