]> granicus.if.org Git - python/commitdiff
Support dotted module names for test discovery paths in unittest. Issue 7780.
authorMichael Foord <fuzzyman@voidspace.org.uk>
Sat, 3 Apr 2010 01:15:21 +0000 (01:15 +0000)
committerMichael Foord <fuzzyman@voidspace.org.uk>
Sat, 3 Apr 2010 01:15:21 +0000 (01:15 +0000)
Lib/unittest/loader.py
Lib/unittest/test/test_discovery.py

index 022ed5781605a1f491a7f8a15df457f4be7468d6..e0b8585209b39adc18ee12201f7fa88341895fb9 100644 (file)
@@ -171,22 +171,41 @@ class TestLoader(object):
         packages can continue discovery themselves. top_level_dir is stored so
         load_tests does not need to pass this argument in to loader.discover().
         """
+        set_implicit_top = False
         if top_level_dir is None and self._top_level_dir is not None:
             # make top_level_dir optional if called from load_tests in a package
             top_level_dir = self._top_level_dir
         elif top_level_dir is None:
+            set_implicit_top = True
             top_level_dir = start_dir
 
-        top_level_dir = os.path.abspath(os.path.normpath(top_level_dir))
-        start_dir = os.path.abspath(os.path.normpath(start_dir))
+        top_level_dir = os.path.abspath(top_level_dir)
 
         if not top_level_dir in sys.path:
             # all test modules must be importable from the top level directory
             sys.path.append(top_level_dir)
         self._top_level_dir = top_level_dir
 
-        if start_dir != top_level_dir and not os.path.isfile(os.path.join(start_dir, '__init__.py')):
-            # what about __init__.pyc or pyo (etc)
+        is_not_importable = False
+        if os.path.isdir(os.path.abspath(start_dir)):
+            start_dir = os.path.abspath(start_dir)
+            if start_dir != top_level_dir:
+                is_not_importable = not os.path.isfile(os.path.join(start_dir, '__init__.py'))
+        else:
+            # support for discovery from dotted module names
+            try:
+                __import__(start_dir)
+            except ImportError:
+                is_not_importable = True
+            else:
+                the_module = sys.modules[start_dir]
+                top_part = start_dir.split('.')[0]
+                start_dir = os.path.abspath(os.path.dirname((the_module.__file__)))
+                if set_implicit_top:
+                    self._top_level_dir = os.path.abspath(os.path.dirname(os.path.dirname(sys.modules[top_part].__file__)))
+                    sys.path.remove(top_level_dir)
+
+        if is_not_importable:
             raise ImportError('Start directory is not importable: %r' % start_dir)
 
         tests = list(self._find_tests(start_dir, pattern))
index 0221fc2945bde51a68fdf4edbc6ed1948c221b3d..de6096826aac9779a8fab84560e2b9d8ade6a206 100644 (file)
@@ -132,6 +132,7 @@ class TestDiscovery(unittest.TestCase):
         loader = unittest.TestLoader()
 
         original_isfile = os.path.isfile
+        original_isdir = os.path.isdir
         def restore_isfile():
             os.path.isfile = original_isfile
 
@@ -151,6 +152,12 @@ class TestDiscovery(unittest.TestCase):
         self.assertIn(full_path, sys.path)
 
         os.path.isfile = lambda path: True
+        os.path.isdir = lambda path: True
+
+        def restore_isdir():
+            os.path.isdir = original_isdir
+        self.addCleanup(restore_isdir)
+
         _find_tests_args = []
         def _find_tests(start_dir, pattern):
             _find_tests_args.append((start_dir, pattern))
@@ -160,8 +167,8 @@ class TestDiscovery(unittest.TestCase):
 
         suite = loader.discover('/foo/bar/baz', 'pattern', '/foo/bar')
 
-        top_level_dir = os.path.abspath(os.path.normpath('/foo/bar'))
-        start_dir = os.path.abspath(os.path.normpath('/foo/bar/baz'))
+        top_level_dir = os.path.abspath('/foo/bar')
+        start_dir = os.path.abspath('/foo/bar/baz')
         self.assertEqual(suite, "['tests']")
         self.assertEqual(loader._top_level_dir, top_level_dir)
         self.assertEqual(_find_tests_args, [(start_dir, 'pattern')])