]> granicus.if.org Git - python/commitdiff
Issue #9244: multiprocessing.pool: Worker crashes if result can't be encoded
authorAsk Solem <askh@opera.com>
Tue, 9 Nov 2010 20:55:52 +0000 (20:55 +0000)
committerAsk Solem <askh@opera.com>
Tue, 9 Nov 2010 20:55:52 +0000 (20:55 +0000)
Lib/multiprocessing/pool.py
Lib/test/test_multiprocessing.py

index 7154d3c090a7a0f8d183e8454a531d58ea6860ff..c170cca11d3e678406bfa226e9b56b6639215217 100644 (file)
@@ -42,6 +42,23 @@ def mapstar(args):
 # Code run by worker processes
 #
 
+class MaybeEncodingError(Exception):
+    """Wraps possible unpickleable errors, so they can be
+    safely sent through the socket."""
+
+    def __init__(self, exc, value):
+        self.exc = repr(exc)
+        self.value = repr(value)
+        super(MaybeEncodingError, self).__init__(self.exc, self.value)
+
+    def __str__(self):
+        return "Error sending result: '%s'. Reason: '%s'" % (self.value,
+                                                             self.exc)
+
+    def __repr__(self):
+        return "<MaybeEncodingError: %s>" % str(self)
+
+
 def worker(inqueue, outqueue, initializer=None, initargs=(), maxtasks=None):
     assert maxtasks is None or (type(maxtasks) == int and maxtasks > 0)
     put = outqueue.put
@@ -70,7 +87,13 @@ def worker(inqueue, outqueue, initializer=None, initargs=(), maxtasks=None):
             result = (True, func(*args, **kwds))
         except Exception as e:
             result = (False, e)
-        put((job, i, result))
+        try:
+            put((job, i, result))
+        except Exception as e:
+            wrapped = MaybeEncodingError(e, result[1])
+            debug("Possible encoding error while sending result: %s" % (
+                wrapped))
+            put((job, i, (False, wrapped)))
         completed += 1
     debug('worker exiting after %d tasks' % completed)
 
@@ -235,16 +258,18 @@ class Pool(object):
                      for i, x in enumerate(task_batches)), result._set_length))
             return (item for chunk in result for item in chunk)
 
-    def apply_async(self, func, args=(), kwds={}, callback=None):
+    def apply_async(self, func, args=(), kwds={}, callback=None,
+            error_callback=None):
         '''
         Asynchronous version of `apply()` method.
         '''
         assert self._state == RUN
-        result = ApplyResult(self._cache, callback)
+        result = ApplyResult(self._cache, callback, error_callback)
         self._taskqueue.put(([(result._job, None, func, args, kwds)], None))
         return result
 
-    def map_async(self, func, iterable, chunksize=None, callback=None):
+    def map_async(self, func, iterable, chunksize=None, callback=None,
+            error_callback=None):
         '''
         Asynchronous version of `map()` method.
         '''
@@ -260,7 +285,8 @@ class Pool(object):
             chunksize = 0
 
         task_batches = Pool._get_tasks(func, iterable, chunksize)
-        result = MapResult(self._cache, chunksize, len(iterable), callback)
+        result = MapResult(self._cache, chunksize, len(iterable), callback,
+                           error_callback=error_callback)
         self._taskqueue.put((((result._job, i, mapstar, (x,), {})
                               for i, x in enumerate(task_batches)), None))
         return result
@@ -459,12 +485,13 @@ class Pool(object):
 
 class ApplyResult(object):
 
-    def __init__(self, cache, callback):
+    def __init__(self, cache, callback, error_callback):
         self._cond = threading.Condition(threading.Lock())
         self._job = next(job_counter)
         self._cache = cache
         self._ready = False
         self._callback = callback
+        self._error_callback = error_callback
         cache[self._job] = self
 
     def ready(self):
@@ -495,6 +522,8 @@ class ApplyResult(object):
         self._success, self._value = obj
         if self._callback and self._success:
             self._callback(self._value)
+        if self._error_callback and not self._success:
+            self._error_callback(self._value)
         self._cond.acquire()
         try:
             self._ready = True
@@ -509,8 +538,9 @@ class ApplyResult(object):
 
 class MapResult(ApplyResult):
 
-    def __init__(self, cache, chunksize, length, callback):
-        ApplyResult.__init__(self, cache, callback)
+    def __init__(self, cache, chunksize, length, callback, error_callback):
+        ApplyResult.__init__(self, cache, callback,
+                             error_callback=error_callback)
         self._success = True
         self._value = [None] * length
         self._chunksize = chunksize
@@ -535,10 +565,11 @@ class MapResult(ApplyResult):
                     self._cond.notify()
                 finally:
                     self._cond.release()
-
         else:
             self._success = False
             self._value = result
+            if self._error_callback:
+                self._error_callback(self._value)
             del self._cache[self._job]
             self._cond.acquire()
             try:
index 0b3f937ace9c2884152ec6bc566c32db332dfad8..bb0f06adf6abf3a2fbedf4d9a0d4874f21c3fa50 100644 (file)
@@ -1011,6 +1011,7 @@ class _TestContainers(BaseTestCase):
 def sqr(x, wait=0.0):
     time.sleep(wait)
     return x*x
+
 class _TestPool(BaseTestCase):
 
     def test_apply(self):
@@ -1087,9 +1088,55 @@ class _TestPool(BaseTestCase):
         join()
         self.assertTrue(join.elapsed < 0.2)
 
-class _TestPoolWorkerLifetime(BaseTestCase):
+def raising():
+    raise KeyError("key")
+
+def unpickleable_result():
+    return lambda: 42
+
+class _TestPoolWorkerErrors(BaseTestCase):
+    ALLOWED_TYPES = ('processes', )
+
+    def test_async_error_callback(self):
+        p = multiprocessing.Pool(2)
+
+        scratchpad = [None]
+        def errback(exc):
+            scratchpad[0] = exc
+
+        res = p.apply_async(raising, error_callback=errback)
+        self.assertRaises(KeyError, res.get)
+        self.assertTrue(scratchpad[0])
+        self.assertIsInstance(scratchpad[0], KeyError)
+
+        p.close()
+        p.join()
+
+    def test_unpickleable_result(self):
+        from multiprocessing.pool import MaybeEncodingError
+        p = multiprocessing.Pool(2)
+
+        # Make sure we don't lose pool processes because of encoding errors.
+        for iteration in range(20):
+
+            scratchpad = [None]
+            def errback(exc):
+                scratchpad[0] = exc
+
+            res = p.apply_async(unpickleable_result, error_callback=errback)
+            self.assertRaises(MaybeEncodingError, res.get)
+            wrapped = scratchpad[0]
+            self.assertTrue(wrapped)
+            self.assertIsInstance(scratchpad[0], MaybeEncodingError)
+            self.assertIsNotNone(wrapped.exc)
+            self.assertIsNotNone(wrapped.value)
 
+        p.close()
+        p.join()
+
+class _TestPoolWorkerLifetime(BaseTestCase):
     ALLOWED_TYPES = ('processes', )
+
     def test_pool_worker_lifetime(self):
         p = multiprocessing.Pool(3, maxtasksperchild=10)
         self.assertEqual(3, len(p._pool))