]> granicus.if.org Git - python/commitdiff
Adding patch.stopall method to unittest.mock
authorMichael Foord <michael@voidspace.org.uk>
Sun, 10 Jun 2012 19:36:32 +0000 (20:36 +0100)
committerMichael Foord <michael@voidspace.org.uk>
Sun, 10 Jun 2012 19:36:32 +0000 (20:36 +0100)
Doc/library/unittest.mock.rst
Lib/unittest/mock.py
Lib/unittest/test/testmock/testpatch.py

index 12b0275aab6fd713b0bbb2833fbbeaa550d10113..1b3969726130c1b6f378a7e2286dac5e3f37f73b 100644 (file)
@@ -1354,8 +1354,12 @@ method of a `TestCase`:
     As an added bonus you no longer need to keep a reference to the `patcher`
     object.
 
-In fact `start` and `stop` are just aliases for the context manager
-`__enter__` and `__exit__` methods.
+It is also possible to stop all patches which have been started by using
+`patch.stopall`.
+
+.. function:: patch.stopall
+
+    Stop all active patches.
 
 
 TEST_PREFIX
index 4ae3d16139e27d4fc99765a632404ca4d289e819..95570aa3a993f1f0b2fb2707cf7e8560a3c358c9 100644 (file)
@@ -1002,6 +1002,7 @@ def _is_started(patcher):
 class _patch(object):
 
     attribute_name = None
+    _active_patches = set()
 
     def __init__(
             self, getter, attribute, new, spec, create,
@@ -1270,8 +1271,18 @@ class _patch(object):
             if _is_started(patcher):
                 patcher.__exit__(*exc_info)
 
-    start = __enter__
-    stop = __exit__
+
+    def start(self):
+        """Activate a patch, returning any created mock."""
+        result = self.__enter__()
+        self._active_patches.add(self)
+        return result
+
+
+    def stop(self):
+        """Stop an active patch."""
+        self._active_patches.discard(self)
+        return self.__exit__()
 
 
 
@@ -1562,9 +1573,16 @@ def _clear_dict(in_dict):
             del in_dict[key]
 
 
+def _patch_stopall():
+    """Stop all active patches."""
+    for patch in list(_patch._active_patches):
+        patch.stop()
+
+
 patch.object = _patch_object
 patch.dict = _patch_dict
 patch.multiple = _patch_multiple
+patch.stopall = _patch_stopall
 patch.TEST_PREFIX = 'test'
 
 magic_methods = (
index 62568554da24b207f634bcf877a6d9a25d51ac61..c1091b4e9b37c9472c2c559070f670986e8293e7 100644 (file)
@@ -1762,6 +1762,24 @@ class PatchTest(unittest.TestCase):
             p.stop()
 
 
+    def test_patch_stopall(self):
+        unlink = os.unlink
+        chdir = os.chdir
+        path = os.path
+        patch('os.unlink', something).start()
+        patch('os.chdir', something_else).start()
+
+        @patch('os.path')
+        def patched(mock_path):
+            patch.stopall()
+            self.assertIs(os.path, mock_path)
+            self.assertIs(os.unlink, unlink)
+            self.assertIs(os.chdir, chdir)
+
+        patched()
+        self.assertIs(os.path, path)
+
+
 
 if __name__ == '__main__':
     unittest.main()