]> granicus.if.org Git - python/commitdiff
bpo-38136: Updates await_count and call_count to be different things (GH-16192)
authorLisa Roach <lisaroach14@gmail.com>
Tue, 24 Sep 2019 03:49:40 +0000 (20:49 -0700)
committerGitHub <noreply@github.com>
Tue, 24 Sep 2019 03:49:40 +0000 (20:49 -0700)
Doc/library/unittest.mock.rst
Lib/unittest/mock.py
Lib/unittest/test/testmock/testasync.py
Lib/unittest/test/testmock/testmock.py
Misc/NEWS.d/next/Library/2019-09-16-09-54-42.bpo-38136.MdI-Zb.rst [new file with mode: 0644]

index b446ddb3598f2939432c7f2cf462047898fa9bc2..300f28c8e2cec785976ab575ea5688778c6224d4 100644 (file)
@@ -514,21 +514,6 @@ the *new_callable* argument to :func:`patch`.
             >>> mock.call_count
             2
 
-        For :class:`AsyncMock` the :attr:`call_count` is only iterated if the function
-        has been awaited:
-
-            >>> mock = AsyncMock()
-            >>> mock()  # doctest: +SKIP
-            <coroutine object AsyncMockMixin._mock_call at ...>
-            >>> mock.call_count
-            0
-            >>> async def main():
-            ...     await mock()
-            ...
-            >>> asyncio.run(main())
-            >>> mock.call_count
-            1
-
     .. attribute:: return_value
 
         Set this to configure the value returned by calling the mock:
@@ -907,19 +892,22 @@ object::
 
   .. method:: assert_awaited()
 
-      Assert that the mock was awaited at least once.
+      Assert that the mock was awaited at least once. Note that this is separate
+      from the object having been called, the ``await`` keyword must be used:
 
           >>> mock = AsyncMock()
-          >>> async def main():
-          ...     await mock()
+          >>> async def main(coroutine_mock):
+          ...     await coroutine_mock
           ...
-          >>> asyncio.run(main())
+          >>> coroutine_mock = mock()
+          >>> mock.called
+          True
           >>> mock.assert_awaited()
-          >>> mock_2 = AsyncMock()
-          >>> mock_2.assert_awaited()
           Traceback (most recent call last):
           ...
           AssertionError: Expected mock to have been awaited.
+          >>> asyncio.run(main(coroutine_mock))
+          >>> mock.assert_awaited()
 
   .. method:: assert_awaited_once()
 
@@ -1004,14 +992,15 @@ object::
         ...     await mock(*args, **kwargs)
         ...
         >>> calls = [call("foo"), call("bar")]
-        >>> mock.assert_has_calls(calls)
+        >>> mock.assert_has_awaits(calls)
         Traceback (most recent call last):
         ...
-        AssertionError: Calls not found.
+        AssertionError: Awaits not found.
         Expected: [call('foo'), call('bar')]
+        Actual: []
         >>> asyncio.run(main('foo'))
         >>> asyncio.run(main('bar'))
-        >>> mock.assert_has_calls(calls)
+        >>> mock.assert_has_awaits(calls)
 
   .. method:: assert_not_awaited()
 
index 0a16e26f1d8a14895b2283beec2860defb03584c..22d63a45884a65213942230b1793714603209f61 100644 (file)
@@ -1076,14 +1076,20 @@ class CallableMixin(Base):
         # can't use self in-case a function / method we are mocking uses self
         # in the signature
         self._mock_check_sig(*args, **kwargs)
+        self._increment_mock_call(*args, **kwargs)
         return self._mock_call(*args, **kwargs)
 
 
     def _mock_call(self, /, *args, **kwargs):
+        return self._execute_mock_call(*args, **kwargs)
+
+    def _increment_mock_call(self, /, *args, **kwargs):
         self.called = True
         self.call_count += 1
 
         # handle call_args
+        # needs to be set here so assertions on call arguments pass before
+        # execution in the case of awaited calls
         _call = _Call((args, kwargs), two=True)
         self.call_args = _call
         self.call_args_list.append(_call)
@@ -1123,6 +1129,10 @@ class CallableMixin(Base):
             # follow the parental chain:
             _new_parent = _new_parent._mock_new_parent
 
+    def _execute_mock_call(self, /, *args, **kwargs):
+        # seperate from _increment_mock_call so that awaited functions are
+        # executed seperately from their call
+
         effect = self.side_effect
         if effect is not None:
             if _is_exception(effect):
index 0d68c9757980fcb63259faa0f1eec605941d5a82..aca1cd0bbecc2bdc5e39ec3ee988bf27a2378d3b 100644 (file)
@@ -3,7 +3,7 @@ import inspect
 import unittest
 
 from unittest.mock import (ANY, call, AsyncMock, patch, MagicMock,
-                           create_autospec, _AwaitEvent)
+                           create_autospec, _AwaitEvent, sentinel, _CallList)
 
 
 def tearDownModule():
@@ -595,11 +595,173 @@ class AsyncMockAssert(unittest.TestCase):
     def setUp(self):
         self.mock = AsyncMock()
 
-    async def _runnable_test(self, *args):
-        if not args:
-            await self.mock()
-        else:
-            await self.mock(*args)
+    async def _runnable_test(self, *args, **kwargs):
+        await self.mock(*args, **kwargs)
+
+    async def _await_coroutine(self, coroutine):
+        return await coroutine
+
+    def test_assert_called_but_not_awaited(self):
+        mock = AsyncMock(AsyncClass)
+        with self.assertWarns(RuntimeWarning):
+            # Will raise a warning because never awaited
+            mock.async_method()
+        self.assertTrue(asyncio.iscoroutinefunction(mock.async_method))
+        mock.async_method.assert_called()
+        mock.async_method.assert_called_once()
+        mock.async_method.assert_called_once_with()
+        with self.assertRaises(AssertionError):
+            mock.assert_awaited()
+        with self.assertRaises(AssertionError):
+            mock.async_method.assert_awaited()
+
+    def test_assert_called_then_awaited(self):
+        mock = AsyncMock(AsyncClass)
+        mock_coroutine = mock.async_method()
+        mock.async_method.assert_called()
+        mock.async_method.assert_called_once()
+        mock.async_method.assert_called_once_with()
+        with self.assertRaises(AssertionError):
+            mock.async_method.assert_awaited()
+
+        asyncio.run(self._await_coroutine(mock_coroutine))
+        # Assert we haven't re-called the function
+        mock.async_method.assert_called_once()
+        mock.async_method.assert_awaited()
+        mock.async_method.assert_awaited_once()
+        mock.async_method.assert_awaited_once_with()
+
+    def test_assert_called_and_awaited_at_same_time(self):
+        with self.assertRaises(AssertionError):
+            self.mock.assert_awaited()
+
+        with self.assertRaises(AssertionError):
+            self.mock.assert_called()
+
+        asyncio.run(self._runnable_test())
+        self.mock.assert_called_once()
+        self.mock.assert_awaited_once()
+
+    def test_assert_called_twice_and_awaited_once(self):
+        mock = AsyncMock(AsyncClass)
+        coroutine = mock.async_method()
+        with self.assertWarns(RuntimeWarning):
+            # The first call will be awaited so no warning there
+            # But this call will never get awaited, so it will warn here
+            mock.async_method()
+        with self.assertRaises(AssertionError):
+            mock.async_method.assert_awaited()
+        mock.async_method.assert_called()
+        asyncio.run(self._await_coroutine(coroutine))
+        mock.async_method.assert_awaited()
+        mock.async_method.assert_awaited_once()
+
+    def test_assert_called_once_and_awaited_twice(self):
+        mock = AsyncMock(AsyncClass)
+        coroutine = mock.async_method()
+        mock.async_method.assert_called_once()
+        asyncio.run(self._await_coroutine(coroutine))
+        with self.assertRaises(RuntimeError):
+            # Cannot reuse already awaited coroutine
+            asyncio.run(self._await_coroutine(coroutine))
+        mock.async_method.assert_awaited()
+
+    def test_assert_awaited_but_not_called(self):
+        with self.assertRaises(AssertionError):
+            self.mock.assert_awaited()
+        with self.assertRaises(AssertionError):
+            self.mock.assert_called()
+        with self.assertRaises(TypeError):
+            # You cannot await an AsyncMock, it must be a coroutine
+            asyncio.run(self._await_coroutine(self.mock))
+
+        with self.assertRaises(AssertionError):
+            self.mock.assert_awaited()
+        with self.assertRaises(AssertionError):
+            self.mock.assert_called()
+
+    def test_assert_has_calls_not_awaits(self):
+        kalls = [call('foo')]
+        with self.assertWarns(RuntimeWarning):
+            # Will raise a warning because never awaited
+            self.mock('foo')
+        self.mock.assert_has_calls(kalls)
+        with self.assertRaises(AssertionError):
+            self.mock.assert_has_awaits(kalls)
+
+    def test_assert_has_mock_calls_on_async_mock_no_spec(self):
+        with self.assertWarns(RuntimeWarning):
+            # Will raise a warning because never awaited
+            self.mock()
+        kalls_empty = [('', (), {})]
+        self.assertEqual(self.mock.mock_calls, kalls_empty)
+
+        with self.assertWarns(RuntimeWarning):
+            # Will raise a warning because never awaited
+            self.mock('foo')
+            self.mock('baz')
+        mock_kalls = ([call(), call('foo'), call('baz')])
+        self.assertEqual(self.mock.mock_calls, mock_kalls)
+
+    def test_assert_has_mock_calls_on_async_mock_with_spec(self):
+        a_class_mock = AsyncMock(AsyncClass)
+        with self.assertWarns(RuntimeWarning):
+            # Will raise a warning because never awaited
+            a_class_mock.async_method()
+        kalls_empty = [('', (), {})]
+        self.assertEqual(a_class_mock.async_method.mock_calls, kalls_empty)
+        self.assertEqual(a_class_mock.mock_calls, [call.async_method()])
+
+        with self.assertWarns(RuntimeWarning):
+            # Will raise a warning because never awaited
+            a_class_mock.async_method(1, 2, 3, a=4, b=5)
+        method_kalls = [call(), call(1, 2, 3, a=4, b=5)]
+        mock_kalls = [call.async_method(), call.async_method(1, 2, 3, a=4, b=5)]
+        self.assertEqual(a_class_mock.async_method.mock_calls, method_kalls)
+        self.assertEqual(a_class_mock.mock_calls, mock_kalls)
+
+    def test_async_method_calls_recorded(self):
+        with self.assertWarns(RuntimeWarning):
+            # Will raise warnings because never awaited
+            self.mock.something(3, fish=None)
+            self.mock.something_else.something(6, cake=sentinel.Cake)
+
+        self.assertEqual(self.mock.method_calls, [
+            ("something", (3,), {'fish': None}),
+            ("something_else.something", (6,), {'cake': sentinel.Cake})
+        ],
+            "method calls not recorded correctly")
+        self.assertEqual(self.mock.something_else.method_calls,
+                         [("something", (6,), {'cake': sentinel.Cake})],
+                         "method calls not recorded correctly")
+
+    def test_async_arg_lists(self):
+        def assert_attrs(mock):
+            names = ('call_args_list', 'method_calls', 'mock_calls')
+            for name in names:
+                attr = getattr(mock, name)
+                self.assertIsInstance(attr, _CallList)
+                self.assertIsInstance(attr, list)
+                self.assertEqual(attr, [])
+
+        assert_attrs(self.mock)
+        with self.assertWarns(RuntimeWarning):
+            # Will raise warnings because never awaited
+            self.mock()
+            self.mock(1, 2)
+            self.mock(a=3)
+
+        self.mock.reset_mock()
+        assert_attrs(self.mock)
+
+        a_mock = AsyncMock(AsyncClass)
+        with self.assertWarns(RuntimeWarning):
+            # Will raise warnings because never awaited
+            a_mock.async_method()
+            a_mock.async_method(1, a=3)
+
+        a_mock.reset_mock()
+        assert_attrs(a_mock)
 
     def test_assert_awaited(self):
         with self.assertRaises(AssertionError):
@@ -645,20 +807,20 @@ class AsyncMockAssert(unittest.TestCase):
 
     def test_assert_any_wait(self):
         with self.assertRaises(AssertionError):
-            self.mock.assert_any_await('NormalFoo')
+            self.mock.assert_any_await('foo')
 
-        asyncio.run(self._runnable_test('foo'))
+        asyncio.run(self._runnable_test('baz'))
         with self.assertRaises(AssertionError):
-            self.mock.assert_any_await('NormalFoo')
+            self.mock.assert_any_await('foo')
 
-        asyncio.run(self._runnable_test('NormalFoo'))
-        self.mock.assert_any_await('NormalFoo')
+        asyncio.run(self._runnable_test('foo'))
+        self.mock.assert_any_await('foo')
 
         asyncio.run(self._runnable_test('SomethingElse'))
-        self.mock.assert_any_await('NormalFoo')
+        self.mock.assert_any_await('foo')
 
     def test_assert_has_awaits_no_order(self):
-        calls = [call('NormalFoo'), call('baz')]
+        calls = [call('foo'), call('baz')]
 
         with self.assertRaises(AssertionError) as cm:
             self.mock.assert_has_awaits(calls)
@@ -668,7 +830,7 @@ class AsyncMockAssert(unittest.TestCase):
         with self.assertRaises(AssertionError):
             self.mock.assert_has_awaits(calls)
 
-        asyncio.run(self._runnable_test('NormalFoo'))
+        asyncio.run(self._runnable_test('foo'))
         with self.assertRaises(AssertionError):
             self.mock.assert_has_awaits(calls)
 
@@ -703,7 +865,7 @@ class AsyncMockAssert(unittest.TestCase):
         mock_with_spec.assert_any_await(ANY, 1)
 
     def test_assert_has_awaits_ordered(self):
-        calls = [call('NormalFoo'), call('baz')]
+        calls = [call('foo'), call('baz')]
         with self.assertRaises(AssertionError):
             self.mock.assert_has_awaits(calls, any_order=True)
 
@@ -711,11 +873,11 @@ class AsyncMockAssert(unittest.TestCase):
         with self.assertRaises(AssertionError):
             self.mock.assert_has_awaits(calls, any_order=True)
 
-        asyncio.run(self._runnable_test('foo'))
+        asyncio.run(self._runnable_test('bamf'))
         with self.assertRaises(AssertionError):
             self.mock.assert_has_awaits(calls, any_order=True)
 
-        asyncio.run(self._runnable_test('NormalFoo'))
+        asyncio.run(self._runnable_test('foo'))
         self.mock.assert_has_awaits(calls, any_order=True)
 
         asyncio.run(self._runnable_test('qux'))
index 2bafa8266b63ba59c07ef768d7a1ea86c6087f24..ad67f98f87ccba47132767cc5f487b1103a00a9e 100644 (file)
@@ -850,6 +850,7 @@ class MockTest(unittest.TestCase):
     def test_setting_call(self):
         mock = Mock()
         def __call__(self, a):
+            self._increment_mock_call(a)
             return self._mock_call(a)
 
         type(mock).__call__ = __call__
@@ -2025,7 +2026,7 @@ class MockTest(unittest.TestCase):
             )
 
             mocks = [
-                Mock, MagicMock, NonCallableMock, NonCallableMagicMock
+                Mock, MagicMock, NonCallableMock, NonCallableMagicMock, AsyncMock
             ]
 
             for mock in mocks:
diff --git a/Misc/NEWS.d/next/Library/2019-09-16-09-54-42.bpo-38136.MdI-Zb.rst b/Misc/NEWS.d/next/Library/2019-09-16-09-54-42.bpo-38136.MdI-Zb.rst
new file mode 100644 (file)
index 0000000..78cad24
--- /dev/null
@@ -0,0 +1,3 @@
+Changes AsyncMock call count and await count to be two different counters.
+Now await count only counts when a coroutine has been awaited, not when it
+has been called, and vice-versa. Update the documentation around this.