]> granicus.if.org Git - python/commitdiff
asyncio: Refactor tests: add a base TestCase class
authorVictor Stinner <victor.stinner@gmail.com>
Tue, 17 Jun 2014 23:36:32 +0000 (01:36 +0200)
committerVictor Stinner <victor.stinner@gmail.com>
Tue, 17 Jun 2014 23:36:32 +0000 (01:36 +0200)
13 files changed:
Lib/asyncio/test_utils.py
Lib/test/test_asyncio/test_base_events.py
Lib/test/test_asyncio/test_events.py
Lib/test/test_asyncio/test_futures.py
Lib/test/test_asyncio/test_locks.py
Lib/test/test_asyncio/test_proactor_events.py
Lib/test/test_asyncio/test_queues.py
Lib/test/test_asyncio/test_selector_events.py
Lib/test/test_asyncio/test_streams.py
Lib/test/test_asyncio/test_subprocess.py
Lib/test/test_asyncio/test_tasks.py
Lib/test/test_asyncio/test_unix_events.py
Lib/test/test_asyncio/test_windows_events.py

index 1062bae132234a69ae6fd7506534b838c461fec5..d9c7ae2d114d11fe1ef6233a29d0e2430000716b 100644 (file)
@@ -11,6 +11,7 @@ import sys
 import tempfile
 import threading
 import time
+import unittest
 from unittest import mock
 
 from http.server import HTTPServer
@@ -379,3 +380,20 @@ def get_function_source(func):
     if source is None:
         raise ValueError("unable to get the source of %r" % (func,))
     return source
+
+
+class TestCase(unittest.TestCase):
+    def set_event_loop(self, loop, *, cleanup=True):
+        assert loop is not None
+        # ensure that the event loop is passed explicitly in asyncio
+        events.set_event_loop(None)
+        if cleanup:
+            self.addCleanup(loop.close)
+
+    def new_test_loop(self, gen=None):
+        loop = TestLoop(gen)
+        self.set_event_loop(loop)
+        return loop
+
+    def tearDown(self):
+        events.set_event_loop(None)
index fb28b87e715ce05607155c5e228ae48a28dccd09..059b41c329ccfb47c77041bd3800a3157907176f 100644 (file)
@@ -19,12 +19,12 @@ MOCK_ANY = mock.ANY
 PY34 = sys.version_info >= (3, 4)
 
 
-class BaseEventLoopTests(unittest.TestCase):
+class BaseEventLoopTests(test_utils.TestCase):
 
     def setUp(self):
         self.loop = base_events.BaseEventLoop()
         self.loop._selector = mock.Mock()
-        asyncio.set_event_loop(None)
+        self.set_event_loop(self.loop)
 
     def test_not_implemented(self):
         m = mock.Mock()
@@ -548,14 +548,11 @@ class MyDatagramProto(asyncio.DatagramProtocol):
             self.done.set_result(None)
 
 
-class BaseEventLoopWithSelectorTests(unittest.TestCase):
+class BaseEventLoopWithSelectorTests(test_utils.TestCase):
 
     def setUp(self):
         self.loop = asyncio.new_event_loop()
-        asyncio.set_event_loop(None)
-
-    def tearDown(self):
-        self.loop.close()
+        self.set_event_loop(self.loop)
 
     @mock.patch('asyncio.base_events.socket')
     def test_create_connection_multiple_errors(self, m_socket):
index 2262a75226b461935fc213273086e46354cee520..37e45e1df70842859c237253387b3b48ce677493 100644 (file)
@@ -224,7 +224,7 @@ class EventLoopTestsMixin:
     def setUp(self):
         super().setUp()
         self.loop = self.create_event_loop()
-        asyncio.set_event_loop(None)
+        self.set_event_loop(self.loop)
 
     def tearDown(self):
         # just in case if we have transport close callbacks
@@ -1629,14 +1629,14 @@ class SubprocessTestsMixin:
 
 if sys.platform == 'win32':
 
-    class SelectEventLoopTests(EventLoopTestsMixin, unittest.TestCase):
+    class SelectEventLoopTests(EventLoopTestsMixin, test_utils.TestCase):
 
         def create_event_loop(self):
             return asyncio.SelectorEventLoop()
 
     class ProactorEventLoopTests(EventLoopTestsMixin,
                                  SubprocessTestsMixin,
-                                 unittest.TestCase):
+                                 test_utils.TestCase):
 
         def create_event_loop(self):
             return asyncio.ProactorEventLoop()
@@ -1691,7 +1691,7 @@ else:
     if hasattr(selectors, 'KqueueSelector'):
         class KqueueEventLoopTests(UnixEventLoopTestsMixin,
                                    SubprocessTestsMixin,
-                                   unittest.TestCase):
+                                   test_utils.TestCase):
 
             def create_event_loop(self):
                 return asyncio.SelectorEventLoop(
@@ -1716,7 +1716,7 @@ else:
     if hasattr(selectors, 'EpollSelector'):
         class EPollEventLoopTests(UnixEventLoopTestsMixin,
                                   SubprocessTestsMixin,
-                                  unittest.TestCase):
+                                  test_utils.TestCase):
 
             def create_event_loop(self):
                 return asyncio.SelectorEventLoop(selectors.EpollSelector())
@@ -1724,7 +1724,7 @@ else:
     if hasattr(selectors, 'PollSelector'):
         class PollEventLoopTests(UnixEventLoopTestsMixin,
                                  SubprocessTestsMixin,
-                                 unittest.TestCase):
+                                 test_utils.TestCase):
 
             def create_event_loop(self):
                 return asyncio.SelectorEventLoop(selectors.PollSelector())
@@ -1732,7 +1732,7 @@ else:
     # Should always exist.
     class SelectEventLoopTests(UnixEventLoopTestsMixin,
                                SubprocessTestsMixin,
-                               unittest.TestCase):
+                               test_utils.TestCase):
 
         def create_event_loop(self):
             return asyncio.SelectorEventLoop(selectors.SelectSelector())
index 399e8f438e1a65b45b0cebe6953afffa6e0ebb47..a230d614761226d517bd63ddb14a88bd0fc1e126 100644 (file)
@@ -13,14 +13,10 @@ def _fakefunc(f):
     return f
 
 
-class FutureTests(unittest.TestCase):
+class FutureTests(test_utils.TestCase):
 
     def setUp(self):
-        self.loop = test_utils.TestLoop()
-        asyncio.set_event_loop(None)
-
-    def tearDown(self):
-        self.loop.close()
+        self.loop = self.new_test_loop()
 
     def test_initial_state(self):
         f = asyncio.Future(loop=self.loop)
@@ -30,12 +26,9 @@ class FutureTests(unittest.TestCase):
         self.assertTrue(f.cancelled())
 
     def test_init_constructor_default_loop(self):
-        try:
-            asyncio.set_event_loop(self.loop)
-            f = asyncio.Future()
-            self.assertIs(f._loop, self.loop)
-        finally:
-            asyncio.set_event_loop(None)
+        asyncio.set_event_loop(self.loop)
+        f = asyncio.Future()
+        self.assertIs(f._loop, self.loop)
 
     def test_constructor_positional(self):
         # Make sure Future doesn't accept a positional argument
@@ -264,14 +257,10 @@ class FutureTests(unittest.TestCase):
         self.assertTrue(f2.cancelled())
 
 
-class FutureDoneCallbackTests(unittest.TestCase):
+class FutureDoneCallbackTests(test_utils.TestCase):
 
     def setUp(self):
-        self.loop = test_utils.TestLoop()
-        asyncio.set_event_loop(None)
-
-    def tearDown(self):
-        self.loop.close()
+        self.loop = self.new_test_loop()
 
     def run_briefly(self):
         test_utils.run_briefly(self.loop)
index f542463ad23cb988fd101533368ac21706c68da2..9d50a71f4e18526410ab9479444d63c44ccd9c80 100644 (file)
@@ -17,14 +17,10 @@ STR_RGX_REPR = (
 RGX_REPR = re.compile(STR_RGX_REPR)
 
 
-class LockTests(unittest.TestCase):
+class LockTests(test_utils.TestCase):
 
     def setUp(self):
-        self.loop = test_utils.TestLoop()
-        asyncio.set_event_loop(None)
-
-    def tearDown(self):
-        self.loop.close()
+        self.loop = self.new_test_loop()
 
     def test_ctor_loop(self):
         loop = mock.Mock()
@@ -35,12 +31,9 @@ class LockTests(unittest.TestCase):
         self.assertIs(lock._loop, self.loop)
 
     def test_ctor_noloop(self):
-        try:
-            asyncio.set_event_loop(self.loop)
-            lock = asyncio.Lock()
-            self.assertIs(lock._loop, self.loop)
-        finally:
-            asyncio.set_event_loop(None)
+        asyncio.set_event_loop(self.loop)
+        lock = asyncio.Lock()
+        self.assertIs(lock._loop, self.loop)
 
     def test_repr(self):
         lock = asyncio.Lock(loop=self.loop)
@@ -240,14 +233,10 @@ class LockTests(unittest.TestCase):
         self.assertFalse(lock.locked())
 
 
-class EventTests(unittest.TestCase):
+class EventTests(test_utils.TestCase):
 
     def setUp(self):
-        self.loop = test_utils.TestLoop()
-        asyncio.set_event_loop(None)
-
-    def tearDown(self):
-        self.loop.close()
+        self.loop = self.new_test_loop()
 
     def test_ctor_loop(self):
         loop = mock.Mock()
@@ -258,12 +247,9 @@ class EventTests(unittest.TestCase):
         self.assertIs(ev._loop, self.loop)
 
     def test_ctor_noloop(self):
-        try:
-            asyncio.set_event_loop(self.loop)
-            ev = asyncio.Event()
-            self.assertIs(ev._loop, self.loop)
-        finally:
-            asyncio.set_event_loop(None)
+        asyncio.set_event_loop(self.loop)
+        ev = asyncio.Event()
+        self.assertIs(ev._loop, self.loop)
 
     def test_repr(self):
         ev = asyncio.Event(loop=self.loop)
@@ -376,14 +362,10 @@ class EventTests(unittest.TestCase):
         self.assertTrue(t.result())
 
 
-class ConditionTests(unittest.TestCase):
+class ConditionTests(test_utils.TestCase):
 
     def setUp(self):
-        self.loop = test_utils.TestLoop()
-        asyncio.set_event_loop(None)
-
-    def tearDown(self):
-        self.loop.close()
+        self.loop = self.new_test_loop()
 
     def test_ctor_loop(self):
         loop = mock.Mock()
@@ -394,12 +376,9 @@ class ConditionTests(unittest.TestCase):
         self.assertIs(cond._loop, self.loop)
 
     def test_ctor_noloop(self):
-        try:
-            asyncio.set_event_loop(self.loop)
-            cond = asyncio.Condition()
-            self.assertIs(cond._loop, self.loop)
-        finally:
-            asyncio.set_event_loop(None)
+        asyncio.set_event_loop(self.loop)
+        cond = asyncio.Condition()
+        self.assertIs(cond._loop, self.loop)
 
     def test_wait(self):
         cond = asyncio.Condition(loop=self.loop)
@@ -678,14 +657,10 @@ class ConditionTests(unittest.TestCase):
         self.assertFalse(cond.locked())
 
 
-class SemaphoreTests(unittest.TestCase):
+class SemaphoreTests(test_utils.TestCase):
 
     def setUp(self):
-        self.loop = test_utils.TestLoop()
-        asyncio.set_event_loop(None)
-
-    def tearDown(self):
-        self.loop.close()
+        self.loop = self.new_test_loop()
 
     def test_ctor_loop(self):
         loop = mock.Mock()
@@ -696,12 +671,9 @@ class SemaphoreTests(unittest.TestCase):
         self.assertIs(sem._loop, self.loop)
 
     def test_ctor_noloop(self):
-        try:
-            asyncio.set_event_loop(self.loop)
-            sem = asyncio.Semaphore()
-            self.assertIs(sem._loop, self.loop)
-        finally:
-            asyncio.set_event_loop(None)
+        asyncio.set_event_loop(self.loop)
+        sem = asyncio.Semaphore()
+        self.assertIs(sem._loop, self.loop)
 
     def test_initial_value_zero(self):
         sem = asyncio.Semaphore(0, loop=self.loop)
index 5bf24a4503d10a014c5da299ef58092aca583e87..ddfceae14a66a045cd1f16493b83791110cef9fd 100644 (file)
@@ -12,10 +12,10 @@ from asyncio.proactor_events import _ProactorDuplexPipeTransport
 from asyncio import test_utils
 
 
-class ProactorSocketTransportTests(unittest.TestCase):
+class ProactorSocketTransportTests(test_utils.TestCase):
 
     def setUp(self):
-        self.loop = test_utils.TestLoop()
+        self.loop = self.new_test_loop()
         self.proactor = mock.Mock()
         self.loop._proactor = self.proactor
         self.protocol = test_utils.make_test_protocol(asyncio.Protocol)
@@ -343,7 +343,7 @@ class ProactorSocketTransportTests(unittest.TestCase):
         tr.close()
 
 
-class BaseProactorEventLoopTests(unittest.TestCase):
+class BaseProactorEventLoopTests(test_utils.TestCase):
 
     def setUp(self):
         self.sock = mock.Mock(socket.socket)
@@ -356,6 +356,7 @@ class BaseProactorEventLoopTests(unittest.TestCase):
                 return (self.ssock, self.csock)
 
         self.loop = EventLoop(self.proactor)
+        self.set_event_loop(self.loop, cleanup=False)
 
     @mock.patch.object(BaseProactorEventLoop, 'call_soon')
     @mock.patch.object(BaseProactorEventLoop, '_socketpair')
index 820234dfe1df2cf5eab807917ebea0b2196a25df..32c90f4737e87867cc98c4743a0a688af870f527 100644 (file)
@@ -7,14 +7,10 @@ import asyncio
 from asyncio import test_utils
 
 
-class _QueueTestBase(unittest.TestCase):
+class _QueueTestBase(test_utils.TestCase):
 
     def setUp(self):
-        self.loop = test_utils.TestLoop()
-        asyncio.set_event_loop(None)
-
-    def tearDown(self):
-        self.loop.close()
+        self.loop = self.new_test_loop()
 
 
 class QueueBasicTests(_QueueTestBase):
@@ -32,8 +28,7 @@ class QueueBasicTests(_QueueTestBase):
             self.assertAlmostEqual(0.2, when)
             yield 0.1
 
-        loop = test_utils.TestLoop(gen)
-        self.addCleanup(loop.close)
+        loop = self.new_test_loop(gen)
 
         q = asyncio.Queue(loop=loop)
         self.assertTrue(fn(q).startswith('<Queue'), fn(q))
@@ -80,12 +75,9 @@ class QueueBasicTests(_QueueTestBase):
         self.assertIs(q._loop, self.loop)
 
     def test_ctor_noloop(self):
-        try:
-            asyncio.set_event_loop(self.loop)
-            q = asyncio.Queue()
-            self.assertIs(q._loop, self.loop)
-        finally:
-            asyncio.set_event_loop(None)
+        asyncio.set_event_loop(self.loop)
+        q = asyncio.Queue()
+        self.assertIs(q._loop, self.loop)
 
     def test_repr(self):
         self._test_repr_or_str(repr, True)
@@ -126,8 +118,7 @@ class QueueBasicTests(_QueueTestBase):
             self.assertAlmostEqual(0.02, when)
             yield 0.01
 
-        loop = test_utils.TestLoop(gen)
-        self.addCleanup(loop.close)
+        loop = self.new_test_loop(gen)
 
         q = asyncio.Queue(maxsize=2, loop=loop)
         self.assertEqual(2, q.maxsize)
@@ -194,8 +185,7 @@ class QueueGetTests(_QueueTestBase):
             self.assertAlmostEqual(0.01, when)
             yield 0.01
 
-        loop = test_utils.TestLoop(gen)
-        self.addCleanup(loop.close)
+        loop = self.new_test_loop(gen)
 
         q = asyncio.Queue(loop=loop)
         started = asyncio.Event(loop=loop)
@@ -241,8 +231,7 @@ class QueueGetTests(_QueueTestBase):
             self.assertAlmostEqual(0.061, when)
             yield 0.05
 
-        loop = test_utils.TestLoop(gen)
-        self.addCleanup(loop.close)
+        loop = self.new_test_loop(gen)
 
         q = asyncio.Queue(loop=loop)
 
@@ -302,8 +291,7 @@ class QueuePutTests(_QueueTestBase):
             self.assertAlmostEqual(0.01, when)
             yield 0.01
 
-        loop = test_utils.TestLoop(gen)
-        self.addCleanup(loop.close)
+        loop = self.new_test_loop(gen)
 
         q = asyncio.Queue(maxsize=1, loop=loop)
         started = asyncio.Event(loop=loop)
index 36f65085bf5c5d0addaac7a2bcb64b2f393fda50..b1148d2e36d8a6ad399d2cf774205a3eb57eb85f 100644 (file)
@@ -37,11 +37,12 @@ def list_to_buffer(l=()):
     return bytearray().join(l)
 
 
-class BaseSelectorEventLoopTests(unittest.TestCase):
+class BaseSelectorEventLoopTests(test_utils.TestCase):
 
     def setUp(self):
         selector = mock.Mock()
         self.loop = TestBaseSelectorEventLoop(selector)
+        self.set_event_loop(self.loop, cleanup=False)
 
     def test_make_socket_transport(self):
         m = mock.Mock()
@@ -597,10 +598,10 @@ class BaseSelectorEventLoopTests(unittest.TestCase):
         self.loop.remove_writer.assert_called_with(1)
 
 
-class SelectorTransportTests(unittest.TestCase):
+class SelectorTransportTests(test_utils.TestCase):
 
     def setUp(self):
-        self.loop = test_utils.TestLoop()
+        self.loop = self.new_test_loop()
         self.protocol = test_utils.make_test_protocol(asyncio.Protocol)
         self.sock = mock.Mock(socket.socket)
         self.sock.fileno.return_value = 7
@@ -684,14 +685,14 @@ class SelectorTransportTests(unittest.TestCase):
         self.assertEqual(2, sys.getrefcount(self.protocol),
                          pprint.pformat(gc.get_referrers(self.protocol)))
         self.assertIsNone(tr._loop)
-        self.assertEqual(2, sys.getrefcount(self.loop),
+        self.assertEqual(3, sys.getrefcount(self.loop),
                          pprint.pformat(gc.get_referrers(self.loop)))
 
 
-class SelectorSocketTransportTests(unittest.TestCase):
+class SelectorSocketTransportTests(test_utils.TestCase):
 
     def setUp(self):
-        self.loop = test_utils.TestLoop()
+        self.loop = self.new_test_loop()
         self.protocol = test_utils.make_test_protocol(asyncio.Protocol)
         self.sock = mock.Mock(socket.socket)
         self.sock_fd = self.sock.fileno.return_value = 7
@@ -1061,10 +1062,10 @@ class SelectorSocketTransportTests(unittest.TestCase):
 
 
 @unittest.skipIf(ssl is None, 'No ssl module')
-class SelectorSslTransportTests(unittest.TestCase):
+class SelectorSslTransportTests(test_utils.TestCase):
 
     def setUp(self):
-        self.loop = test_utils.TestLoop()
+        self.loop = self.new_test_loop()
         self.protocol = test_utils.make_test_protocol(asyncio.Protocol)
         self.sock = mock.Mock(socket.socket)
         self.sock.fileno.return_value = 7
@@ -1396,10 +1397,10 @@ class SelectorSslWithoutSslTransportTests(unittest.TestCase):
             _SelectorSslTransport(Mock(), Mock(), Mock(), Mock())
 
 
-class SelectorDatagramTransportTests(unittest.TestCase):
+class SelectorDatagramTransportTests(test_utils.TestCase):
 
     def setUp(self):
-        self.loop = test_utils.TestLoop()
+        self.loop = self.new_test_loop()
         self.protocol = test_utils.make_test_protocol(asyncio.DatagramProtocol)
         self.sock = mock.Mock(spec_set=socket.socket)
         self.sock.fileno.return_value = 7
index 1ecc8eb1fa0d235a9f426af95e1e6275d330dae9..73a375aba42d2fbfe7ce3673f441e31add69ea50 100644 (file)
@@ -15,13 +15,13 @@ import asyncio
 from asyncio import test_utils
 
 
-class StreamReaderTests(unittest.TestCase):
+class StreamReaderTests(test_utils.TestCase):
 
     DATA = b'line1\nline2\nline3\n'
 
     def setUp(self):
         self.loop = asyncio.new_event_loop()
-        asyncio.set_event_loop(None)
+        self.set_event_loop(self.loop)
 
     def tearDown(self):
         # just in case if we have transport close callbacks
@@ -29,6 +29,7 @@ class StreamReaderTests(unittest.TestCase):
 
         self.loop.close()
         gc.collect()
+        super().tearDown()
 
     @mock.patch('asyncio.streams.events')
     def test_ctor_global_loop(self, m_events):
index 14fd17e61a421fd049f39b8d7f31080a948fd0b1..3b962bf9bdcc37cd4ac8aa911e767180a9fad7ac 100644 (file)
@@ -1,4 +1,5 @@
 from asyncio import subprocess
+from asyncio import test_utils
 import asyncio
 import signal
 import sys
@@ -151,21 +152,21 @@ if sys.platform != 'win32':
             policy = asyncio.get_event_loop_policy()
             policy.set_child_watcher(None)
             self.loop.close()
-            policy.set_event_loop(None)
+            super().tearDown()
 
     class SubprocessSafeWatcherTests(SubprocessWatcherMixin,
-                                     unittest.TestCase):
+                                     test_utils.TestCase):
 
         Watcher = unix_events.SafeChildWatcher
 
     class SubprocessFastWatcherTests(SubprocessWatcherMixin,
-                                     unittest.TestCase):
+                                     test_utils.TestCase):
 
         Watcher = unix_events.FastChildWatcher
 
 else:
     # Windows
-    class SubprocessProactorTests(SubprocessMixin, unittest.TestCase):
+    class SubprocessProactorTests(SubprocessMixin, test_utils.TestCase):
 
         def setUp(self):
             policy = asyncio.get_event_loop_policy()
@@ -178,6 +179,7 @@ else:
             policy = asyncio.get_event_loop_policy()
             self.loop.close()
             policy.set_event_loop(None)
+            super().tearDown()
 
 
 if __name__ == '__main__':
index dcc81234d67d3b92d8437c3d362425aa297aadfb..0ed2f941b55182010adc0c15c3034b146ce42d5f 100644 (file)
@@ -30,15 +30,10 @@ class Dummy:
         pass
 
 
-class TaskTests(unittest.TestCase):
+class TaskTests(test_utils.TestCase):
 
     def setUp(self):
-        self.loop = test_utils.TestLoop()
-        asyncio.set_event_loop(None)
-
-    def tearDown(self):
-        self.loop.close()
-        gc.collect()
+        self.loop = self.new_test_loop()
 
     def test_task_class(self):
         @asyncio.coroutine
@@ -51,6 +46,7 @@ class TaskTests(unittest.TestCase):
         self.assertIs(t._loop, self.loop)
 
         loop = asyncio.new_event_loop()
+        self.set_event_loop(loop)
         t = asyncio.Task(notmuch(), loop=loop)
         self.assertIs(t._loop, loop)
         loop.close()
@@ -66,6 +62,7 @@ class TaskTests(unittest.TestCase):
         self.assertIs(t._loop, self.loop)
 
         loop = asyncio.new_event_loop()
+        self.set_event_loop(loop)
         t = asyncio.async(notmuch(), loop=loop)
         self.assertIs(t._loop, loop)
         loop.close()
@@ -81,6 +78,7 @@ class TaskTests(unittest.TestCase):
         self.assertIs(f, f_orig)
 
         loop = asyncio.new_event_loop()
+        self.set_event_loop(loop)
 
         with self.assertRaises(ValueError):
             f = asyncio.async(f_orig, loop=loop)
@@ -102,6 +100,7 @@ class TaskTests(unittest.TestCase):
         self.assertIs(t, t_orig)
 
         loop = asyncio.new_event_loop()
+        self.set_event_loop(loop)
 
         with self.assertRaises(ValueError):
             t = asyncio.async(t_orig, loop=loop)
@@ -220,8 +219,7 @@ class TaskTests(unittest.TestCase):
             self.assertAlmostEqual(10.0, when)
             yield 0
 
-        loop = test_utils.TestLoop(gen)
-        self.addCleanup(loop.close)
+        loop = self.new_test_loop(gen)
 
         @asyncio.coroutine
         def task():
@@ -346,7 +344,7 @@ class TaskTests(unittest.TestCase):
 
     def test_cancel_current_task(self):
         loop = asyncio.new_event_loop()
-        self.addCleanup(loop.close)
+        self.set_event_loop(loop)
 
         @asyncio.coroutine
         def task():
@@ -374,8 +372,7 @@ class TaskTests(unittest.TestCase):
             self.assertAlmostEqual(0.3, when)
             yield 0.1
 
-        loop = test_utils.TestLoop(gen)
-        self.addCleanup(loop.close)
+        loop = self.new_test_loop(gen)
 
         x = 0
         waiters = []
@@ -410,8 +407,7 @@ class TaskTests(unittest.TestCase):
             self.assertAlmostEqual(0.1, when)
             when = yield 0.1
 
-        loop = test_utils.TestLoop(gen)
-        self.addCleanup(loop.close)
+        loop = self.new_test_loop(gen)
 
         foo_running = None
 
@@ -436,8 +432,7 @@ class TaskTests(unittest.TestCase):
         self.assertEqual(foo_running, False)
 
     def test_wait_for_blocking(self):
-        loop = test_utils.TestLoop()
-        self.addCleanup(loop.close)
+        loop = self.new_test_loop()
 
         @asyncio.coroutine
         def coro():
@@ -457,8 +452,7 @@ class TaskTests(unittest.TestCase):
             self.assertAlmostEqual(0.01, when)
             yield 0.01
 
-        loop = test_utils.TestLoop(gen)
-        self.addCleanup(loop.close)
+        loop = self.new_test_loop(gen)
 
         @asyncio.coroutine
         def foo():
@@ -486,8 +480,7 @@ class TaskTests(unittest.TestCase):
             self.assertAlmostEqual(0.15, when)
             yield 0.15
 
-        loop = test_utils.TestLoop(gen)
-        self.addCleanup(loop.close)
+        loop = self.new_test_loop(gen)
 
         a = asyncio.Task(asyncio.sleep(0.1, loop=loop), loop=loop)
         b = asyncio.Task(asyncio.sleep(0.15, loop=loop), loop=loop)
@@ -517,8 +510,7 @@ class TaskTests(unittest.TestCase):
             self.assertAlmostEqual(0.015, when)
             yield 0.015
 
-        loop = test_utils.TestLoop(gen)
-        self.addCleanup(loop.close)
+        loop = self.new_test_loop(gen)
 
         a = asyncio.Task(asyncio.sleep(0.01, loop=loop), loop=loop)
         b = asyncio.Task(asyncio.sleep(0.015, loop=loop), loop=loop)
@@ -531,11 +523,8 @@ class TaskTests(unittest.TestCase):
             return 42
 
         asyncio.set_event_loop(loop)
-        try:
-            res = loop.run_until_complete(
-                asyncio.Task(foo(), loop=loop))
-        finally:
-            asyncio.set_event_loop(None)
+        res = loop.run_until_complete(
+            asyncio.Task(foo(), loop=loop))
 
         self.assertEqual(res, 42)
 
@@ -573,8 +562,7 @@ class TaskTests(unittest.TestCase):
             self.assertAlmostEqual(0.1, when)
             yield 0.1
 
-        loop = test_utils.TestLoop(gen)
-        self.addCleanup(loop.close)
+        loop = self.new_test_loop(gen)
 
         a = asyncio.Task(asyncio.sleep(10.0, loop=loop), loop=loop)
         b = asyncio.Task(asyncio.sleep(0.1, loop=loop), loop=loop)
@@ -629,8 +617,7 @@ class TaskTests(unittest.TestCase):
             self.assertAlmostEqual(10.0, when)
             yield 0
 
-        loop = test_utils.TestLoop(gen)
-        self.addCleanup(loop.close)
+        loop = self.new_test_loop(gen)
 
         # first_exception, task already has exception
         a = asyncio.Task(asyncio.sleep(10.0, loop=loop), loop=loop)
@@ -663,8 +650,7 @@ class TaskTests(unittest.TestCase):
             self.assertAlmostEqual(0.01, when)
             yield 0.01
 
-        loop = test_utils.TestLoop(gen)
-        self.addCleanup(loop.close)
+        loop = self.new_test_loop(gen)
 
         # first_exception, exception during waiting
         a = asyncio.Task(asyncio.sleep(10.0, loop=loop), loop=loop)
@@ -696,8 +682,7 @@ class TaskTests(unittest.TestCase):
             self.assertAlmostEqual(0.15, when)
             yield 0.15
 
-        loop = test_utils.TestLoop(gen)
-        self.addCleanup(loop.close)
+        loop = self.new_test_loop(gen)
 
         a = asyncio.Task(asyncio.sleep(0.1, loop=loop), loop=loop)
 
@@ -733,8 +718,7 @@ class TaskTests(unittest.TestCase):
             self.assertAlmostEqual(0.11, when)
             yield 0.11
 
-        loop = test_utils.TestLoop(gen)
-        self.addCleanup(loop.close)
+        loop = self.new_test_loop(gen)
 
         a = asyncio.Task(asyncio.sleep(0.1, loop=loop), loop=loop)
         b = asyncio.Task(asyncio.sleep(0.15, loop=loop), loop=loop)
@@ -764,8 +748,7 @@ class TaskTests(unittest.TestCase):
             self.assertAlmostEqual(0.1, when)
             yield 0.1
 
-        loop = test_utils.TestLoop(gen)
-        self.addCleanup(loop.close)
+        loop = self.new_test_loop(gen)
 
         a = asyncio.Task(asyncio.sleep(0.1, loop=loop), loop=loop)
         b = asyncio.Task(asyncio.sleep(0.15, loop=loop), loop=loop)
@@ -789,8 +772,7 @@ class TaskTests(unittest.TestCase):
             yield 0.01
             yield 0
 
-        loop = test_utils.TestLoop(gen)
-        self.addCleanup(loop.close)
+        loop = self.new_test_loop(gen)
         completed = set()
         time_shifted = False
 
@@ -833,8 +815,7 @@ class TaskTests(unittest.TestCase):
             yield 0
             yield 0.1
 
-        loop = test_utils.TestLoop(gen)
-        self.addCleanup(loop.close)
+        loop = self.new_test_loop(gen)
 
         a = asyncio.sleep(0.1, 'a', loop=loop)
         b = asyncio.sleep(0.15, 'b', loop=loop)
@@ -870,8 +851,7 @@ class TaskTests(unittest.TestCase):
             yield 0
             yield 0.01
 
-        loop = test_utils.TestLoop(gen)
-        self.addCleanup(loop.close)
+        loop = self.new_test_loop(gen)
 
         a = asyncio.sleep(0.01, 'a', loop=loop)
 
@@ -890,8 +870,7 @@ class TaskTests(unittest.TestCase):
             yield 0.05
             yield 0
 
-        loop = test_utils.TestLoop(gen)
-        self.addCleanup(loop.close)
+        loop = self.new_test_loop(gen)
 
         a = asyncio.sleep(0.05, 'a', loop=loop)
         b = asyncio.sleep(0.10, 'b', loop=loop)
@@ -916,8 +895,7 @@ class TaskTests(unittest.TestCase):
             self.assertAlmostEqual(0.05, when)
             yield 0.05
 
-        loop = test_utils.TestLoop(gen)
-        self.addCleanup(loop.close)
+        loop = self.new_test_loop(gen)
 
         a = asyncio.sleep(0.05, 'a', loop=loop)
         b = asyncio.sleep(0.05, 'b', loop=loop)
@@ -958,8 +936,7 @@ class TaskTests(unittest.TestCase):
             self.assertAlmostEqual(0.1, when)
             yield 0.05
 
-        loop = test_utils.TestLoop(gen)
-        self.addCleanup(loop.close)
+        loop = self.new_test_loop(gen)
 
         @asyncio.coroutine
         def sleeper(dt, arg):
@@ -980,8 +957,7 @@ class TaskTests(unittest.TestCase):
             self.assertAlmostEqual(10.0, when)
             yield 0
 
-        loop = test_utils.TestLoop(gen)
-        self.addCleanup(loop.close)
+        loop = self.new_test_loop(gen)
 
         t = asyncio.Task(asyncio.sleep(10.0, 'yeah', loop=loop),
                          loop=loop)
@@ -1012,8 +988,7 @@ class TaskTests(unittest.TestCase):
             self.assertAlmostEqual(5000, when)
             yield 0.1
 
-        loop = test_utils.TestLoop(gen)
-        self.addCleanup(loop.close)
+        loop = self.new_test_loop(gen)
 
         @asyncio.coroutine
         def sleep(dt):
@@ -1123,8 +1098,7 @@ class TaskTests(unittest.TestCase):
             self.assertAlmostEqual(10.0, when)
             yield 0
 
-        loop = test_utils.TestLoop(gen)
-        self.addCleanup(loop.close)
+        loop = self.new_test_loop(gen)
 
         @asyncio.coroutine
         def sleeper():
@@ -1536,12 +1510,9 @@ class TaskTests(unittest.TestCase):
 class GatherTestsBase:
 
     def setUp(self):
-        self.one_loop = test_utils.TestLoop()
-        self.other_loop = test_utils.TestLoop()
-
-    def tearDown(self):
-        self.one_loop.close()
-        self.other_loop.close()
+        self.one_loop = self.new_test_loop()
+        self.other_loop = self.new_test_loop()
+        self.set_event_loop(self.one_loop, cleanup=False)
 
     def _run_loop(self, loop):
         while loop._ready:
@@ -1633,7 +1604,7 @@ class GatherTestsBase:
         self.assertEqual(stdout.rstrip(), b'False')
 
 
-class FutureGatherTests(GatherTestsBase, unittest.TestCase):
+class FutureGatherTests(GatherTestsBase, test_utils.TestCase):
 
     def wrap_futures(self, *futures):
         return futures
@@ -1717,16 +1688,12 @@ class FutureGatherTests(GatherTestsBase, unittest.TestCase):
         cb.assert_called_once_with(fut)
 
 
-class CoroutineGatherTests(GatherTestsBase, unittest.TestCase):
+class CoroutineGatherTests(GatherTestsBase, test_utils.TestCase):
 
     def setUp(self):
         super().setUp()
         asyncio.set_event_loop(self.one_loop)
 
-    def tearDown(self):
-        asyncio.set_event_loop(None)
-        super().tearDown()
-
     def wrap_futures(self, *futures):
         coros = []
         for fut in futures:
index cec7a110a3c010596b04f9aa2d405381e9262b14..89a4c10368eff6dffe196f273b9491aa41eb5995 100644 (file)
@@ -29,14 +29,11 @@ MOCK_ANY = mock.ANY
 
 
 @unittest.skipUnless(signal, 'Signals are not supported')
-class SelectorEventLoopSignalTests(unittest.TestCase):
+class SelectorEventLoopSignalTests(test_utils.TestCase):
 
     def setUp(self):
         self.loop = asyncio.SelectorEventLoop()
-        asyncio.set_event_loop(None)
-
-    def tearDown(self):
-        self.loop.close()
+        self.set_event_loop(self.loop)
 
     def test_check_signal(self):
         self.assertRaises(
@@ -208,14 +205,11 @@ class SelectorEventLoopSignalTests(unittest.TestCase):
 
 @unittest.skipUnless(hasattr(socket, 'AF_UNIX'),
                      'UNIX Sockets are not supported')
-class SelectorEventLoopUnixSocketTests(unittest.TestCase):
+class SelectorEventLoopUnixSocketTests(test_utils.TestCase):
 
     def setUp(self):
         self.loop = asyncio.SelectorEventLoop()
-        asyncio.set_event_loop(None)
-
-    def tearDown(self):
-        self.loop.close()
+        self.set_event_loop(self.loop)
 
     def test_create_unix_server_existing_path_sock(self):
         with test_utils.unix_socket_path() as path:
@@ -304,10 +298,10 @@ class SelectorEventLoopUnixSocketTests(unittest.TestCase):
             self.loop.run_until_complete(coro)
 
 
-class UnixReadPipeTransportTests(unittest.TestCase):
+class UnixReadPipeTransportTests(test_utils.TestCase):
 
     def setUp(self):
-        self.loop = test_utils.TestLoop()
+        self.loop = self.new_test_loop()
         self.protocol = test_utils.make_test_protocol(asyncio.Protocol)
         self.pipe = mock.Mock(spec_set=io.RawIOBase)
         self.pipe.fileno.return_value = 5
@@ -451,7 +445,7 @@ class UnixReadPipeTransportTests(unittest.TestCase):
         self.assertEqual(2, sys.getrefcount(self.protocol),
                          pprint.pformat(gc.get_referrers(self.protocol)))
         self.assertIsNone(tr._loop)
-        self.assertEqual(4, sys.getrefcount(self.loop),
+        self.assertEqual(5, sys.getrefcount(self.loop),
                          pprint.pformat(gc.get_referrers(self.loop)))
 
     def test__call_connection_lost_with_err(self):
@@ -468,14 +462,14 @@ class UnixReadPipeTransportTests(unittest.TestCase):
         self.assertEqual(2, sys.getrefcount(self.protocol),
                          pprint.pformat(gc.get_referrers(self.protocol)))
         self.assertIsNone(tr._loop)
-        self.assertEqual(4, sys.getrefcount(self.loop),
+        self.assertEqual(5, sys.getrefcount(self.loop),
                          pprint.pformat(gc.get_referrers(self.loop)))
 
 
-class UnixWritePipeTransportTests(unittest.TestCase):
+class UnixWritePipeTransportTests(test_utils.TestCase):
 
     def setUp(self):
-        self.loop = test_utils.TestLoop()
+        self.loop = self.new_test_loop()
         self.protocol = test_utils.make_test_protocol(asyncio.BaseProtocol)
         self.pipe = mock.Mock(spec_set=io.RawIOBase)
         self.pipe.fileno.return_value = 5
@@ -737,7 +731,7 @@ class UnixWritePipeTransportTests(unittest.TestCase):
         self.assertEqual(2, sys.getrefcount(self.protocol),
                          pprint.pformat(gc.get_referrers(self.protocol)))
         self.assertIsNone(tr._loop)
-        self.assertEqual(4, sys.getrefcount(self.loop),
+        self.assertEqual(5, sys.getrefcount(self.loop),
                          pprint.pformat(gc.get_referrers(self.loop)))
 
     def test__call_connection_lost_with_err(self):
@@ -753,7 +747,7 @@ class UnixWritePipeTransportTests(unittest.TestCase):
         self.assertEqual(2, sys.getrefcount(self.protocol),
                          pprint.pformat(gc.get_referrers(self.protocol)))
         self.assertIsNone(tr._loop)
-        self.assertEqual(4, sys.getrefcount(self.loop),
+        self.assertEqual(5, sys.getrefcount(self.loop),
                          pprint.pformat(gc.get_referrers(self.loop)))
 
     def test_close(self):
@@ -834,7 +828,7 @@ class ChildWatcherTestsMixin:
     ignore_warnings = mock.patch.object(log.logger, "warning")
 
     def setUp(self):
-        self.loop = test_utils.TestLoop()
+        self.loop = self.new_test_loop()
         self.running = False
         self.zombies = {}
 
@@ -1392,7 +1386,7 @@ class ChildWatcherTestsMixin:
 
         # attach a new loop
         old_loop = self.loop
-        self.loop = test_utils.TestLoop()
+        self.loop = self.new_test_loop()
         patch = mock.patch.object
 
         with patch(old_loop, "remove_signal_handler") as m_old_remove, \
@@ -1447,7 +1441,7 @@ class ChildWatcherTestsMixin:
         self.assertFalse(callback3.called)
 
         # attach a new loop
-        self.loop = test_utils.TestLoop()
+        self.loop = self.new_test_loop()
 
         with mock.patch.object(
                 self.loop, "add_signal_handler") as m_add_signal_handler:
@@ -1505,12 +1499,12 @@ class ChildWatcherTestsMixin:
                     self.assertFalse(self.watcher._zombies)
 
 
-class SafeChildWatcherTests (ChildWatcherTestsMixin, unittest.TestCase):
+class SafeChildWatcherTests (ChildWatcherTestsMixin, test_utils.TestCase):
     def create_watcher(self):
         return asyncio.SafeChildWatcher()
 
 
-class FastChildWatcherTests (ChildWatcherTestsMixin, unittest.TestCase):
+class FastChildWatcherTests (ChildWatcherTestsMixin, test_utils.TestCase):
     def create_watcher(self):
         return asyncio.FastChildWatcher()
 
index f65225863972699d3a38d843aa4432af09f618b6..ca79c437c7bd9bdb41990c2b1ca682578514f3f4 100644 (file)
@@ -26,15 +26,11 @@ class UpperProto(asyncio.Protocol):
             self.trans.close()
 
 
-class ProactorTests(unittest.TestCase):
+class ProactorTests(test_utils.TestCase):
 
     def setUp(self):
         self.loop = asyncio.ProactorEventLoop()
-        asyncio.set_event_loop(None)
-
-    def tearDown(self):
-        self.loop.close()
-        self.loop = None
+        self.set_event_loop(self.loop)
 
     def test_close(self):
         a, b = self.loop._socketpair()