]> granicus.if.org Git - python/commitdiff
bpo-36996: Handle async functions when mock.patch is used as a decorator (GH-13562)
authorXtreak <tir.karthi@gmail.com>
Tue, 28 May 2019 07:07:39 +0000 (12:37 +0530)
committerMiss Islington (bot) <31488909+miss-islington@users.noreply.github.com>
Tue, 28 May 2019 07:07:38 +0000 (00:07 -0700)
Return a coroutine while patching async functions with a decorator.

Co-authored-by: Andrew Svetlov <andrew.svetlov@gmail.com>
https://bugs.python.org/issue36996

Lib/unittest/mock.py
Lib/unittest/test/testmock/testasync.py
Misc/NEWS.d/next/Library/2019-05-22-22-55-18.bpo-36996.XQx08d.rst [new file with mode: 0644]

index b91afd88dd132e16c444b92b1bed3375b311af52..fac4535747c4c81c6955cdbfce758be186690834 100644 (file)
@@ -26,6 +26,7 @@ __all__ = (
 __version__ = '1.0'
 
 import asyncio
+import contextlib
 import io
 import inspect
 import pprint
@@ -1220,6 +1221,8 @@ class _patch(object):
     def __call__(self, func):
         if isinstance(func, type):
             return self.decorate_class(func)
+        if inspect.iscoroutinefunction(func):
+            return self.decorate_async_callable(func)
         return self.decorate_callable(func)
 
 
@@ -1237,41 +1240,68 @@ class _patch(object):
         return klass
 
 
+    @contextlib.contextmanager
+    def decoration_helper(self, patched, args, keywargs):
+        extra_args = []
+        entered_patchers = []
+        patching = None
+
+        exc_info = tuple()
+        try:
+            for patching in patched.patchings:
+                arg = patching.__enter__()
+                entered_patchers.append(patching)
+                if patching.attribute_name is not None:
+                    keywargs.update(arg)
+                elif patching.new is DEFAULT:
+                    extra_args.append(arg)
+
+            args += tuple(extra_args)
+            yield (args, keywargs)
+        except:
+            if (patching not in entered_patchers and
+                _is_started(patching)):
+                # the patcher may have been started, but an exception
+                # raised whilst entering one of its additional_patchers
+                entered_patchers.append(patching)
+            # Pass the exception to __exit__
+            exc_info = sys.exc_info()
+            # re-raise the exception
+            raise
+        finally:
+            for patching in reversed(entered_patchers):
+                patching.__exit__(*exc_info)
+
+
     def decorate_callable(self, func):
+        # NB. Keep the method in sync with decorate_async_callable()
         if hasattr(func, 'patchings'):
             func.patchings.append(self)
             return func
 
         @wraps(func)
         def patched(*args, **keywargs):
-            extra_args = []
-            entered_patchers = []
+            with self.decoration_helper(patched,
+                                        args,
+                                        keywargs) as (newargs, newkeywargs):
+                return func(*newargs, **newkeywargs)
 
-            exc_info = tuple()
-            try:
-                for patching in patched.patchings:
-                    arg = patching.__enter__()
-                    entered_patchers.append(patching)
-                    if patching.attribute_name is not None:
-                        keywargs.update(arg)
-                    elif patching.new is DEFAULT:
-                        extra_args.append(arg)
-
-                args += tuple(extra_args)
-                return func(*args, **keywargs)
-            except:
-                if (patching not in entered_patchers and
-                    _is_started(patching)):
-                    # the patcher may have been started, but an exception
-                    # raised whilst entering one of its additional_patchers
-                    entered_patchers.append(patching)
-                # Pass the exception to __exit__
-                exc_info = sys.exc_info()
-                # re-raise the exception
-                raise
-            finally:
-                for patching in reversed(entered_patchers):
-                    patching.__exit__(*exc_info)
+        patched.patchings = [self]
+        return patched
+
+
+    def decorate_async_callable(self, func):
+        # NB. Keep the method in sync with decorate_callable()
+        if hasattr(func, 'patchings'):
+            func.patchings.append(self)
+            return func
+
+        @wraps(func)
+        async def patched(*args, **keywargs):
+            with self.decoration_helper(patched,
+                                        args,
+                                        keywargs) as (newargs, newkeywargs):
+                return await func(*newargs, **newkeywargs)
 
         patched.patchings = [self]
         return patched
index 0519d59696f6c6500854cfb5a890d7c3654a94b1..ccea4fe242dc074247742eecd0755bb518ca664e 100644 (file)
@@ -66,6 +66,14 @@ class AsyncPatchDecoratorTest(unittest.TestCase):
 
         test_async()
 
+    def test_async_def_patch(self):
+        @patch(f"{__name__}.async_func", AsyncMock())
+        async def test_async():
+            self.assertIsInstance(async_func, AsyncMock)
+
+        asyncio.run(test_async())
+        self.assertTrue(inspect.iscoroutinefunction(async_func))
+
 
 class AsyncPatchCMTest(unittest.TestCase):
     def test_is_async_function_cm(self):
@@ -91,6 +99,14 @@ class AsyncPatchCMTest(unittest.TestCase):
 
         test_async()
 
+    def test_async_def_cm(self):
+        async def test_async():
+            with patch(f"{__name__}.async_func", AsyncMock()):
+                self.assertIsInstance(async_func, AsyncMock)
+            self.assertTrue(inspect.iscoroutinefunction(async_func))
+
+        asyncio.run(test_async())
+
 
 class AsyncMockTest(unittest.TestCase):
     def test_iscoroutinefunction_default(self):
diff --git a/Misc/NEWS.d/next/Library/2019-05-22-22-55-18.bpo-36996.XQx08d.rst b/Misc/NEWS.d/next/Library/2019-05-22-22-55-18.bpo-36996.XQx08d.rst
new file mode 100644 (file)
index 0000000..69d18d9
--- /dev/null
@@ -0,0 +1 @@
+Handle :func:`unittest.mock.patch` used as a decorator on async functions.