__version__ = '1.0'
import asyncio
+import contextlib
import io
import inspect
import pprint
def __call__(self, func):
if isinstance(func, type):
return self.decorate_class(func)
+ if inspect.iscoroutinefunction(func):
+ return self.decorate_async_callable(func)
return self.decorate_callable(func)
return klass
+ @contextlib.contextmanager
+ def decoration_helper(self, patched, args, keywargs):
+ extra_args = []
+ entered_patchers = []
+ patching = None
+
+ exc_info = tuple()
+ try:
+ for patching in patched.patchings:
+ arg = patching.__enter__()
+ entered_patchers.append(patching)
+ if patching.attribute_name is not None:
+ keywargs.update(arg)
+ elif patching.new is DEFAULT:
+ extra_args.append(arg)
+
+ args += tuple(extra_args)
+ yield (args, keywargs)
+ except:
+ if (patching not in entered_patchers and
+ _is_started(patching)):
+ # the patcher may have been started, but an exception
+ # raised whilst entering one of its additional_patchers
+ entered_patchers.append(patching)
+ # Pass the exception to __exit__
+ exc_info = sys.exc_info()
+ # re-raise the exception
+ raise
+ finally:
+ for patching in reversed(entered_patchers):
+ patching.__exit__(*exc_info)
+
+
def decorate_callable(self, func):
+ # NB. Keep the method in sync with decorate_async_callable()
if hasattr(func, 'patchings'):
func.patchings.append(self)
return func
@wraps(func)
def patched(*args, **keywargs):
- extra_args = []
- entered_patchers = []
+ with self.decoration_helper(patched,
+ args,
+ keywargs) as (newargs, newkeywargs):
+ return func(*newargs, **newkeywargs)
- exc_info = tuple()
- try:
- for patching in patched.patchings:
- arg = patching.__enter__()
- entered_patchers.append(patching)
- if patching.attribute_name is not None:
- keywargs.update(arg)
- elif patching.new is DEFAULT:
- extra_args.append(arg)
-
- args += tuple(extra_args)
- return func(*args, **keywargs)
- except:
- if (patching not in entered_patchers and
- _is_started(patching)):
- # the patcher may have been started, but an exception
- # raised whilst entering one of its additional_patchers
- entered_patchers.append(patching)
- # Pass the exception to __exit__
- exc_info = sys.exc_info()
- # re-raise the exception
- raise
- finally:
- for patching in reversed(entered_patchers):
- patching.__exit__(*exc_info)
+ patched.patchings = [self]
+ return patched
+
+
+ def decorate_async_callable(self, func):
+ # NB. Keep the method in sync with decorate_callable()
+ if hasattr(func, 'patchings'):
+ func.patchings.append(self)
+ return func
+
+ @wraps(func)
+ async def patched(*args, **keywargs):
+ with self.decoration_helper(patched,
+ args,
+ keywargs) as (newargs, newkeywargs):
+ return await func(*newargs, **newkeywargs)
patched.patchings = [self]
return patched
test_async()
+ def test_async_def_patch(self):
+ @patch(f"{__name__}.async_func", AsyncMock())
+ async def test_async():
+ self.assertIsInstance(async_func, AsyncMock)
+
+ asyncio.run(test_async())
+ self.assertTrue(inspect.iscoroutinefunction(async_func))
+
class AsyncPatchCMTest(unittest.TestCase):
def test_is_async_function_cm(self):
test_async()
+ def test_async_def_cm(self):
+ async def test_async():
+ with patch(f"{__name__}.async_func", AsyncMock()):
+ self.assertIsInstance(async_func, AsyncMock)
+ self.assertTrue(inspect.iscoroutinefunction(async_func))
+
+ asyncio.run(test_async())
+
class AsyncMockTest(unittest.TestCase):
def test_iscoroutinefunction_default(self):