]> granicus.if.org Git - python/commitdiff
#22315: Use technique outlined in test_file_util
authorJason R. Coombs <jaraco@jaraco.com>
Sun, 31 Aug 2014 21:31:32 +0000 (17:31 -0400)
committerJason R. Coombs <jaraco@jaraco.com>
Sun, 31 Aug 2014 21:31:32 +0000 (17:31 -0400)
Lib/distutils/tests/test_dir_util.py

index eb83497fc6c143cc606756151568f4ca05746f8a..c9f789c8955385297f059ee152ae9351611e5c86 100644 (file)
@@ -3,7 +3,7 @@ import unittest
 import os
 import stat
 import sys
-import contextlib
+from unittest.mock import patch
 
 from distutils import dir_util, errors
 from distutils.dir_util import (mkpath, remove_tree, create_tree, copy_tree,
@@ -14,19 +14,6 @@ from distutils.tests import support
 from test.support import run_unittest
 
 
-@contextlib.context_manager
-def patch_obj(obj, attr, replacement):
-    """
-    A poor man's mock.patch.object
-    """
-    orig = getattr(obj, attr)
-    try:
-        setattr(obj, attr, replacement)
-        yield
-    finally:
-        setattr(obj, attr, orig)
-
-
 class DirUtilTestCase(support.TempdirManager, unittest.TestCase):
 
     def _log(self, msg, *args):
@@ -135,17 +122,13 @@ class DirUtilTestCase(support.TempdirManager, unittest.TestCase):
             self.assertEqual(ensure_relative('c:\\home\\foo'), 'c:home\\foo')
             self.assertEqual(ensure_relative('home\\foo'), 'home\\foo')
 
-    def test_copy_tree_exception_in_listdir(self):
+    @patch('os.listdir', side_effect=OSError())
+    def test_copy_tree_exception_in_listdir(self, listdir):
         """
         An exception in listdir should raise a DistutilsFileError
         """
-        def new_listdir(path):
-            raise OSError()
-        # simulate a transient network error or other failure invoking listdir
-        with patch_obj(os, 'listdir', new_listdir):
-            args = 'src', None
-            exc = errors.DistutilsFileError
-            self.assertRaises(exc, dir_util.copy_tree, *args)
+        with self.assertRaises(errors.DistutilsFileError):
+            dir_util.copy_tree('src', None)
 
 
 def test_suite():