]> granicus.if.org Git - python/commitdiff
asyncio.tasks: Fix as_completed, gather & wait to work with duplicate coroutines
authorYury Selivanov <yselivanov@sprymix.com>
Fri, 7 Feb 2014 03:06:16 +0000 (22:06 -0500)
committerYury Selivanov <yselivanov@sprymix.com>
Fri, 7 Feb 2014 03:06:16 +0000 (22:06 -0500)
Lib/asyncio/tasks.py
Lib/test/test_asyncio/test_tasks.py

index a5708b4c2c51cd6bb03440f5dbb9058e32138458..5ad06520e9a4e5cd92ebed42d2ce102597244844 100644 (file)
@@ -364,7 +364,7 @@ def wait(fs, *, loop=None, timeout=None, return_when=ALL_COMPLETED):
     if loop is None:
         loop = events.get_event_loop()
 
-    fs = set(async(f, loop=loop) for f in fs)
+    fs = {async(f, loop=loop) for f in set(fs)}
 
     if return_when not in (FIRST_COMPLETED, FIRST_EXCEPTION, ALL_COMPLETED):
         raise ValueError('Invalid return_when value: {}'.format(return_when))
@@ -476,7 +476,7 @@ def as_completed(fs, *, loop=None, timeout=None):
     """
     loop = loop if loop is not None else events.get_event_loop()
     deadline = None if timeout is None else loop.time() + timeout
-    todo = set(async(f, loop=loop) for f in fs)
+    todo = {async(f, loop=loop) for f in set(fs)}
     completed = collections.deque()
 
     @coroutine
@@ -568,7 +568,8 @@ def gather(*coros_or_futures, loop=None, return_exceptions=False):
     prevent the cancellation of one child to cause other children to
     be cancelled.)
     """
-    children = [async(fut, loop=loop) for fut in coros_or_futures]
+    arg_to_fut = {arg: async(arg, loop=loop) for arg in set(coros_or_futures)}
+    children = [arg_to_fut[arg] for arg in coros_or_futures]
     n = len(children)
     if n == 0:
         outer = futures.Future(loop=loop)
index f54a0a06decb200eae7b2387ea109cb14e0ace22..d4d4e6390c222f6acabfed994eff2f6f13f93ba3 100644 (file)
@@ -483,6 +483,21 @@ class TaskTests(unittest.TestCase):
 
         self.assertEqual(res, 42)
 
+    def test_wait_duplicate_coroutines(self):
+        @asyncio.coroutine
+        def coro(s):
+            return s
+        c = coro('test')
+
+        task = asyncio.Task(
+            asyncio.wait([c, c, coro('spam')], loop=self.loop),
+            loop=self.loop)
+
+        done, pending = self.loop.run_until_complete(task)
+
+        self.assertFalse(pending)
+        self.assertEqual(set(f.result() for f in done), {'test', 'spam'})
+
     def test_wait_errors(self):
         self.assertRaises(
             ValueError, self.loop.run_until_complete,
@@ -757,14 +772,10 @@ class TaskTests(unittest.TestCase):
     def test_as_completed_with_timeout(self):
 
         def gen():
-            when = yield
-            self.assertAlmostEqual(0.12, when)
-            when = yield 0
-            self.assertAlmostEqual(0.1, when)
-            when = yield 0
-            self.assertAlmostEqual(0.15, when)
-            when = yield 0.1
-            self.assertAlmostEqual(0.12, when)
+            yield
+            yield 0
+            yield 0
+            yield 0.1
             yield 0.02
 
         loop = test_utils.TestLoop(gen)
@@ -840,6 +851,25 @@ class TaskTests(unittest.TestCase):
         done, pending = loop.run_until_complete(waiter)
         self.assertEqual(set(f.result() for f in done), {'a', 'b'})
 
+    def test_as_completed_duplicate_coroutines(self):
+        @asyncio.coroutine
+        def coro(s):
+            return s
+
+        @asyncio.coroutine
+        def runner():
+            result = []
+            c = coro('ham')
+            for f in asyncio.as_completed({c, c, coro('spam')}, loop=self.loop):
+                result.append((yield from f))
+            return result
+
+        fut = asyncio.Task(runner(), loop=self.loop)
+        self.loop.run_until_complete(fut)
+        result = fut.result()
+        self.assertEqual(set(result), {'ham', 'spam'})
+        self.assertEqual(len(result), 2)
+
     def test_sleep(self):
 
         def gen():
@@ -1505,6 +1535,15 @@ class CoroutineGatherTests(GatherTestsBase, unittest.TestCase):
         gen3.close()
         gen4.close()
 
+    def test_duplicate_coroutines(self):
+        @asyncio.coroutine
+        def coro(s):
+            return s
+        c = coro('abc')
+        fut = asyncio.gather(c, c, coro('def'), c, loop=self.one_loop)
+        self._run_loop(self.one_loop)
+        self.assertEqual(fut.result(), ['abc', 'abc', 'def', 'abc'])
+
     def test_cancellation_broadcast(self):
         # Cancelling outer() cancels all children.
         proof = 0