]> granicus.if.org Git - python/commitdiff
asyncio, Tulip issue 126: call_soon(), call_soon_threadsafe(), call_later(),
authorVictor Stinner <victor.stinner@gmail.com>
Tue, 11 Feb 2014 10:34:30 +0000 (11:34 +0100)
committerVictor Stinner <victor.stinner@gmail.com>
Tue, 11 Feb 2014 10:34:30 +0000 (11:34 +0100)
call_at() and run_in_executor() now raise a TypeError if the callback is a
coroutine function.

Lib/asyncio/base_events.py
Lib/asyncio/test_utils.py
Lib/test/test_asyncio/test_base_events.py
Lib/test/test_asyncio/test_proactor_events.py
Lib/test/test_asyncio/test_selector_events.py
Lib/test/test_asyncio/test_tasks.py

index 48b3ee3e9df255e069ae1c1cdb6de208b1384f40..4b7b161ecaabb3011ea9ab63f7dc1c4f959ca779 100644 (file)
@@ -227,6 +227,8 @@ class BaseEventLoop(events.AbstractEventLoop):
 
     def call_at(self, when, callback, *args):
         """Like call_later(), but uses an absolute time."""
+        if tasks.iscoroutinefunction(callback):
+            raise TypeError("coroutines cannot be used with call_at()")
         timer = events.TimerHandle(when, callback, args)
         heapq.heappush(self._scheduled, timer)
         return timer
@@ -241,6 +243,8 @@ class BaseEventLoop(events.AbstractEventLoop):
         Any positional arguments after the callback will be passed to
         the callback when it is called.
         """
+        if tasks.iscoroutinefunction(callback):
+            raise TypeError("coroutines cannot be used with call_soon()")
         handle = events.Handle(callback, args)
         self._ready.append(handle)
         return handle
@@ -252,6 +256,8 @@ class BaseEventLoop(events.AbstractEventLoop):
         return handle
 
     def run_in_executor(self, executor, callback, *args):
+        if tasks.iscoroutinefunction(callback):
+            raise TypeError("coroutines cannot be used with run_in_executor()")
         if isinstance(callback, events.Handle):
             assert not args
             assert not isinstance(callback, events.TimerHandle)
index 7c8e1dcbd6cfdd1c64e74a35bbe5269dd5f3b466..deab7c33122f066fd52d5a756c1894c9fef6d6d2 100644 (file)
@@ -135,7 +135,7 @@ def make_test_protocol(base):
         if name.startswith('__') and name.endswith('__'):
             # skip magic names
             continue
-        dct[name] = unittest.mock.Mock(return_value=None)
+        dct[name] = MockCallback(return_value=None)
     return type('TestProtocol', (base,) + base.__bases__, dct)()
 
 
@@ -274,3 +274,6 @@ class TestLoop(base_events.BaseEventLoop):
 
     def _write_to_self(self):
         pass
+
+def MockCallback(**kwargs):
+    return unittest.mock.Mock(spec=['__call__'], **kwargs)
index 5b05684723dde82040c09aa852f80e3c662b0986..c6950ab3fa8de3e81655b9f203b79e875deae46c 100644 (file)
@@ -567,6 +567,7 @@ class BaseEventLoopWithSelectorTests(unittest.TestCase):
 
         m_socket.getaddrinfo.return_value = [
             (2, 1, 6, '', ('127.0.0.1', 10100))]
+        m_socket.getaddrinfo._is_coroutine = False
         m_sock = m_socket.socket.return_value = unittest.mock.Mock()
         m_sock.bind.side_effect = Err
 
@@ -577,6 +578,7 @@ class BaseEventLoopWithSelectorTests(unittest.TestCase):
     @unittest.mock.patch('asyncio.base_events.socket')
     def test_create_datagram_endpoint_no_addrinfo(self, m_socket):
         m_socket.getaddrinfo.return_value = []
+        m_socket.getaddrinfo._is_coroutine = False
 
         coro = self.loop.create_datagram_endpoint(
             MyDatagramProto, local_addr=('localhost', 0))
@@ -681,6 +683,22 @@ class BaseEventLoopWithSelectorTests(unittest.TestCase):
                                                 unittest.mock.ANY,
                                                 MyProto, sock, None, None)
 
+    def test_call_coroutine(self):
+        @asyncio.coroutine
+        def coroutine_function():
+            pass
+
+        with self.assertRaises(TypeError):
+            self.loop.call_soon(coroutine_function)
+        with self.assertRaises(TypeError):
+            self.loop.call_soon_threadsafe(coroutine_function)
+        with self.assertRaises(TypeError):
+            self.loop.call_later(60, coroutine_function)
+        with self.assertRaises(TypeError):
+            self.loop.call_at(self.loop.time() + 60, coroutine_function)
+        with self.assertRaises(TypeError):
+            self.loop.run_in_executor(None, coroutine_function)
+
 
 if __name__ == '__main__':
     unittest.main()
index 9964f425d21d9d7a2cd975b004cd35ce5ba0c4fd..6bea1a33685709cca9b2eddc15b7c224a8d78cb8 100644 (file)
@@ -402,7 +402,7 @@ class BaseProactorEventLoopTests(unittest.TestCase):
             NotImplementedError, BaseProactorEventLoop, self.proactor)
 
     def test_make_socket_transport(self):
-        tr = self.loop._make_socket_transport(self.sock, unittest.mock.Mock())
+        tr = self.loop._make_socket_transport(self.sock, asyncio.Protocol())
         self.assertIsInstance(tr, _ProactorSocketTransport)
 
     def test_loop_self_reading(self):
index ad0b0be81e529c5caf25671cde5e1783ef3c31c4..855a8954e868291131f347e102e645fed8782f6d 100644 (file)
@@ -44,8 +44,8 @@ class BaseSelectorEventLoopTests(unittest.TestCase):
     def test_make_socket_transport(self):
         m = unittest.mock.Mock()
         self.loop.add_reader = unittest.mock.Mock()
-        self.assertIsInstance(
-            self.loop._make_socket_transport(m, m), _SelectorSocketTransport)
+        transport = self.loop._make_socket_transport(m, asyncio.Protocol())
+        self.assertIsInstance(transport, _SelectorSocketTransport)
 
     @unittest.skipIf(ssl is None, 'No ssl module')
     def test_make_ssl_transport(self):
@@ -54,8 +54,9 @@ class BaseSelectorEventLoopTests(unittest.TestCase):
         self.loop.add_writer = unittest.mock.Mock()
         self.loop.remove_reader = unittest.mock.Mock()
         self.loop.remove_writer = unittest.mock.Mock()
-        self.assertIsInstance(
-            self.loop._make_ssl_transport(m, m, m, m), _SelectorSslTransport)
+        waiter = asyncio.Future(loop=self.loop)
+        transport = self.loop._make_ssl_transport(m, asyncio.Protocol(), m, waiter)
+        self.assertIsInstance(transport, _SelectorSslTransport)
 
     @unittest.mock.patch('asyncio.selector_events.ssl', None)
     def test_make_ssl_transport_without_ssl_error(self):
index 9abdfa5bc13e0d5f20fcd521e3efc40c92ae7baf..29bdaf5bd4fc2def25deb1eb858e208b65dd9043 100644 (file)
@@ -2,8 +2,6 @@
 
 import gc
 import unittest
-import unittest.mock
-from unittest.mock import Mock
 
 import asyncio
 from asyncio import test_utils
@@ -1358,7 +1356,7 @@ class GatherTestsBase:
     def _check_success(self, **kwargs):
         a, b, c = [asyncio.Future(loop=self.one_loop) for i in range(3)]
         fut = asyncio.gather(*self.wrap_futures(a, b, c), **kwargs)
-        cb = Mock()
+        cb = test_utils.MockCallback()
         fut.add_done_callback(cb)
         b.set_result(1)
         a.set_result(2)
@@ -1380,7 +1378,7 @@ class GatherTestsBase:
     def test_one_exception(self):
         a, b, c, d, e = [asyncio.Future(loop=self.one_loop) for i in range(5)]
         fut = asyncio.gather(*self.wrap_futures(a, b, c, d, e))
-        cb = Mock()
+        cb = test_utils.MockCallback()
         fut.add_done_callback(cb)
         exc = ZeroDivisionError()
         a.set_result(1)
@@ -1399,7 +1397,7 @@ class GatherTestsBase:
         a, b, c, d = [asyncio.Future(loop=self.one_loop) for i in range(4)]
         fut = asyncio.gather(*self.wrap_futures(a, b, c, d),
                              return_exceptions=True)
-        cb = Mock()
+        cb = test_utils.MockCallback()
         fut.add_done_callback(cb)
         exc = ZeroDivisionError()
         exc2 = RuntimeError()
@@ -1460,7 +1458,7 @@ class FutureGatherTests(GatherTestsBase, unittest.TestCase):
     def test_one_cancellation(self):
         a, b, c, d, e = [asyncio.Future(loop=self.one_loop) for i in range(5)]
         fut = asyncio.gather(a, b, c, d, e)
-        cb = Mock()
+        cb = test_utils.MockCallback()
         fut.add_done_callback(cb)
         a.set_result(1)
         b.cancel()
@@ -1479,7 +1477,7 @@ class FutureGatherTests(GatherTestsBase, unittest.TestCase):
         a, b, c, d, e, f = [asyncio.Future(loop=self.one_loop)
                             for i in range(6)]
         fut = asyncio.gather(a, b, c, d, e, f, return_exceptions=True)
-        cb = Mock()
+        cb = test_utils.MockCallback()
         fut.add_done_callback(cb)
         a.set_result(1)
         zde = ZeroDivisionError()