]> granicus.if.org Git - python/commitdiff
bpo-28699: fix abnormal behaviour of pools in multiprocessing.pool (GH-882)
authorXiang Zhang <angwerzx@126.com>
Wed, 29 Mar 2017 04:50:28 +0000 (12:50 +0800)
committerGitHub <noreply@github.com>
Wed, 29 Mar 2017 04:50:28 +0000 (12:50 +0800)
an exception raised at the very first of an iterable would cause pools behave abnormally
(swallow the exception or hang)

Lib/multiprocessing/pool.py
Lib/test/_test_multiprocessing.py
Misc/NEWS

index ae8cec44796b2115fd51dd4f10329283f6dc3c50..a545f3c1a189613b9d95bf9a921c1c2abd80692c 100644 (file)
@@ -118,7 +118,7 @@ def worker(inqueue, outqueue, initializer=None, initargs=(), maxtasks=None,
         try:
             result = (True, func(*args, **kwds))
         except Exception as e:
-            if wrap_exception:
+            if wrap_exception and func is not _helper_reraises_exception:
                 e = ExceptionWithTraceback(e, e.__traceback__)
             result = (False, e)
         try:
@@ -133,6 +133,10 @@ def worker(inqueue, outqueue, initializer=None, initargs=(), maxtasks=None,
         completed += 1
     util.debug('worker exiting after %d tasks' % completed)
 
+def _helper_reraises_exception(ex):
+    'Pickle-able helper function for use by _guarded_task_generation.'
+    raise ex
+
 #
 # Class representing a process pool
 #
@@ -277,6 +281,17 @@ class Pool(object):
         return self._map_async(func, iterable, starmapstar, chunksize,
                                callback, error_callback)
 
+    def _guarded_task_generation(self, result_job, func, iterable):
+        '''Provides a generator of tasks for imap and imap_unordered with
+        appropriate handling for iterables which throw exceptions during
+        iteration.'''
+        try:
+            i = -1
+            for i, x in enumerate(iterable):
+                yield (result_job, i, func, (x,), {})
+        except Exception as e:
+            yield (result_job, i+1, _helper_reraises_exception, (e,), {})
+
     def imap(self, func, iterable, chunksize=1):
         '''
         Equivalent of `map()` -- can be MUCH slower than `Pool.map()`.
@@ -285,15 +300,23 @@ class Pool(object):
             raise ValueError("Pool not running")
         if chunksize == 1:
             result = IMapIterator(self._cache)
-            self._taskqueue.put((((result._job, i, func, (x,), {})
-                         for i, x in enumerate(iterable)), result._set_length))
+            self._taskqueue.put(
+                (
+                    self._guarded_task_generation(result._job, func, iterable),
+                    result._set_length
+                ))
             return result
         else:
             assert chunksize > 1
             task_batches = Pool._get_tasks(func, iterable, chunksize)
             result = IMapIterator(self._cache)
-            self._taskqueue.put((((result._job, i, mapstar, (x,), {})
-                     for i, x in enumerate(task_batches)), result._set_length))
+            self._taskqueue.put(
+                (
+                    self._guarded_task_generation(result._job,
+                                                  mapstar,
+                                                  task_batches),
+                    result._set_length
+                ))
             return (item for chunk in result for item in chunk)
 
     def imap_unordered(self, func, iterable, chunksize=1):
@@ -304,15 +327,23 @@ class Pool(object):
             raise ValueError("Pool not running")
         if chunksize == 1:
             result = IMapUnorderedIterator(self._cache)
-            self._taskqueue.put((((result._job, i, func, (x,), {})
-                         for i, x in enumerate(iterable)), result._set_length))
+            self._taskqueue.put(
+                (
+                    self._guarded_task_generation(result._job, func, iterable),
+                    result._set_length
+                ))
             return result
         else:
             assert chunksize > 1
             task_batches = Pool._get_tasks(func, iterable, chunksize)
             result = IMapUnorderedIterator(self._cache)
-            self._taskqueue.put((((result._job, i, mapstar, (x,), {})
-                     for i, x in enumerate(task_batches)), result._set_length))
+            self._taskqueue.put(
+                (
+                    self._guarded_task_generation(result._job,
+                                                  mapstar,
+                                                  task_batches),
+                    result._set_length
+                ))
             return (item for chunk in result for item in chunk)
 
     def apply_async(self, func, args=(), kwds={}, callback=None,
@@ -323,7 +354,7 @@ class Pool(object):
         if self._state != RUN:
             raise ValueError("Pool not running")
         result = ApplyResult(self._cache, callback, error_callback)
-        self._taskqueue.put(([(result._job, None, func, args, kwds)], None))
+        self._taskqueue.put(([(result._job, 0, func, args, kwds)], None))
         return result
 
     def map_async(self, func, iterable, chunksize=None, callback=None,
@@ -354,8 +385,14 @@ class Pool(object):
         task_batches = Pool._get_tasks(func, iterable, chunksize)
         result = MapResult(self._cache, chunksize, len(iterable), callback,
                            error_callback=error_callback)
-        self._taskqueue.put((((result._job, i, mapper, (x,), {})
-                              for i, x in enumerate(task_batches)), None))
+        self._taskqueue.put(
+            (
+                self._guarded_task_generation(result._job,
+                                              mapper,
+                                              task_batches),
+                None
+            )
+        )
         return result
 
     @staticmethod
@@ -377,33 +414,27 @@ class Pool(object):
 
         for taskseq, set_length in iter(taskqueue.get, None):
             task = None
-            i = -1
             try:
-                for i, task in enumerate(taskseq):
+                # iterating taskseq cannot fail
+                for task in taskseq:
                     if thread._state:
                         util.debug('task handler found thread._state != RUN')
                         break
                     try:
                         put(task)
                     except Exception as e:
-                        job, ind = task[:2]
+                        job, idx = task[:2]
                         try:
-                            cache[job]._set(ind, (False, e))
+                            cache[job]._set(idx, (False, e))
                         except KeyError:
                             pass
                 else:
                     if set_length:
                         util.debug('doing set_length()')
-                        set_length(i+1)
+                        idx = task[1] if task else -1
+                        set_length(idx + 1)
                     continue
                 break
-            except Exception as ex:
-                job, ind = task[:2] if task else (0, 0)
-                if job in cache:
-                    cache[job]._set(ind + 1, (False, ex))
-                if set_length:
-                    util.debug('doing set_length()')
-                    set_length(i+1)
             finally:
                 task = taskseq = job = None
         else:
index 1d3bb0f8bae78161374b0b6d480f25bfb342d86b..771bbf24265b91204df4fc7dacbd3fccc6c624d0 100644 (file)
@@ -1755,6 +1755,8 @@ class CountedObject(object):
 class SayWhenError(ValueError): pass
 
 def exception_throwing_generator(total, when):
+    if when == -1:
+        raise SayWhenError("Somebody said when")
     for i in range(total):
         if i == when:
             raise SayWhenError("Somebody said when")
@@ -1833,6 +1835,32 @@ class _TestPool(BaseTestCase):
         except multiprocessing.TimeoutError:
             self.fail("pool.map_async with chunksize stalled on null list")
 
+    def test_map_handle_iterable_exception(self):
+        if self.TYPE == 'manager':
+            self.skipTest('test not appropriate for {}'.format(self.TYPE))
+
+        # SayWhenError seen at the very first of the iterable
+        with self.assertRaises(SayWhenError):
+            self.pool.map(sqr, exception_throwing_generator(1, -1), 1)
+        # again, make sure it's reentrant
+        with self.assertRaises(SayWhenError):
+            self.pool.map(sqr, exception_throwing_generator(1, -1), 1)
+
+        with self.assertRaises(SayWhenError):
+            self.pool.map(sqr, exception_throwing_generator(10, 3), 1)
+
+        class SpecialIterable:
+            def __iter__(self):
+                return self
+            def __next__(self):
+                raise SayWhenError
+            def __len__(self):
+                return 1
+        with self.assertRaises(SayWhenError):
+            self.pool.map(sqr, SpecialIterable(), 1)
+        with self.assertRaises(SayWhenError):
+            self.pool.map(sqr, SpecialIterable(), 1)
+
     def test_async(self):
         res = self.pool.apply_async(sqr, (7, TIMEOUT1,))
         get = TimingWrapper(res.get)
@@ -1863,6 +1891,13 @@ class _TestPool(BaseTestCase):
         if self.TYPE == 'manager':
             self.skipTest('test not appropriate for {}'.format(self.TYPE))
 
+        # SayWhenError seen at the very first of the iterable
+        it = self.pool.imap(sqr, exception_throwing_generator(1, -1), 1)
+        self.assertRaises(SayWhenError, it.__next__)
+        # again, make sure it's reentrant
+        it = self.pool.imap(sqr, exception_throwing_generator(1, -1), 1)
+        self.assertRaises(SayWhenError, it.__next__)
+
         it = self.pool.imap(sqr, exception_throwing_generator(10, 3), 1)
         for i in range(3):
             self.assertEqual(next(it), i*i)
@@ -1889,6 +1924,17 @@ class _TestPool(BaseTestCase):
         if self.TYPE == 'manager':
             self.skipTest('test not appropriate for {}'.format(self.TYPE))
 
+        # SayWhenError seen at the very first of the iterable
+        it = self.pool.imap_unordered(sqr,
+                                      exception_throwing_generator(1, -1),
+                                      1)
+        self.assertRaises(SayWhenError, it.__next__)
+        # again, make sure it's reentrant
+        it = self.pool.imap_unordered(sqr,
+                                      exception_throwing_generator(1, -1),
+                                      1)
+        self.assertRaises(SayWhenError, it.__next__)
+
         it = self.pool.imap_unordered(sqr,
                                       exception_throwing_generator(10, 3),
                                       1)
@@ -1970,7 +2016,7 @@ class _TestPool(BaseTestCase):
                 except Exception as e:
                     exc = e
                 else:
-                    raise AssertionError('expected RuntimeError')
+                    self.fail('expected RuntimeError')
             self.assertIs(type(exc), RuntimeError)
             self.assertEqual(exc.args, (123,))
             cause = exc.__cause__
@@ -1984,6 +2030,17 @@ class _TestPool(BaseTestCase):
                     sys.excepthook(*sys.exc_info())
             self.assertIn('raise RuntimeError(123) # some comment',
                           f1.getvalue())
+            # _helper_reraises_exception should not make the error
+            # a remote exception
+            with self.Pool(1) as p:
+                try:
+                    p.map(sqr, exception_throwing_generator(1, -1), 1)
+                except Exception as e:
+                    exc = e
+                else:
+                    self.fail('expected SayWhenError')
+                self.assertIs(type(exc), SayWhenError)
+                self.assertIs(exc.__cause__, None)
 
     @classmethod
     def _test_wrapped_exception(cls):
index 45c5c7828701f87016f66ea61611e2074398ebff..fbaa840638ff97b1b3860a6c4d76d18f8dcc0445 100644 (file)
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -27,6 +27,10 @@ Core and Builtins
 Library
 -------
 
+- bpo-28699: Fixed a bug in pools in multiprocessing.pool that raising an
+  exception at the very first of an iterable may swallow the exception or
+  make the program hang. Patch by Davin Potts and Xiang Zhang.
+
 - bpo-25803: Avoid incorrect errors raised by Path.mkdir(exist_ok=True)
   when the OS gives priority to errors such as EACCES over EEXIST.