From 4b03b6863575c7e85cc7ec3e8e25034b37043986 Mon Sep 17 00:00:00 2001 From: Brett Cannon Date: Thu, 23 Feb 2012 20:47:57 -0500 Subject: [PATCH] Turn _return_module() into _handle_fromlist(). --- Lib/importlib/_bootstrap.py | 50 ++++++++++++++++++------------------- 1 file changed, 25 insertions(+), 25 deletions(-) diff --git a/Lib/importlib/_bootstrap.py b/Lib/importlib/_bootstrap.py index ccdea85df6..d880d31aac 100644 --- a/Lib/importlib/_bootstrap.py +++ b/Lib/importlib/_bootstrap.py @@ -1001,7 +1001,7 @@ def _gcd_import(name, package=None, level=0): return _find_and_load(name, _gcd_import) -def _return_module(module, name, fromlist, level, import_): +def _handle_fromlist(module, fromlist, import_): """Figure out what __import__ should return. The import_ parameter is a callable which takes the name of module to @@ -1010,29 +1010,18 @@ def _return_module(module, name, fromlist, level, import_): """ # The hell that is fromlist ... - if not fromlist: - # Return up to the first dot in 'name'. This is complicated by the fact - # that 'name' may be relative. - if level == 0: - return sys.modules[name.partition('.')[0]] - elif not name: - return module - else: - cut_off = len(name) - len(name.partition('.')[0]) - return sys.modules[module.__name__[:-cut_off]] - else: - # If a package was imported, try to import stuff from fromlist. - if hasattr(module, '__path__'): - if '*' in fromlist and hasattr(module, '__all__'): - fromlist = list(fromlist) - fromlist.remove('*') - fromlist.extend(module.__all__) - for x in (y for y in fromlist if not hasattr(module,y)): - try: - import_('{0}.{1}'.format(module.__name__, x)) - except ImportError: - pass - return module + # If a package was imported, try to import stuff from fromlist. + if hasattr(module, '__path__'): + if '*' in fromlist and hasattr(module, '__all__'): + fromlist = list(fromlist) + fromlist.remove('*') + fromlist.extend(module.__all__) + for x in (y for y in fromlist if not hasattr(module,y)): + try: + import_('{0}.{1}'.format(module.__name__, x)) + except ImportError: + pass + return module def _calc___package__(globals): @@ -1066,7 +1055,18 @@ def __import__(name, globals={}, locals={}, fromlist=[], level=0): else: package = _calc___package__(globals) module = _gcd_import(name, package, level) - return _return_module(module, name, fromlist, level, _gcd_import) + if not fromlist: + # Return up to the first dot in 'name'. This is complicated by the fact + # that 'name' may be relative. + if level == 0: + return sys.modules[name.partition('.')[0]] + elif not name: + return module + else: + cut_off = len(name) - len(name.partition('.')[0]) + return sys.modules[module.__name__[:-cut_off]] + else: + return _handle_fromlist(module, fromlist, _gcd_import) def _setup(sys_module, imp_module): -- 2.40.0