]> granicus.if.org Git - python/commitdiff
Fix SF#1516184 and add a test to prevent regression.
authorPhillip J. Eby <pje@telecommunity.com>
Mon, 10 Jul 2006 19:03:29 +0000 (19:03 +0000)
committerPhillip J. Eby <pje@telecommunity.com>
Mon, 10 Jul 2006 19:03:29 +0000 (19:03 +0000)
Lib/inspect.py
Lib/test/test_inspect.py

index bf7f006bad2f6624144dad639ff0bb1f057f2747..311fe7ec697ff242be49d2f3f646f71b4d69b673 100644 (file)
@@ -355,40 +355,37 @@ def getsourcefile(object):
             return None
     if os.path.exists(filename):
         return filename
-    # Ugly but necessary - '<stdin>' and '<string>' mean that getmodule()
-    # would infinitely recurse, because they're not real files nor loadable
-    # Note that this means that writing a PEP 302 loader that uses '<'
-    # at the start of a filename is now not a good idea.  :(
-    if filename[:1]!='<' and hasattr(getmodule(object), '__loader__'):
+    # only return a non-existent filename if the module has a PEP 302 loader
+    if hasattr(getmodule(object, filename), '__loader__'):
         return filename
 
-def getabsfile(object):
+def getabsfile(object, _filename=None):
     """Return an absolute path to the source or compiled file for an object.
 
     The idea is for each object to have a unique origin, so this routine
     normalizes the result as much as possible."""
     return os.path.normcase(
-        os.path.abspath(getsourcefile(object) or getfile(object)))
+        os.path.abspath(_filename or getsourcefile(object) or getfile(object)))
 
 modulesbyfile = {}
 
-def getmodule(object):
+def getmodule(object, _filename=None):
     """Return the module an object was defined in, or None if not found."""
     if ismodule(object):
         return object
     if hasattr(object, '__module__'):
         return sys.modules.get(object.__module__)
     try:
-        file = getabsfile(object)
+        file = getabsfile(object, _filename)
     except TypeError:
         return None
     if file in modulesbyfile:
         return sys.modules.get(modulesbyfile[file])
     for module in sys.modules.values():
         if ismodule(module) and hasattr(module, '__file__'):
-            modulesbyfile[
-                os.path.realpath(
-                        getabsfile(module))] = module.__name__
+            f = getabsfile(module)
+            modulesbyfile[f] = modulesbyfile[
+                os.path.realpath(f)] = module.__name__
     if file in modulesbyfile:
         return sys.modules.get(modulesbyfile[file])
     main = sys.modules['__main__']
index 62c40eba4d78b4056d79b017778119e8cbbd2a82..d100d22ab560b3368abf61f2491fd4069bbd13b7 100644 (file)
@@ -178,6 +178,16 @@ class TestRetrievingSourceCode(GetSourceBase):
     def test_getfile(self):
         self.assertEqual(inspect.getfile(mod.StupidGit), mod.__file__)
 
+    def test_getmodule_recursion(self):
+        from new import module
+        name = '__inspect_dummy'
+        m = sys.modules[name] = module(name)
+        m.__file__ = "<string>" # hopefully not a real filename... 
+        m.__loader__ = "dummy"  # pretend the filename is understood by a loader
+        exec "def x(): pass" in m.__dict__
+        self.assertEqual(inspect.getsourcefile(m.x.func_code), '<string>')
+        del sys.modules[name]
+
 class TestDecorators(GetSourceBase):
     fodderFile = mod2