]> granicus.if.org Git - python/commitdiff
Refactor importlib.__import__() and _gcd_import() to facilitate using
authorBrett Cannon <brett@python.org>
Thu, 16 Feb 2012 18:43:41 +0000 (13:43 -0500)
committerBrett Cannon <brett@python.org>
Thu, 16 Feb 2012 18:43:41 +0000 (13:43 -0500)
an __import__ implementation that takes care of basics in C and punts
to importlib for more complicated code.

Lib/importlib/_bootstrap.py

index f0650dd569ebf9adbf1a502d32ff258536215b96..ad1561efb66900f2f99371a1b2176fc22e394cf3 100644 (file)
@@ -861,19 +861,50 @@ class _ImportLockContext:
         imp.release_lock()
 
 
-_IMPLICIT_META_PATH = [BuiltinImporter, FrozenImporter, _DefaultPathFinder]
+def _resolve_name(name, package, level):
+    """Resolve a relative module name to an absolute one."""
+    dot = len(package)
+    for x in range(level, 1, -1):
+        try:
+            dot = package.rindex('.', 0, dot)
+        except ValueError:
+            raise ValueError("attempted relative import beyond "
+                             "top-level package")
+    if name:
+        return "{0}.{1}".format(package[:dot], name)
+    else:
+        return package[:dot]
+
+
+def _find_module(name, path):
+    """Find a module's loader."""
+    meta_path = sys.meta_path + _IMPLICIT_META_PATH
+    for finder in meta_path:
+        loader = finder.find_module(name, path)
+        if loader is not None:
+            # The parent import may have already imported this module.
+            if name not in sys.modules:
+                return loader
+            else:
+                return sys.modules[name].__loader__
+    else:
+        return None
 
-_ERR_MSG = 'No module named {!r}'
 
-def _gcd_import(name, package=None, level=0):
-    """Import and return the module based on its name, the package the call is
-    being made from, and the level adjustment.
+def _set___package__(module):
+    """Set __package__ on a module."""
+    # Watch out for what comes out of sys.modules to not be a module,
+    # e.g. an int.
+    try:
+        module.__package__ = module.__name__
+        if not hasattr(module, '__path__'):
+            module.__package__ = module.__package__.rpartition('.')[0]
+    except AttributeError:
+        pass
 
-    This function represents the greatest common denominator of functionality
-    between import_module and __import__. This includes setting __package__ if
-    the loader did not.
 
-    """
+def _sanity_check(name, package, level):
+    """Verify arguments are "sane"."""
     if package:
         if not hasattr(package, 'rindex'):
             raise ValueError("__package__ not set to a string")
@@ -883,18 +914,47 @@ def _gcd_import(name, package=None, level=0):
             raise SystemError(msg.format(package))
     if not name and level == 0:
         raise ValueError("Empty module name")
+
+
+def _find_search_path(name, import_):
+    """Find the search path for a module.
+
+    import_ is expected to be a callable which takes the name of a module to
+    import. It is required to decouple the function from importlib.
+
+    """
+    path = None
+    parent = name.rpartition('.')[0]
+    if parent:
+        if parent not in sys.modules:
+            import_(parent)
+        # Backwards-compatibility; be nicer to skip the dict lookup.
+        parent_module = sys.modules[parent]
+        try:
+            path = parent_module.__path__
+        except AttributeError:
+            msg = (_ERR_MSG + '; {} is not a package').format(name, parent)
+            raise ImportError(msg)
+    return parent, path
+
+
+
+_IMPLICIT_META_PATH = [BuiltinImporter, FrozenImporter, _DefaultPathFinder]
+
+_ERR_MSG = 'No module named {!r}'
+
+def _gcd_import(name, package=None, level=0):
+    """Import and return the module based on its name, the package the call is
+    being made from, and the level adjustment.
+
+    This function represents the greatest common denominator of functionality
+    between import_module and __import__. This includes setting __package__ if
+    the loader did not.
+
+    """
+    _sanity_check(name, package, level)
     if level > 0:
-        dot = len(package)
-        for x in range(level, 1, -1):
-            try:
-                dot = package.rindex('.', 0, dot)
-            except ValueError:
-                raise ValueError("attempted relative import beyond "
-                                 "top-level package")
-        if name:
-            name = "{0}.{1}".format(package[:dot], name)
-        else:
-            name = package[:dot]
+        name = _resolve_name(name, package, level)
     with _ImportLockContext():
         try:
             module = sys.modules[name]
@@ -905,70 +965,33 @@ def _gcd_import(name, package=None, level=0):
             return module
         except KeyError:
             pass
-        parent = name.rpartition('.')[0]
-        path = None
-        if parent:
-            if parent not in sys.modules:
-                _gcd_import(parent)
-            # Backwards-compatibility; be nicer to skip the dict lookup.
-            parent_module = sys.modules[parent]
-            try:
-                path = parent_module.__path__
-            except AttributeError:
-                msg = (_ERR_MSG + '; {} is not a package').format(name, parent)
-                raise ImportError(msg)
-        meta_path = sys.meta_path + _IMPLICIT_META_PATH
-        for finder in meta_path:
-            loader = finder.find_module(name, path)
-            if loader is not None:
-                # The parent import may have already imported this module.
-                if name not in sys.modules:
-                    loader.load_module(name)
-                break
-        else:
+        parent, path = _find_search_path(name, _gcd_import)
+        loader = _find_module(name, path)
+        if loader is None:
             raise ImportError(_ERR_MSG.format(name))
+        elif name not in sys.modules:
+            # The parent import may have already imported this module.
+            loader.load_module(name)
         # Backwards-compatibility; be nicer to skip the dict lookup.
         module = sys.modules[name]
         if parent:
             # Set the module as an attribute on its parent.
+            parent_module = sys.modules[parent]
             setattr(parent_module, name.rpartition('.')[2], module)
         # Set __package__ if the loader did not.
         if not hasattr(module, '__package__') or module.__package__ is None:
-            # Watch out for what comes out of sys.modules to not be a module,
-            # e.g. an int.
-            try:
-                module.__package__ = module.__name__
-                if not hasattr(module, '__path__'):
-                    module.__package__ = module.__package__.rpartition('.')[0]
-            except AttributeError:
-                pass
+            _set___package__(module)
         return module
 
 
-def __import__(name, globals={}, locals={}, fromlist=[], level=0):
-    """Import a module.
+def _return_module(module, name, fromlist, level, import_):
+    """Figure out what __import__ should return.
 
-    The 'globals' argument is used to infer where the import is occuring from
-    to handle relative imports. The 'locals' argument is ignored. The
-    'fromlist' argument specifies what should exist as attributes on the module
-    being imported (e.g. ``from module import <fromlist>``).  The 'level'
-    argument represents the package location to import from in a relative
-    import (e.g. ``from ..pkg import mod`` would have a 'level' of 2).
+    The import_ parameter is a callable which takes the name of module to
+    import. It is required to decouple the function from assuming importlib's
+    import implementation is desired.
 
     """
-    if not hasattr(name, 'rpartition'):
-        raise TypeError("module name must be str, not {}".format(type(name)))
-    if level == 0:
-        module = _gcd_import(name)
-    else:
-        # __package__ is not guaranteed to be defined or could be set to None
-        # to represent that its proper value is unknown
-        package = globals.get('__package__')
-        if package is None:
-            package = globals['__name__']
-            if '__path__' not in globals:
-                package = package.rpartition('.')[0]
-        module = _gcd_import(name, package, level)
     # The hell that is fromlist ...
     if not fromlist:
         # Return up to the first dot in 'name'. This is complicated by the fact
@@ -989,12 +1012,48 @@ def __import__(name, globals={}, locals={}, fromlist=[], level=0):
                 fromlist.extend(module.__all__)
             for x in (y for y in fromlist if not hasattr(module,y)):
                 try:
-                    _gcd_import('{0}.{1}'.format(module.__name__, x))
+                    import_('{0}.{1}'.format(module.__name__, x))
                 except ImportError:
                     pass
         return module
 
 
+def _calc___package__(globals):
+    """Calculate what __package__ should be.
+
+    __package__ is not guaranteed to be defined or could be set to None
+    to represent that its proper value is unknown.
+
+    """
+    package = globals.get('__package__')
+    if package is None:
+        package = globals['__name__']
+        if '__path__' not in globals:
+            package = package.rpartition('.')[0]
+    return package
+
+
+def __import__(name, globals={}, locals={}, fromlist=[], level=0):
+    """Import a module.
+
+    The 'globals' argument is used to infer where the import is occuring from
+    to handle relative imports. The 'locals' argument is ignored. The
+    'fromlist' argument specifies what should exist as attributes on the module
+    being imported (e.g. ``from module import <fromlist>``).  The 'level'
+    argument represents the package location to import from in a relative
+    import (e.g. ``from ..pkg import mod`` would have a 'level' of 2).
+
+    """
+    if not hasattr(name, 'rpartition'):
+        raise TypeError("module name must be str, not {}".format(type(name)))
+    if level == 0:
+        module = _gcd_import(name)
+    else:
+        package = _calc___package__(globals)
+        module = _gcd_import(name, package, level)
+    return _return_module(module, name, fromlist, level, _gcd_import)
+
+
 def _setup(sys_module, imp_module):
     """Setup importlib by importing needed built-in modules and injecting them
     into the global namespace.