]> granicus.if.org Git - python/commitdiff
Issue #28613: Fix get_event_loop() to return the current loop
authorYury Selivanov <yury@magic.io>
Fri, 4 Nov 2016 18:29:28 +0000 (14:29 -0400)
committerYury Selivanov <yury@magic.io>
Fri, 4 Nov 2016 18:29:28 +0000 (14:29 -0400)
when called from coroutines or callbacks.

18 files changed:
Lib/asyncio/base_events.py
Lib/asyncio/events.py
Lib/asyncio/test_utils.py
Lib/test/test_asyncio/test_base_events.py
Lib/test/test_asyncio/test_events.py
Lib/test/test_asyncio/test_futures.py
Lib/test/test_asyncio/test_locks.py
Lib/test/test_asyncio/test_pep492.py
Lib/test/test_asyncio/test_proactor_events.py
Lib/test/test_asyncio/test_queues.py
Lib/test/test_asyncio/test_selector_events.py
Lib/test/test_asyncio/test_sslproto.py
Lib/test/test_asyncio/test_streams.py
Lib/test/test_asyncio/test_subprocess.py
Lib/test/test_asyncio/test_tasks.py
Lib/test/test_asyncio/test_unix_events.py
Lib/test/test_asyncio/test_windows_events.py
Misc/NEWS

index 5597bcb7cdb0e9dfdfec8a1a21d6399d536eb31d..6488f23d3c89595af99efdd2167b49011b260d82 100644 (file)
@@ -393,7 +393,10 @@ class BaseEventLoop(events.AbstractEventLoop):
         """Run until stop() is called."""
         self._check_closed()
         if self.is_running():
-            raise RuntimeError('Event loop is running.')
+            raise RuntimeError('This event loop is already running')
+        if events._get_running_loop() is not None:
+            raise RuntimeError(
+                'Cannot run the event loop while another loop is running')
         self._set_coroutine_wrapper(self._debug)
         self._thread_id = threading.get_ident()
         if self._asyncgens is not None:
@@ -401,6 +404,7 @@ class BaseEventLoop(events.AbstractEventLoop):
             sys.set_asyncgen_hooks(firstiter=self._asyncgen_firstiter_hook,
                                    finalizer=self._asyncgen_finalizer_hook)
         try:
+            events._set_running_loop(self)
             while True:
                 self._run_once()
                 if self._stopping:
@@ -408,6 +412,7 @@ class BaseEventLoop(events.AbstractEventLoop):
         finally:
             self._stopping = False
             self._thread_id = None
+            events._set_running_loop(None)
             self._set_coroutine_wrapper(False)
             if self._asyncgens is not None:
                 sys.set_asyncgen_hooks(*old_agen_hooks)
index b89b4b205a4437b07ed5ce0288d4ae03a02f6a71..8575e2c1c4029c15f1f1e7043fabcbd245fce8ee 100644 (file)
@@ -607,6 +607,30 @@ _event_loop_policy = None
 _lock = threading.Lock()
 
 
+# A TLS for the running event loop, used by _get_running_loop.
+class _RunningLoop(threading.local):
+    _loop = None
+_running_loop = _RunningLoop()
+
+
+def _get_running_loop():
+    """Return the running event loop or None.
+
+    This is a low-level function intended to be used by event loops.
+    This function is thread-specific.
+    """
+    return _running_loop._loop
+
+
+def _set_running_loop(loop):
+    """Set the running event loop.
+
+    This is a low-level function intended to be used by event loops.
+    This function is thread-specific.
+    """
+    _running_loop._loop = loop
+
+
 def _init_event_loop_policy():
     global _event_loop_policy
     with _lock:
@@ -632,7 +656,17 @@ def set_event_loop_policy(policy):
 
 
 def get_event_loop():
-    """Equivalent to calling get_event_loop_policy().get_event_loop()."""
+    """Return an asyncio event loop.
+
+    When called from a coroutine or a callback (e.g. scheduled with call_soon
+    or similar API), this function will always return the running event loop.
+
+    If there is no running event loop set, the function will return
+    the result of `get_event_loop_policy().get_event_loop()` call.
+    """
+    current_loop = _get_running_loop()
+    if current_loop is not None:
+        return current_loop
     return get_event_loop_policy().get_event_loop()
 
 
index 307fffccc60a2356161085ab10eb5951e29b8b33..9d32822fa9e451402773f16dd49b67f08612517b 100644 (file)
@@ -449,7 +449,13 @@ class TestCase(unittest.TestCase):
         self.set_event_loop(loop)
         return loop
 
+    def setUp(self):
+        self._get_running_loop = events._get_running_loop
+        events._get_running_loop = lambda: None
+
     def tearDown(self):
+        events._get_running_loop = self._get_running_loop
+
         events.set_event_loop(None)
 
         # Detect CPython bug #23353: ensure that yield/yield-from is not used
index 39131256a0a0b94e63806e3fc5775b67ace27f53..cdbd58798d67a910b9f145b6c20c1799b66b4929 100644 (file)
@@ -154,6 +154,7 @@ class BaseEventTests(test_utils.TestCase):
 class BaseEventLoopTests(test_utils.TestCase):
 
     def setUp(self):
+        super().setUp()
         self.loop = base_events.BaseEventLoop()
         self.loop._selector = mock.Mock()
         self.loop._selector.select.return_value = ()
@@ -976,6 +977,7 @@ class MyDatagramProto(asyncio.DatagramProtocol):
 class BaseEventLoopWithSelectorTests(test_utils.TestCase):
 
     def setUp(self):
+        super().setUp()
         self.loop = asyncio.new_event_loop()
         self.set_event_loop(self.loop)
 
@@ -1692,5 +1694,23 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase):
                          "took .* seconds$")
 
 
+class RunningLoopTests(unittest.TestCase):
+
+    def test_running_loop_within_a_loop(self):
+        @asyncio.coroutine
+        def runner(loop):
+            loop.run_forever()
+
+        loop = asyncio.new_event_loop()
+        outer_loop = asyncio.new_event_loop()
+        try:
+            with self.assertRaisesRegex(RuntimeError,
+                                        'while another loop is running'):
+                outer_loop.run_until_complete(runner(loop))
+        finally:
+            loop.close()
+            outer_loop.close()
+
+
 if __name__ == '__main__':
     unittest.main()
index d8946e38f229b28a78cfebbb9f32cdd666b1a04b..4c18300bab25c886233e8558197eb4b38aeeb3b6 100644 (file)
@@ -2233,6 +2233,7 @@ def noop(*args, **kwargs):
 class HandleTests(test_utils.TestCase):
 
     def setUp(self):
+        super().setUp()
         self.loop = mock.Mock()
         self.loop.get_debug.return_value = True
 
@@ -2411,6 +2412,7 @@ class HandleTests(test_utils.TestCase):
 class TimerTests(unittest.TestCase):
 
     def setUp(self):
+        super().setUp()
         self.loop = mock.Mock()
 
     def test_hash(self):
@@ -2719,6 +2721,27 @@ class PolicyTests(unittest.TestCase):
         self.assertIs(policy, asyncio.get_event_loop_policy())
         self.assertIsNot(policy, old_policy)
 
+    def test_get_event_loop_returns_running_loop(self):
+        class Policy(asyncio.DefaultEventLoopPolicy):
+            def get_event_loop(self):
+                raise NotImplementedError
+
+        loop = None
+
+        old_policy = asyncio.get_event_loop_policy()
+        try:
+            asyncio.set_event_loop_policy(Policy())
+            loop = asyncio.new_event_loop()
+
+            async def func():
+                self.assertIs(asyncio.get_event_loop(), loop)
+
+            loop.run_until_complete(func())
+        finally:
+            asyncio.set_event_loop_policy(old_policy)
+            if loop is not None:
+                loop.close()
+
 
 if __name__ == '__main__':
     unittest.main()
index d20eb687f9a2200f68e5cd49d8daaadaea885ca1..153b8ed707f1b3b6c77579335c643195290d3ff4 100644 (file)
@@ -79,6 +79,7 @@ class DuckFuture:
 class DuckTests(test_utils.TestCase):
 
     def setUp(self):
+        super().setUp()
         self.loop = self.new_test_loop()
         self.addCleanup(self.loop.close)
 
@@ -96,6 +97,7 @@ class DuckTests(test_utils.TestCase):
 class FutureTests(test_utils.TestCase):
 
     def setUp(self):
+        super().setUp()
         self.loop = self.new_test_loop()
         self.addCleanup(self.loop.close)
 
@@ -468,6 +470,7 @@ class FutureTests(test_utils.TestCase):
 class FutureDoneCallbackTests(test_utils.TestCase):
 
     def setUp(self):
+        super().setUp()
         self.loop = self.new_test_loop()
 
     def run_briefly(self):
index e557212f9690aee6f62ae8709f8928343dd93142..152948c8138975096aaa243005e95f155f551d4b 100644 (file)
@@ -19,6 +19,7 @@ RGX_REPR = re.compile(STR_RGX_REPR)
 class LockTests(test_utils.TestCase):
 
     def setUp(self):
+        super().setUp()
         self.loop = self.new_test_loop()
 
     def test_ctor_loop(self):
@@ -235,6 +236,7 @@ class LockTests(test_utils.TestCase):
 class EventTests(test_utils.TestCase):
 
     def setUp(self):
+        super().setUp()
         self.loop = self.new_test_loop()
 
     def test_ctor_loop(self):
@@ -364,6 +366,7 @@ class EventTests(test_utils.TestCase):
 class ConditionTests(test_utils.TestCase):
 
     def setUp(self):
+        super().setUp()
         self.loop = self.new_test_loop()
 
     def test_ctor_loop(self):
@@ -699,6 +702,7 @@ class ConditionTests(test_utils.TestCase):
 class SemaphoreTests(test_utils.TestCase):
 
     def setUp(self):
+        super().setUp()
         self.loop = self.new_test_loop()
 
     def test_ctor_loop(self):
index 29aba817ec468b7e29e0d7976e4f9358359f8f1c..d5b852248bb3b1fa1b7ce639703651d6778074d7 100644 (file)
@@ -17,6 +17,7 @@ from asyncio import test_utils
 class BaseTest(test_utils.TestCase):
 
     def setUp(self):
+        super().setUp()
         self.loop = asyncio.BaseEventLoop()
         self.loop._process_events = mock.Mock()
         self.loop._selector = mock.Mock()
index 5a92b1e34a583ffd6ee06b0cf11b09931ad7e025..4dfc61259f873120cdb90eeaa38004ec488a2d4f 100644 (file)
@@ -24,6 +24,7 @@ def close_transport(transport):
 class ProactorSocketTransportTests(test_utils.TestCase):
 
     def setUp(self):
+        super().setUp()
         self.loop = self.new_test_loop()
         self.addCleanup(self.loop.close)
         self.proactor = mock.Mock()
@@ -436,6 +437,8 @@ class ProactorSocketTransportTests(test_utils.TestCase):
 class BaseProactorEventLoopTests(test_utils.TestCase):
 
     def setUp(self):
+        super().setUp()
+
         self.sock = test_utils.mock_nonblocking_socket()
         self.proactor = mock.Mock()
 
index 591a9bb53516af6a84356c56be9b1a35093c8bd2..fe5a6dbfe34ecc019ee6d23389569682e546645d 100644 (file)
@@ -10,6 +10,7 @@ from asyncio import test_utils
 class _QueueTestBase(test_utils.TestCase):
 
     def setUp(self):
+        super().setUp()
         self.loop = self.new_test_loop()
 
 
index 07de640c0b73d83a61087c71a97d4c9074c45651..6bf7862ecf9c0a0cdece16646812e42a3d7bd10a 100644 (file)
@@ -51,6 +51,7 @@ def close_transport(transport):
 class BaseSelectorEventLoopTests(test_utils.TestCase):
 
     def setUp(self):
+        super().setUp()
         self.selector = mock.Mock()
         self.selector.select.return_value = []
         self.loop = TestBaseSelectorEventLoop(self.selector)
@@ -698,6 +699,7 @@ class BaseSelectorEventLoopTests(test_utils.TestCase):
 class SelectorTransportTests(test_utils.TestCase):
 
     def setUp(self):
+        super().setUp()
         self.loop = self.new_test_loop()
         self.protocol = test_utils.make_test_protocol(asyncio.Protocol)
         self.sock = mock.Mock(socket.socket)
@@ -793,6 +795,7 @@ class SelectorTransportTests(test_utils.TestCase):
 class SelectorSocketTransportTests(test_utils.TestCase):
 
     def setUp(self):
+        super().setUp()
         self.loop = self.new_test_loop()
         self.protocol = test_utils.make_test_protocol(asyncio.Protocol)
         self.sock = mock.Mock(socket.socket)
@@ -1141,6 +1144,7 @@ class SelectorSocketTransportTests(test_utils.TestCase):
 class SelectorSslTransportTests(test_utils.TestCase):
 
     def setUp(self):
+        super().setUp()
         self.loop = self.new_test_loop()
         self.protocol = test_utils.make_test_protocol(asyncio.Protocol)
         self.sock = mock.Mock(socket.socket)
@@ -1501,6 +1505,7 @@ class SelectorSslWithoutSslTransportTests(unittest.TestCase):
 class SelectorDatagramTransportTests(test_utils.TestCase):
 
     def setUp(self):
+        super().setUp()
         self.loop = self.new_test_loop()
         self.protocol = test_utils.make_test_protocol(asyncio.DatagramProtocol)
         self.sock = mock.Mock(spec_set=socket.socket)
index 7dfa6c2c63b2be27818b08130d194f075f4dd395..0ca6d1bf2aabb1506be807517c17733220e3a8ef 100644 (file)
@@ -18,6 +18,7 @@ from asyncio import test_utils
 class SslProtoHandshakeTests(test_utils.TestCase):
 
     def setUp(self):
+        super().setUp()
         self.loop = asyncio.new_event_loop()
         self.set_event_loop(self.loop)
 
index 35557c3ce4a6f8ac9ec9de0b85384feb0dc46688..b47433a4cfdbd41d11cf1bf59963a6f7f14fde65 100644 (file)
@@ -22,6 +22,7 @@ class StreamReaderTests(test_utils.TestCase):
     DATA = b'line1\nline2\nline3\n'
 
     def setUp(self):
+        super().setUp()
         self.loop = asyncio.new_event_loop()
         self.set_event_loop(self.loop)
 
index 15310238727d7399501fbf0d9565d25a615954d3..bba688bb5a53c7eb1f35a861c1ba50f844a4fb9c 100644 (file)
@@ -35,6 +35,7 @@ class TestSubprocessTransport(base_subprocess.BaseSubprocessTransport):
 
 class SubprocessTransportTests(test_utils.TestCase):
     def setUp(self):
+        super().setUp()
         self.loop = self.new_test_loop()
         self.set_event_loop(self.loop)
 
@@ -466,6 +467,7 @@ if sys.platform != 'win32':
         Watcher = None
 
         def setUp(self):
+            super().setUp()
             policy = asyncio.get_event_loop_policy()
             self.loop = policy.new_event_loop()
             self.set_event_loop(self.loop)
@@ -490,6 +492,7 @@ else:
     class SubprocessProactorTests(SubprocessMixin, test_utils.TestCase):
 
         def setUp(self):
+            super().setUp()
             self.loop = asyncio.ProactorEventLoop()
             self.set_event_loop(self.loop)
 
index 1ceb9b28cb38ef901d2f7fd2de866eab130d4bba..22accf5d1edaa1d94a2901c86eb05af2024e33fe 100644 (file)
@@ -75,6 +75,7 @@ class Dummy:
 class TaskTests(test_utils.TestCase):
 
     def setUp(self):
+        super().setUp()
         self.loop = self.new_test_loop()
 
     def test_other_loop_future(self):
@@ -1933,6 +1934,7 @@ class TaskTests(test_utils.TestCase):
 class GatherTestsBase:
 
     def setUp(self):
+        super().setUp()
         self.one_loop = self.new_test_loop()
         self.other_loop = self.new_test_loop()
         self.set_event_loop(self.one_loop, cleanup=False)
@@ -2216,6 +2218,7 @@ class RunCoroutineThreadsafeTests(test_utils.TestCase):
     """Test case for asyncio.run_coroutine_threadsafe."""
 
     def setUp(self):
+        super().setUp()
         self.loop = asyncio.new_event_loop()
         self.set_event_loop(self.loop) # Will cleanup properly
 
@@ -2306,12 +2309,14 @@ class RunCoroutineThreadsafeTests(test_utils.TestCase):
 
 class SleepTests(test_utils.TestCase):
     def setUp(self):
+        super().setUp()
         self.loop = asyncio.new_event_loop()
         asyncio.set_event_loop(None)
 
     def tearDown(self):
         self.loop.close()
         self.loop = None
+        super().tearDown()
 
     def test_sleep_zero(self):
         result = 0
index ce897ed6bdd03f4a5ee7c2516ee6df2959f80bc1..83a035edee48e04a4122b7bc57de8c79b963781f 100644 (file)
@@ -40,6 +40,7 @@ def close_pipe_transport(transport):
 class SelectorEventLoopSignalTests(test_utils.TestCase):
 
     def setUp(self):
+        super().setUp()
         self.loop = asyncio.SelectorEventLoop()
         self.set_event_loop(self.loop)
 
@@ -234,6 +235,7 @@ class SelectorEventLoopSignalTests(test_utils.TestCase):
 class SelectorEventLoopUnixSocketTests(test_utils.TestCase):
 
     def setUp(self):
+        super().setUp()
         self.loop = asyncio.SelectorEventLoop()
         self.set_event_loop(self.loop)
 
@@ -338,6 +340,7 @@ class SelectorEventLoopUnixSocketTests(test_utils.TestCase):
 class UnixReadPipeTransportTests(test_utils.TestCase):
 
     def setUp(self):
+        super().setUp()
         self.loop = self.new_test_loop()
         self.protocol = test_utils.make_test_protocol(asyncio.Protocol)
         self.pipe = mock.Mock(spec_set=io.RawIOBase)
@@ -487,6 +490,7 @@ class UnixReadPipeTransportTests(test_utils.TestCase):
 class UnixWritePipeTransportTests(test_utils.TestCase):
 
     def setUp(self):
+        super().setUp()
         self.loop = self.new_test_loop()
         self.protocol = test_utils.make_test_protocol(asyncio.BaseProtocol)
         self.pipe = mock.Mock(spec_set=io.RawIOBase)
@@ -805,6 +809,7 @@ class ChildWatcherTestsMixin:
     ignore_warnings = mock.patch.object(log.logger, "warning")
 
     def setUp(self):
+        super().setUp()
         self.loop = self.new_test_loop()
         self.running = False
         self.zombies = {}
index 7fcf4023eecdcbbf7e7ceffa6f8aa47eea1101e3..1afcae107b0dbef8b443dcb0db1359e608b56b7b 100644 (file)
@@ -31,6 +31,7 @@ class UpperProto(asyncio.Protocol):
 class ProactorTests(test_utils.TestCase):
 
     def setUp(self):
+        super().setUp()
         self.loop = asyncio.ProactorEventLoop()
         self.set_event_loop(self.loop)
 
index 5c936bd483eaca2eeb574978fec36b38a9c4e993..3e2b34f41c2855bdeeea780dc3b2105d887a6560 100644 (file)
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -443,6 +443,9 @@ Library
 
 - Issue #28600: Optimize loop.call_soon().
 
+- Issue #28613: Fix get_event_loop() return the current loop if 
+  called from coroutines/callbacks.
+
 IDLE
 ----