]> granicus.if.org Git - python/commitdiff
Issue #1180193: When importing a module from a .pyc (or .pyo) file with
authorAntoine Pitrou <solipsis@pitrou.net>
Tue, 6 Jan 2009 18:10:47 +0000 (18:10 +0000)
committerAntoine Pitrou <solipsis@pitrou.net>
Tue, 6 Jan 2009 18:10:47 +0000 (18:10 +0000)
an existing .py counterpart, override the co_filename attributes of all
code objects if the original filename is obsolete (which can happen if the
file has been renamed, moved, or if it is accessed through different paths).
Patch by Ziga Seilnacht and Jean-Paul Calderone.

Lib/test/test_import.py
Misc/NEWS
Python/import.c

index 13e8cc3bd49e8a096534772d562cdbb1c4df2c7f..7318a9e7e049fd461cdaf82c75c35b3a68ca6e18 100644 (file)
@@ -5,6 +5,7 @@ import shutil
 import sys
 import py_compile
 import warnings
+import marshal
 from test.test_support import unlink, TESTFN, unload, run_unittest, check_warnings
 
 
@@ -231,6 +232,97 @@ class ImportTest(unittest.TestCase):
         else:
             self.fail("import by path didn't raise an exception")
 
+class TestPycRewriting(unittest.TestCase):
+    # Test that the `co_filename` attribute on code objects always points
+    # to the right file, even when various things happen (e.g. both the .py
+    # and the .pyc file are renamed).
+
+    module_name = "unlikely_module_name"
+    module_source = """
+import sys
+code_filename = sys._getframe().f_code.co_filename
+module_filename = __file__
+constant = 1
+def func():
+    pass
+func_filename = func.func_code.co_filename
+"""
+    dir_name = os.path.abspath(TESTFN)
+    file_name = os.path.join(dir_name, module_name) + os.extsep + "py"
+    compiled_name = file_name + ("c" if __debug__ else "o")
+
+    def setUp(self):
+        self.sys_path = sys.path[:]
+        self.orig_module = sys.modules.pop(self.module_name, None)
+        os.mkdir(self.dir_name)
+        with open(self.file_name, "w") as f:
+            f.write(self.module_source)
+        sys.path.insert(0, self.dir_name)
+
+    def tearDown(self):
+        sys.path[:] = self.sys_path
+        if self.orig_module is not None:
+            sys.modules[self.module_name] = self.orig_module
+        else:
+            del sys.modules[self.module_name]
+        for file_name in self.file_name, self.compiled_name:
+            if os.path.exists(file_name):
+                os.remove(file_name)
+        if os.path.exists(self.dir_name):
+            os.rmdir(self.dir_name)
+
+    def import_module(self):
+        ns = globals()
+        __import__(self.module_name, ns, ns)
+        return sys.modules[self.module_name]
+
+    def test_basics(self):
+        mod = self.import_module()
+        self.assertEqual(mod.module_filename, self.file_name)
+        self.assertEqual(mod.code_filename, self.file_name)
+        self.assertEqual(mod.func_filename, self.file_name)
+        del sys.modules[self.module_name]
+        mod = self.import_module()
+        self.assertEqual(mod.module_filename, self.compiled_name)
+        self.assertEqual(mod.code_filename, self.file_name)
+        self.assertEqual(mod.func_filename, self.file_name)
+
+    def test_incorrect_code_name(self):
+        py_compile.compile(self.file_name, dfile="another_module.py")
+        mod = self.import_module()
+        self.assertEqual(mod.module_filename, self.compiled_name)
+        self.assertEqual(mod.code_filename, self.file_name)
+        self.assertEqual(mod.func_filename, self.file_name)
+
+    def test_module_without_source(self):
+        target = "another_module.py"
+        py_compile.compile(self.file_name, dfile=target)
+        os.remove(self.file_name)
+        mod = self.import_module()
+        self.assertEqual(mod.module_filename, self.compiled_name)
+        self.assertEqual(mod.code_filename, target)
+        self.assertEqual(mod.func_filename, target)
+
+    def test_foreign_code(self):
+        py_compile.compile(self.file_name)
+        with open(self.compiled_name, "rb") as f:
+            header = f.read(8)
+            code = marshal.load(f)
+        constants = list(code.co_consts)
+        foreign_code = test_main.func_code
+        pos = constants.index(1)
+        constants[pos] = foreign_code
+        code = type(code)(code.co_argcount, code.co_nlocals, code.co_stacksize,
+                          code.co_flags, code.co_code, tuple(constants),
+                          code.co_names, code.co_varnames, code.co_filename,
+                          code.co_name, code.co_firstlineno, code.co_lnotab,
+                          code.co_freevars, code.co_cellvars)
+        with open(self.compiled_name, "wb") as f:
+            f.write(header)
+            marshal.dump(code, f)
+        mod = self.import_module()
+        self.assertEqual(mod.constant.co_filename, foreign_code.co_filename)
+
 class PathsTests(unittest.TestCase):
     path = TESTFN
 
@@ -297,7 +389,7 @@ class RelativeImport(unittest.TestCase):
         self.assertRaises(ValueError, check_relative)
 
 def test_main(verbose=None):
-    run_unittest(ImportTest, PathsTests, RelativeImport)
+    run_unittest(ImportTest, TestPycRewriting, PathsTests, RelativeImport)
 
 if __name__ == '__main__':
     # test needs to be a package, so we can do relative import
index 090eebd451613cda274c04b973b1892dd7da4951..cd6bec67c594e9c08d3c6d9ee0af4f4d5e47e4d7 100644 (file)
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -12,6 +12,12 @@ What's New in Python 2.7 alpha 1
 Core and Builtins
 -----------------
 
+- Issue #1180193: When importing a module from a .pyc (or .pyo) file with
+  an existing .py counterpart, override the co_filename attributes of all
+  code objects if the original filename is obsolete (which can happen if the
+  file has been renamed, moved, or if it is accessed through different paths).
+  Patch by Ziga Seilnacht and Jean-Paul Calderone.
+
 - Issue #4075: Use OutputDebugStringW in Py_FatalError.
 
 - Issue #4797: IOError.filename was not set when _fileio.FileIO failed to open
index 08024b21dbe42acfedaeeac8c428d4cc06dbaaba..e9ff922fd03c09172c33f5de671a77c6655d5af3 100644 (file)
@@ -909,6 +909,49 @@ write_compiled_module(PyCodeObject *co, char *cpathname, struct stat *srcstat)
                PySys_WriteStderr("# wrote %s\n", cpathname);
 }
 
+static void
+update_code_filenames(PyCodeObject *co, PyObject *oldname, PyObject *newname)
+{
+       PyObject *constants, *tmp;
+       Py_ssize_t i, n;
+
+       if (!_PyString_Eq(co->co_filename, oldname))
+               return;
+
+       tmp = co->co_filename;
+       co->co_filename = newname;
+       Py_INCREF(co->co_filename);
+       Py_DECREF(tmp);
+
+       constants = co->co_consts;
+       n = PyTuple_GET_SIZE(constants);
+       for (i = 0; i < n; i++) {
+               tmp = PyTuple_GET_ITEM(constants, i);
+               if (PyCode_Check(tmp))
+                       update_code_filenames((PyCodeObject *)tmp,
+                                             oldname, newname);
+       }
+}
+
+static int
+update_compiled_module(PyCodeObject *co, char *pathname)
+{
+       PyObject *oldname, *newname;
+
+       if (strcmp(PyString_AsString(co->co_filename), pathname) == 0)
+               return 0;
+
+       newname = PyString_FromString(pathname);
+       if (newname == NULL)
+               return -1;
+
+       oldname = co->co_filename;
+       Py_INCREF(oldname);
+       update_code_filenames(co, oldname, newname);
+       Py_DECREF(oldname);
+       Py_DECREF(newname);
+       return 1;
+}
 
 /* Load a source module from a given file and return its module
    object WITH INCREMENTED REFERENCE COUNT.  If there's a matching
@@ -949,6 +992,8 @@ load_source_module(char *name, char *pathname, FILE *fp)
                fclose(fpc);
                if (co == NULL)
                        return NULL;
+               if (update_compiled_module(co, pathname) < 0)
+                       return NULL;
                if (Py_VerboseFlag)
                        PySys_WriteStderr("import %s # precompiled from %s\n",
                                name, cpathname);