]> granicus.if.org Git - python/commitdiff
Issue #25233: Rewrite the guts of Queue to be more understandable and correct.
authorGuido van Rossum <guido@python.org>
Mon, 28 Sep 2015 14:42:34 +0000 (07:42 -0700)
committerGuido van Rossum <guido@python.org>
Mon, 28 Sep 2015 14:42:34 +0000 (07:42 -0700)
Lib/asyncio/queues.py
Lib/test/test_asyncio/test_queues.py
Misc/NEWS

index 021043d6be2521aa69f71c8a097e6b0b0328e421..e3a1d5ed60e3b8686ce9fce40c7f12f2a585d517 100644 (file)
@@ -47,7 +47,7 @@ class Queue:
 
         # Futures.
         self._getters = collections.deque()
-        # Futures
+        # Futures.
         self._putters = collections.deque()
         self._unfinished_tasks = 0
         self._finished = locks.Event(loop=self._loop)
@@ -67,10 +67,13 @@ class Queue:
 
     # End of the overridable methods.
 
-    def __put_internal(self, item):
-        self._put(item)
-        self._unfinished_tasks += 1
-        self._finished.clear()
+    def _wakeup_next(self, waiters):
+        # Wake up the next waiter (if any) that isn't cancelled.
+        while waiters:
+            waiter = waiters.popleft()
+            if not waiter.done():
+                waiter.set_result(None)
+                break
 
     def __repr__(self):
         return '<{} at {:#x} {}>'.format(
@@ -91,16 +94,6 @@ class Queue:
             result += ' tasks={}'.format(self._unfinished_tasks)
         return result
 
-    def _consume_done_getters(self):
-        # Delete waiters at the head of the get() queue who've timed out.
-        while self._getters and self._getters[0].done():
-            self._getters.popleft()
-
-    def _consume_done_putters(self):
-        # Delete waiters at the head of the put() queue who've timed out.
-        while self._putters and self._putters[0].done():
-            self._putters.popleft()
-
     def qsize(self):
         """Number of items in the queue."""
         return len(self._queue)
@@ -134,47 +127,31 @@ class Queue:
 
         This method is a coroutine.
         """
-        self._consume_done_getters()
-        if self._getters:
-            assert not self._queue, (
-                'queue non-empty, why are getters waiting?')
-
-            getter = self._getters.popleft()
-            self.__put_internal(item)
-
-            # getter cannot be cancelled, we just removed done getters
-            getter.set_result(self._get())
-
-        elif self._maxsize > 0 and self._maxsize <= self.qsize():
-            waiter = futures.Future(loop=self._loop)
-
-            self._putters.append(waiter)
-            yield from waiter
-            self._put(item)
-
-        else:
-            self.__put_internal(item)
+        while self.full():
+            putter = futures.Future(loop=self._loop)
+            self._putters.append(putter)
+            try:
+                yield from putter
+            except:
+                putter.cancel()  # Just in case putter is not done yet.
+                if not self.full() and not putter.cancelled():
+                    # We were woken up by get_nowait(), but can't take
+                    # the call.  Wake up the next in line.
+                    self._wakeup_next(self._putters)
+                raise
+        return self.put_nowait(item)
 
     def put_nowait(self, item):
         """Put an item into the queue without blocking.
 
         If no free slot is immediately available, raise QueueFull.
         """
-        self._consume_done_getters()
-        if self._getters:
-            assert not self._queue, (
-                'queue non-empty, why are getters waiting?')
-
-            getter = self._getters.popleft()
-            self.__put_internal(item)
-
-            # getter cannot be cancelled, we just removed done getters
-            getter.set_result(self._get())
-
-        elif self._maxsize > 0 and self._maxsize <= self.qsize():
+        if self.full():
             raise QueueFull
-        else:
-            self.__put_internal(item)
+        self._put(item)
+        self._unfinished_tasks += 1
+        self._finished.clear()
+        self._wakeup_next(self._getters)
 
     @coroutine
     def get(self):
@@ -184,77 +161,30 @@ class Queue:
 
         This method is a coroutine.
         """
-        self._consume_done_putters()
-        if self._putters:
-            assert self.full(), 'queue not full, why are putters waiting?'
-            putter = self._putters.popleft()
-
-            # When a getter runs and frees up a slot so this putter can
-            # run, we need to defer the put for a tick to ensure that
-            # getters and putters alternate perfectly. See
-            # ChannelTest.test_wait.
-            self._loop.call_soon(putter._set_result_unless_cancelled, None)
-
-            return self._get()
-
-        elif self.qsize():
-            return self._get()
-        else:
-            waiter = futures.Future(loop=self._loop)
-            self._getters.append(waiter)
+        while self.empty():
+            getter = futures.Future(loop=self._loop)
+            self._getters.append(getter)
             try:
-                return (yield from waiter)
-            except futures.CancelledError:
-                # if we get CancelledError, it means someone cancelled this
-                # get() coroutine.  But there is a chance that the waiter
-                # already is ready and contains an item that has just been
-                # removed from the queue.  In this case, we need to put the item
-                # back into the front of the queue.  This get() must either
-                # succeed without fault or, if it gets cancelled, it must be as
-                # if it never happened.
-                if waiter.done():
-                    self._put_it_back(waiter.result())
+                yield from getter
+            except:
+                getter.cancel()  # Just in case getter is not done yet.
+                if not self.empty() and not getter.cancelled():
+                    # We were woken up by put_nowait(), but can't take
+                    # the call.  Wake up the next in line.
+                    self._wakeup_next(self._getters)
                 raise
-
-    def _put_it_back(self, item):
-        """
-        This is called when we have a waiter to get() an item and this waiter
-        gets cancelled.  In this case, we put the item back: wake up another
-        waiter or put it in the _queue.
-        """
-        self._consume_done_getters()
-        if self._getters:
-            assert not self._queue, (
-                'queue non-empty, why are getters waiting?')
-
-            getter = self._getters.popleft()
-            self.__put_internal(item)
-
-            # getter cannot be cancelled, we just removed done getters
-            getter.set_result(item)
-        else:
-            self._queue.appendleft(item)
+        return self.get_nowait()
 
     def get_nowait(self):
         """Remove and return an item from the queue.
 
         Return an item if one is immediately available, else raise QueueEmpty.
         """
-        self._consume_done_putters()
-        if self._putters:
-            assert self.full(), 'queue not full, why are putters waiting?'
-            putter = self._putters.popleft()
-            # Wake putter on next tick.
-
-            # getter cannot be cancelled, we just removed done putters
-            putter.set_result(None)
-
-            return self._get()
-
-        elif self.qsize():
-            return self._get()
-        else:
+        if self.empty():
             raise QueueEmpty
+        item = self._get()
+        self._wakeup_next(self._putters)
+        return item
 
     def task_done(self):
         """Indicate that a formerly enqueued task is complete.
index 8e38175e40476c6bbdc1cb12db64543228872e8a..591a9bb53516af6a84356c56be9b1a35093c8bd2 100644 (file)
@@ -271,6 +271,29 @@ class QueueGetTests(_QueueTestBase):
         self.assertEqual(self.loop.run_until_complete(q.get()), 'a')
         self.assertEqual(self.loop.run_until_complete(q.get()), 'b')
 
+    def test_why_are_getters_waiting(self):
+        # From issue #268.
+
+        @asyncio.coroutine
+        def consumer(queue, num_expected):
+            for _ in range(num_expected):
+                yield from queue.get()
+
+        @asyncio.coroutine
+        def producer(queue, num_items):
+            for i in range(num_items):
+                yield from queue.put(i)
+
+        queue_size = 1
+        producer_num_items = 5
+        q = asyncio.Queue(queue_size, loop=self.loop)
+
+        self.loop.run_until_complete(
+            asyncio.gather(producer(q, producer_num_items),
+                           consumer(q, producer_num_items),
+                           loop=self.loop),
+            )
+
 
 class QueuePutTests(_QueueTestBase):
 
@@ -377,13 +400,8 @@ class QueuePutTests(_QueueTestBase):
 
         loop.run_until_complete(reader3)
 
-        # reader2 will receive `2`, because it was added to the
-        # queue of pending readers *before* put_nowaits were called.
-        self.assertEqual(reader2.result(), 2)
-        # reader3 will receive `1`, because reader1 was cancelled
-        # before is had a chance to execute, and `2` was already
-        # pushed to reader2 by second `put_nowait`.
-        self.assertEqual(reader3.result(), 1)
+        # It is undefined in which order concurrent readers receive results.
+        self.assertEqual({reader2.result(), reader3.result()}, {1, 2})
 
     def test_put_cancel_drop(self):
 
@@ -479,6 +497,29 @@ class QueuePutTests(_QueueTestBase):
         self.loop.run_until_complete(q.put('a'))
         self.assertEqual(self.loop.run_until_complete(t), 'a')
 
+    def test_why_are_putters_waiting(self):
+        # From issue #265.
+
+        queue = asyncio.Queue(2, loop=self.loop)
+
+        @asyncio.coroutine
+        def putter(item):
+            yield from queue.put(item)
+
+        @asyncio.coroutine
+        def getter():
+            yield
+            num = queue.qsize()
+            for _ in range(num):
+                item = queue.get_nowait()
+
+        t0 = putter(0)
+        t1 = putter(1)
+        t2 = putter(2)
+        t3 = putter(3)
+        self.loop.run_until_complete(
+            asyncio.gather(getter(), t0, t1, t2, t3, loop=self.loop))
+
 
 class LifoQueueTests(_QueueTestBase):
 
index 63730faa2ef719ef54b5d5fd13f20084a4bc48c2..3b76bbdf9ae5f6179c872bb7163a982569f1d5ad 100644 (file)
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -81,6 +81,8 @@ Core and Builtins
 Library
 -------
 
+- Issue #25233: Rewrite the guts of Queue to be more understandable and correct.
+
 - Issue #23600: Default implementation of tzinfo.fromutc() was returning
   wrong results in some cases.