]> granicus.if.org Git - python/commitdiff
Patch #1739468: Directories and zipfiles containing __main__.py are now executable
authorNick Coghlan <ncoghlan@gmail.com>
Sun, 18 Nov 2007 11:56:28 +0000 (11:56 +0000)
committerNick Coghlan <ncoghlan@gmail.com>
Sun, 18 Nov 2007 11:56:28 +0000 (11:56 +0000)
Include/import.h
Lib/test/test_cmd_line.py
Lib/test/test_cmd_line_script.py [new file with mode: 0644]
Misc/NEWS
Modules/main.c
Python/import.c

index 414e059a3ecfe8a9146eef6ce1c49457b19dc965..af9f3394b13d51a73092acd5d53a258cdc280566 100644 (file)
@@ -24,6 +24,7 @@ PyAPI_FUNC(PyObject *) PyImport_ImportModuleEx(
 #define PyImport_ImportModuleEx(n, g, l, f) \
        PyImport_ImportModuleLevel(n, g, l, f, -1)
 
+PyAPI_FUNC(PyObject *) PyImport_GetImporter(PyObject *path);
 PyAPI_FUNC(PyObject *) PyImport_Import(PyObject *name);
 PyAPI_FUNC(PyObject *) PyImport_ReloadModule(PyObject *m);
 PyAPI_FUNC(void) PyImport_Cleanup(void);
@@ -42,6 +43,7 @@ struct _inittab {
     void (*initfunc)(void);
 };
 
+PyAPI_DATA(PyTypeObject) PyNullImporter_Type;
 PyAPI_DATA(struct _inittab *) PyImport_Inittab;
 
 PyAPI_FUNC(int) PyImport_AppendInittab(char *name, void (*initfunc)(void));
index 2232f4d9c8095536ad186cda8187eec315c4e806..efef74f09c9caa9de16a9b982249227f63487d9d 100644 (file)
@@ -1,3 +1,6 @@
+# Tests invocation of the interpreter with various command line arguments
+# All tests are executed with environment variables ignored
+# See test_cmd_line_script.py for testing of script execution
 
 import test.test_support, unittest
 import sys
diff --git a/Lib/test/test_cmd_line_script.py b/Lib/test/test_cmd_line_script.py
new file mode 100644 (file)
index 0000000..eac6385
--- /dev/null
@@ -0,0 +1,145 @@
+# Tests command line execution of scripts
+from __future__ import with_statement
+
+import unittest
+import os
+import os.path
+import sys
+import test
+import tempfile
+import subprocess
+import py_compile
+import contextlib
+import shutil
+import zipfile
+
+verbose = test.test_support.verbose
+
+# XXX ncoghlan: Should we consider moving these to test_support?
+from test_cmd_line import _spawn_python, _kill_python
+
+def _run_python(*args):
+    if __debug__:
+        p = _spawn_python(*args)
+    else:
+        p = _spawn_python('-O', *args)
+    stdout_data = _kill_python(p)
+    return p.wait(), stdout_data
+
+@contextlib.contextmanager
+def temp_dir():
+    dirname = tempfile.mkdtemp()
+    try:
+        yield dirname
+    finally:
+        shutil.rmtree(dirname)
+
+test_source = ("""\
+# Script may be run with optimisation enabled, so don't rely on assert
+# statements being executed
+def assertEqual(lhs, rhs):
+    if lhs != rhs:
+        raise AssertionError("%r != %r" % (lhs, rhs))
+def assertIdentical(lhs, rhs):
+    if lhs is not rhs:
+        raise AssertionError("%r is not %r" % (lhs, rhs))
+# Check basic code execution
+result = ['Top level assignment']
+def f():
+    result.append('Lower level reference')
+f()
+assertEqual(result, ['Top level assignment', 'Lower level reference'])
+# Check population of magic variables
+assertEqual(__name__, '__main__')
+print '__file__==%r' % __file__
+# Check the sys module
+import sys
+assertIdentical(globals(), sys.modules[__name__].__dict__)
+print 'sys.argv[0]==%r' % sys.argv[0]
+""")
+
+def _make_test_script(script_dir, script_basename):
+    script_filename = script_basename+os.extsep+"py"
+    script_name = os.path.join(script_dir, script_filename)
+    script_file = open(script_name, "w")
+    script_file.write(test_source)
+    script_file.close()
+    return script_name
+
+def _compile_test_script(script_name):
+    py_compile.compile(script_name, doraise=True)
+    if __debug__:
+        compiled_name = script_name + 'c'
+    else:
+        compiled_name = script_name + 'o'
+    return compiled_name
+
+def _make_test_zip(zip_dir, zip_basename, script_name):
+    zip_filename = zip_basename+os.extsep+"zip"
+    zip_name = os.path.join(zip_dir, zip_filename)
+    zip_file = zipfile.ZipFile(zip_name, 'w')
+    zip_file.write(script_name, os.path.basename(script_name))
+    zip_file.close()
+    # if verbose:
+    #    zip_file = zipfile.ZipFile(zip_name, 'r')
+    #    print "Contents of %r:" % zip_name
+    #    zip_file.printdir()
+    #    zip_file.close()
+    return zip_name
+
+class CmdLineTest(unittest.TestCase):
+    def _check_script(self, script_name, expected_file, expected_argv0):
+        exit_code, data = _run_python(script_name)
+        # if verbose:
+        #    print "Output from test script %r:" % script_name
+        #    print data
+        self.assertEqual(exit_code, 0)
+        printed_file = '__file__==%r' % expected_file
+        printed_argv0 = 'sys.argv[0]==%r' % expected_argv0
+        self.assert_(printed_file in data)
+        self.assert_(printed_argv0 in data)
+
+    def test_basic_script(self):
+        with temp_dir() as script_dir:
+            script_name = _make_test_script(script_dir, "script")
+            self._check_script(script_name, script_name, script_name)
+
+    def test_script_compiled(self):
+        with temp_dir() as script_dir:
+            script_name = _make_test_script(script_dir, "script")
+            compiled_name = _compile_test_script(script_name)
+            os.remove(script_name)
+            self._check_script(compiled_name, compiled_name, compiled_name)
+
+    def test_directory(self):
+        with temp_dir() as script_dir:
+            script_name = _make_test_script(script_dir, "__main__")
+            self._check_script(script_dir, script_name, script_dir)
+
+    def test_directory_compiled(self):
+        with temp_dir() as script_dir:
+            script_name = _make_test_script(script_dir, "__main__")
+            compiled_name = _compile_test_script(script_name)
+            os.remove(script_name)
+            self._check_script(script_dir, compiled_name, script_dir)
+
+    def test_zipfile(self):
+        with temp_dir() as script_dir:
+            script_name = _make_test_script(script_dir, "__main__")
+            zip_name = _make_test_zip(script_dir, "test_zip", script_name)
+            self._check_script(zip_name, None, zip_name)
+
+    def test_zipfile_compiled(self):
+        with temp_dir() as script_dir:
+            script_name = _make_test_script(script_dir, "__main__")
+            compiled_name = _compile_test_script(script_name)
+            zip_name = _make_test_zip(script_dir, "test_zip", compiled_name)
+            self._check_script(zip_name, None, zip_name)
+
+
+def test_main():
+    test.test_support.run_unittest(CmdLineTest)
+    test.test_support.reap_children()
+
+if __name__ == "__main__":
+    test_main()
index 9da2ba885c89d2343cf48176cabc5f34607d4c39..c474c50f0bf5c530a9465456a3ddf8bfd9e9c67a 100644 (file)
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -12,6 +12,10 @@ What's New in Python 2.6 alpha 1?
 Core and builtins
 -----------------
 
+- Patch #1739468: Directories and zipfiles containing a __main__.py file can
+  now be directly executed by passing their name to the interpreter. The
+  directory/zipfile is automatically inserted as the first entry in sys.path.
+
 - Issue #1265: Fix a problem with sys.settrace, if the tracing function uses a
   generator expression when at the same time the executed code is closing a
   paused generator.
index 417b3f2bf39c5b8c40813d5d5dfa9a037e3c0b3f..4b06acb3c52f41960d51bcf6df171b82a8e9d6e1 100644 (file)
@@ -141,7 +141,7 @@ static void RunStartupFile(PyCompilerFlags *cf)
 }
 
 
-static int RunModule(char *module)
+static int RunModule(char *module, int set_argv0)
 {
        PyObject *runpy, *runmodule, *runargs, *result;
        runpy = PyImport_ImportModule("runpy");
@@ -155,7 +155,7 @@ static int RunModule(char *module)
                Py_DECREF(runpy);
                return -1;
        }
-       runargs = Py_BuildValue("(s)", module);
+       runargs = Py_BuildValue("(si)", module, set_argv0);
        if (runargs == NULL) {
                fprintf(stderr,
                        "Could not create arguments for runpy._run_module_as_main\n");
@@ -177,6 +177,35 @@ static int RunModule(char *module)
        return 0;
 }
 
+static int RunMainFromImporter(char *filename)
+{
+       PyObject *argv0 = NULL, *importer = NULL;
+
+       if (
+               (argv0 = PyString_FromString(filename)) && 
+               (importer = PyImport_GetImporter(argv0)) &&
+               (importer->ob_type != &PyNullImporter_Type))
+       {
+                /* argv0 is usable as an import source, so
+                       put it in sys.path[0] and import __main__ */
+               PyObject *sys_path = NULL;
+               if (
+                       (sys_path = PySys_GetObject("path")) &&
+                       !PyList_SetItem(sys_path, 0, argv0)
+               ) {
+                       Py_INCREF(argv0);
+                       Py_CLEAR(importer);
+                       sys_path = NULL;
+                       return RunModule("__main__", 0) != 0;
+               }
+       }
+       PyErr_Clear();
+       Py_CLEAR(argv0);
+       Py_CLEAR(importer);
+       return -1;
+}
+
+
 /* Wait until threading._shutdown completes, provided
    the threading module was imported in the first place.
    The shutdown routine will wait until all non-daemon
@@ -388,39 +417,6 @@ Py_Main(int argc, char **argv)
 #else
                filename = argv[_PyOS_optind];
 #endif
-               if (filename != NULL) {
-                       if ((fp = fopen(filename, "r")) == NULL) {
-#ifdef HAVE_STRERROR
-                               fprintf(stderr, "%s: can't open file '%s': [Errno %d] %s\n",
-                                       argv[0], filename, errno, strerror(errno));
-#else
-                               fprintf(stderr, "%s: can't open file '%s': Errno %d\n",
-                                       argv[0], filename, errno);
-#endif
-                               return 2;
-                       }
-                       else if (skipfirstline) {
-                               int ch;
-                               /* Push back first newline so line numbers
-                                  remain the same */
-                               while ((ch = getc(fp)) != EOF) {
-                                       if (ch == '\n') {
-                                               (void)ungetc(ch, fp);
-                                               break;
-                                       }
-                               }
-                       }
-                       {
-                               /* XXX: does this work on Win/Win64? (see posix_fstat) */
-                               struct stat sb;
-                               if (fstat(fileno(fp), &sb) == 0 &&
-                                   S_ISDIR(sb.st_mode)) {
-                                       fprintf(stderr, "%s: '%s' is a directory, cannot continue\n", argv[0], filename);
-                                       fclose(fp);
-                                       return 1;
-                               }
-                       }
-               }
        }
 
        stdin_is_interactive = Py_FdIsInteractive(stdin, (char *)0);
@@ -515,19 +511,63 @@ Py_Main(int argc, char **argv)
                sts = PyRun_SimpleStringFlags(command, &cf) != 0;
                free(command);
        } else if (module) {
-               sts = RunModule(module);
+               sts = RunModule(module, 1);
                free(module);
        }
        else {
+
                if (filename == NULL && stdin_is_interactive) {
                        Py_InspectFlag = 0; /* do exit on SystemExit */
                        RunStartupFile(&cf);
                }
                /* XXX */
-               sts = PyRun_AnyFileExFlags(
-                       fp,
-                       filename == NULL ? "<stdin>" : filename,
-                       filename != NULL, &cf) != 0;
+
+               sts = -1;       /* keep track of whether we've already run __main__ */
+
+               if (filename != NULL) {
+                       sts = RunMainFromImporter(filename);
+               }
+
+               if (sts==-1 && filename!=NULL) {
+                       if ((fp = fopen(filename, "r")) == NULL) {
+#ifdef HAVE_STRERROR
+                               fprintf(stderr, "%s: can't open file '%s': [Errno %d] %s\n",
+                                       argv[0], filename, errno, strerror(errno));
+#else
+                               fprintf(stderr, "%s: can't open file '%s': Errno %d\n",
+                                       argv[0], filename, errno);
+#endif
+                               return 2;
+                       }
+                       else if (skipfirstline) {
+                               int ch;
+                               /* Push back first newline so line numbers
+                                  remain the same */
+                               while ((ch = getc(fp)) != EOF) {
+                                       if (ch == '\n') {
+                                               (void)ungetc(ch, fp);
+                                               break;
+                                       }
+                               }
+                       }
+                       {
+                               /* XXX: does this work on Win/Win64? (see posix_fstat) */
+                               struct stat sb;
+                               if (fstat(fileno(fp), &sb) == 0 &&
+                                   S_ISDIR(sb.st_mode)) {
+                                       fprintf(stderr, "%s: '%s' is a directory, cannot continue\n", argv[0], filename);
+                                       return 1;
+                               }
+                       }
+               }
+
+               if (sts==-1) {
+                       sts = PyRun_AnyFileExFlags(
+                               fp,
+                               filename == NULL ? "<stdin>" : filename,
+                               filename != NULL, &cf) != 0;
+               }
+               
        }
 
        /* Check this environment variable at the end, to give programs the
index 6a4d22f6e645ef4e26a37c48bc23edf50bd62f7f..59a51bc233867b8d1312e1357a2afad7e0293097 100644 (file)
@@ -104,7 +104,6 @@ static const struct filedescr _PyImport_StandardFiletab[] = {
 };
 #endif
 
-static PyTypeObject NullImporterType;  /* Forward reference */
 
 /* Initialize things */
 
@@ -167,7 +166,7 @@ _PyImportHooks_Init(void)
 
        /* adding sys.path_hooks and sys.path_importer_cache, setting up
           zipimport */
-       if (PyType_Ready(&NullImporterType) < 0)
+       if (PyType_Ready(&PyNullImporter_Type) < 0)
                goto error;
 
        if (Py_VerboseFlag)
@@ -1088,7 +1087,7 @@ get_path_importer(PyObject *path_importer_cache, PyObject *path_hooks,
        }
        if (importer == NULL) {
                importer = PyObject_CallFunctionObjArgs(
-                       (PyObject *)&NullImporterType, p, NULL
+                       (PyObject *)&PyNullImporter_Type, p, NULL
                );
                if (importer == NULL) {
                        if (PyErr_ExceptionMatches(PyExc_ImportError)) {
@@ -1106,6 +1105,20 @@ get_path_importer(PyObject *path_importer_cache, PyObject *path_hooks,
        return importer;
 }
 
+PyAPI_FUNC(PyObject *)
+PyImport_GetImporter(PyObject *path) {
+        PyObject *importer=NULL, *path_importer_cache=NULL, *path_hooks=NULL;
+
+       if ((path_importer_cache = PySys_GetObject("path_importer_cache"))) {
+               if ((path_hooks = PySys_GetObject("path_hooks"))) {
+                       importer = get_path_importer(path_importer_cache,
+                                                    path_hooks, path);
+               }
+       }
+       Py_XINCREF(importer); /* get_path_importer returns a borrowed reference */
+       return importer;
+}
+
 /* Search the path (default sys.path) for a module.  Return the
    corresponding filedescr struct, and (via return arguments) the
    pathname and an open file.  Return NULL if the module is not found. */
@@ -3049,7 +3062,7 @@ static PyMethodDef NullImporter_methods[] = {
 };
 
 
-static PyTypeObject NullImporterType = {
+PyTypeObject PyNullImporter_Type = {
        PyVarObject_HEAD_INIT(NULL, 0)
        "imp.NullImporter",        /*tp_name*/
        sizeof(NullImporter),      /*tp_basicsize*/
@@ -3096,7 +3109,7 @@ initimp(void)
 {
        PyObject *m, *d;
 
-       if (PyType_Ready(&NullImporterType) < 0)
+       if (PyType_Ready(&PyNullImporter_Type) < 0)
                goto failure;
 
        m = Py_InitModule4("imp", imp_methods, doc_imp,
@@ -3118,8 +3131,8 @@ initimp(void)
        if (setint(d, "PY_CODERESOURCE", PY_CODERESOURCE) < 0) goto failure;
        if (setint(d, "IMP_HOOK", IMP_HOOK) < 0) goto failure;
 
-       Py_INCREF(&NullImporterType);
-       PyModule_AddObject(m, "NullImporter", (PyObject *)&NullImporterType);
+       Py_INCREF(&PyNullImporter_Type);
+       PyModule_AddObject(m, "NullImporter", (PyObject *)&PyNullImporter_Type);
   failure:
        ;
 }