]> granicus.if.org Git - python/commitdiff
asyncio: Add Task.current_task() class method.
authorGuido van Rossum <guido@python.org>
Fri, 6 Dec 2013 20:57:40 +0000 (12:57 -0800)
committerGuido van Rossum <guido@python.org>
Fri, 6 Dec 2013 20:57:40 +0000 (12:57 -0800)
Lib/asyncio/tasks.py
Lib/asyncio/test_utils.py
Lib/test/test_asyncio/test_tasks.py

index 999e9629bc692c172440296f0744e778b47fe0b9..cd9718f534a71210139da9c576a3ea448794fdba 100644 (file)
@@ -122,6 +122,22 @@ class Task(futures.Future):
     # Weak set containing all tasks alive.
     _all_tasks = weakref.WeakSet()
 
+    # Dictionary containing tasks that are currently active in
+    # all running event loops.  {EventLoop: Task}
+    _current_tasks = {}
+
+    @classmethod
+    def current_task(cls, loop=None):
+        """Return the currently running task in an event loop or None.
+
+        By default the current task for the current event loop is returned.
+
+        None is returned when called not in the context of a Task.
+        """
+        if loop is None:
+            loop = events.get_event_loop()
+        return cls._current_tasks.get(loop)
+
     @classmethod
     def all_tasks(cls, loop=None):
         """Return a set of all tasks for an event loop.
@@ -252,6 +268,8 @@ class Task(futures.Future):
             self._must_cancel = False
         coro = self._coro
         self._fut_waiter = None
+
+        self.__class__._current_tasks[self._loop] = self
         # Call either coro.throw(exc) or coro.send(value).
         try:
             if exc is not None:
@@ -302,6 +320,8 @@ class Task(futures.Future):
                     self._step, None,
                     RuntimeError(
                         'Task got bad yield: {!r}'.format(result)))
+        finally:
+            self.__class__._current_tasks.pop(self._loop)
         self = None
 
     def _wakeup(self, future):
index c26dd8818f09420a217786338f16a378da25cc35..131a546080b22348a6671b167b4a8d5cba22d63c 100644 (file)
@@ -88,7 +88,7 @@ def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False):
     class SSLWSGIServer(SilentWSGIServer):
         def finish_request(self, request, client_address):
             # The relative location of our test directory (which
-            # contains the sample key and certificate files) differs
+            # contains the ssl key and certificate files) differs
             # between the stdlib and stand-alone Tulip/asyncio.
             # Prefer our own if we can find it.
             here = os.path.join(os.path.dirname(__file__), '..', 'tests')
index 8f0d081554bb78bb36fb42b51bf60878035eccad..5470da15430613c07107e1147dd8f051b591b4e8 100644 (file)
@@ -1113,6 +1113,42 @@ class TaskTests(unittest.TestCase):
         self.assertEqual(res, 'test')
         self.assertIsNone(t2.result())
 
+    def test_current_task(self):
+        self.assertIsNone(tasks.Task.current_task(loop=self.loop))
+        @tasks.coroutine
+        def coro(loop):
+            self.assertTrue(tasks.Task.current_task(loop=loop) is task)
+
+        task = tasks.Task(coro(self.loop), loop=self.loop)
+        self.loop.run_until_complete(task)
+        self.assertIsNone(tasks.Task.current_task(loop=self.loop))
+
+    def test_current_task_with_interleaving_tasks(self):
+        self.assertIsNone(tasks.Task.current_task(loop=self.loop))
+
+        fut1 = futures.Future(loop=self.loop)
+        fut2 = futures.Future(loop=self.loop)
+
+        @tasks.coroutine
+        def coro1(loop):
+            self.assertTrue(tasks.Task.current_task(loop=loop) is task1)
+            yield from fut1
+            self.assertTrue(tasks.Task.current_task(loop=loop) is task1)
+            fut2.set_result(True)
+
+        @tasks.coroutine
+        def coro2(loop):
+            self.assertTrue(tasks.Task.current_task(loop=loop) is task2)
+            fut1.set_result(True)
+            yield from fut2
+            self.assertTrue(tasks.Task.current_task(loop=loop) is task2)
+
+        task1 = tasks.Task(coro1(self.loop), loop=self.loop)
+        task2 = tasks.Task(coro2(self.loop), loop=self.loop)
+
+        self.loop.run_until_complete(tasks.wait((task1, task2), loop=self.loop))
+        self.assertIsNone(tasks.Task.current_task(loop=self.loop))
+
     # Some thorough tests for cancellation propagation through
     # coroutines, tasks and wait().