From 622be340fdf4110c77e1f86bd13a01fc30c2bb65 Mon Sep 17 00:00:00 2001 From: Yury Selivanov Date: Thu, 6 Feb 2014 22:06:16 -0500 Subject: [PATCH] asyncio.tasks: Fix as_completed, gather & wait to work with duplicate coroutines --- Lib/asyncio/tasks.py | 7 ++-- Lib/test/test_asyncio/test_tasks.py | 55 ++++++++++++++++++++++++----- 2 files changed, 51 insertions(+), 11 deletions(-) diff --git a/Lib/asyncio/tasks.py b/Lib/asyncio/tasks.py index a5708b4c2c..5ad06520e9 100644 --- a/Lib/asyncio/tasks.py +++ b/Lib/asyncio/tasks.py @@ -364,7 +364,7 @@ def wait(fs, *, loop=None, timeout=None, return_when=ALL_COMPLETED): if loop is None: loop = events.get_event_loop() - fs = set(async(f, loop=loop) for f in fs) + fs = {async(f, loop=loop) for f in set(fs)} if return_when not in (FIRST_COMPLETED, FIRST_EXCEPTION, ALL_COMPLETED): raise ValueError('Invalid return_when value: {}'.format(return_when)) @@ -476,7 +476,7 @@ def as_completed(fs, *, loop=None, timeout=None): """ loop = loop if loop is not None else events.get_event_loop() deadline = None if timeout is None else loop.time() + timeout - todo = set(async(f, loop=loop) for f in fs) + todo = {async(f, loop=loop) for f in set(fs)} completed = collections.deque() @coroutine @@ -568,7 +568,8 @@ def gather(*coros_or_futures, loop=None, return_exceptions=False): prevent the cancellation of one child to cause other children to be cancelled.) """ - children = [async(fut, loop=loop) for fut in coros_or_futures] + arg_to_fut = {arg: async(arg, loop=loop) for arg in set(coros_or_futures)} + children = [arg_to_fut[arg] for arg in coros_or_futures] n = len(children) if n == 0: outer = futures.Future(loop=loop) diff --git a/Lib/test/test_asyncio/test_tasks.py b/Lib/test/test_asyncio/test_tasks.py index f54a0a06de..d4d4e6390c 100644 --- a/Lib/test/test_asyncio/test_tasks.py +++ b/Lib/test/test_asyncio/test_tasks.py @@ -483,6 +483,21 @@ class TaskTests(unittest.TestCase): self.assertEqual(res, 42) + def test_wait_duplicate_coroutines(self): + @asyncio.coroutine + def coro(s): + return s + c = coro('test') + + task = asyncio.Task( + asyncio.wait([c, c, coro('spam')], loop=self.loop), + loop=self.loop) + + done, pending = self.loop.run_until_complete(task) + + self.assertFalse(pending) + self.assertEqual(set(f.result() for f in done), {'test', 'spam'}) + def test_wait_errors(self): self.assertRaises( ValueError, self.loop.run_until_complete, @@ -757,14 +772,10 @@ class TaskTests(unittest.TestCase): def test_as_completed_with_timeout(self): def gen(): - when = yield - self.assertAlmostEqual(0.12, when) - when = yield 0 - self.assertAlmostEqual(0.1, when) - when = yield 0 - self.assertAlmostEqual(0.15, when) - when = yield 0.1 - self.assertAlmostEqual(0.12, when) + yield + yield 0 + yield 0 + yield 0.1 yield 0.02 loop = test_utils.TestLoop(gen) @@ -840,6 +851,25 @@ class TaskTests(unittest.TestCase): done, pending = loop.run_until_complete(waiter) self.assertEqual(set(f.result() for f in done), {'a', 'b'}) + def test_as_completed_duplicate_coroutines(self): + @asyncio.coroutine + def coro(s): + return s + + @asyncio.coroutine + def runner(): + result = [] + c = coro('ham') + for f in asyncio.as_completed({c, c, coro('spam')}, loop=self.loop): + result.append((yield from f)) + return result + + fut = asyncio.Task(runner(), loop=self.loop) + self.loop.run_until_complete(fut) + result = fut.result() + self.assertEqual(set(result), {'ham', 'spam'}) + self.assertEqual(len(result), 2) + def test_sleep(self): def gen(): @@ -1505,6 +1535,15 @@ class CoroutineGatherTests(GatherTestsBase, unittest.TestCase): gen3.close() gen4.close() + def test_duplicate_coroutines(self): + @asyncio.coroutine + def coro(s): + return s + c = coro('abc') + fut = asyncio.gather(c, c, coro('def'), c, loop=self.one_loop) + self._run_loop(self.one_loop) + self.assertEqual(fut.result(), ['abc', 'abc', 'def', 'abc']) + def test_cancellation_broadcast(self): # Cancelling outer() cancels all children. proof = 0 -- 2.40.0