]> granicus.if.org Git - python/commitdiff
Add raw_input() back, named input(). Revive the old unittests too.
authorGuido van Rossum <guido@python.org>
Mon, 26 Feb 2007 16:59:55 +0000 (16:59 +0000)
committerGuido van Rossum <guido@python.org>
Mon, 26 Feb 2007 16:59:55 +0000 (16:59 +0000)
Lib/test/test_builtin.py
Python/bltinmodule.c

index e22e73a7336250ba59aa9346735990070943efd5..d2f70ff14329d9c00f384627adaefb0c0610dba1 100644 (file)
@@ -664,6 +664,8 @@ class BuiltinTest(unittest.TestCase):
         id([0,1,2,3])
         id({'spam': 1, 'eggs': 2, 'ham': 3})
 
+    # Test input() later, alphabetized as if it were raw_input
+
     def test_int(self):
         self.assertEqual(int(314), 314)
         self.assertEqual(int(3.14), 3)
@@ -1256,6 +1258,7 @@ class BuiltinTest(unittest.TestCase):
         self.assertRaises(TypeError, oct, ())
 
     def write_testfile(self):
+        # NB the first 4 lines are also used to test input, below
         fp = open(TESTFN, 'w')
         try:
             fp.write('1+1\n')
@@ -1417,6 +1420,43 @@ class BuiltinTest(unittest.TestCase):
         self.assertRaises(OverflowError, range, -sys.maxint, sys.maxint)
         self.assertRaises(OverflowError, range, 0, 2*sys.maxint)
 
+    def test_input(self):
+        self.write_testfile()
+        fp = open(TESTFN, 'r')
+        savestdin = sys.stdin
+        savestdout = sys.stdout # Eats the echo
+        try:
+            sys.stdin = fp
+            sys.stdout = BitBucket()
+            self.assertEqual(input(), "1+1")
+            self.assertEqual(input('testing\n'), "1+1")
+            self.assertEqual(input(), 'The quick brown fox jumps over the lazy dog.')
+            self.assertEqual(input('testing\n'), 'Dear John')
+
+            # SF 1535165: don't segfault on closed stdin
+            # sys.stdout must be a regular file for triggering
+            sys.stdout = savestdout
+            sys.stdin.close()
+            self.assertRaises(ValueError, input)
+
+            sys.stdout = BitBucket()
+            sys.stdin = cStringIO.StringIO("NULL\0")
+            self.assertRaises(TypeError, input, 42, 42)
+            sys.stdin = cStringIO.StringIO("    'whitespace'")
+            self.assertEqual(input(), "    'whitespace'")
+            sys.stdin = cStringIO.StringIO()
+            self.assertRaises(EOFError, input)
+
+            del sys.stdout
+            self.assertRaises(RuntimeError, input, 'prompt')
+            del sys.stdin
+            self.assertRaises(RuntimeError, input, 'prompt')
+        finally:
+            sys.stdin = savestdin
+            sys.stdout = savestdout
+            fp.close()
+            unlink(TESTFN)
+
     def test_reload(self):
         import marshal
         reload(marshal)
index 78aeeb7937e144f1897b1e28b58f609ddc781d58..3c149e4b495790d60e4756dc0d35cb9bda8a1f55 100644 (file)
@@ -1753,6 +1753,83 @@ When step is given, it specifies the increment (or decrement).\n\
 For example, range(4) returns [0, 1, 2, 3].  The end point is omitted!\n\
 These are exactly the valid indices for a list of 4 elements.");
 
+static PyObject *
+builtin_input(PyObject *self, PyObject *args)
+{
+       PyObject *v = NULL;
+       PyObject *fin = PySys_GetObject("stdin");
+       PyObject *fout = PySys_GetObject("stdout");
+
+       if (!PyArg_UnpackTuple(args, "input", 0, 1, &v))
+               return NULL;
+
+       if (fin == NULL) {
+               PyErr_SetString(PyExc_RuntimeError, "input: lost sys.stdin");
+               return NULL;
+       }
+       if (fout == NULL) {
+               PyErr_SetString(PyExc_RuntimeError, "input: lost sys.stdout");
+               return NULL;
+       }
+       if (PyFile_AsFile(fin) && PyFile_AsFile(fout)
+            && isatty(fileno(PyFile_AsFile(fin)))
+            && isatty(fileno(PyFile_AsFile(fout)))) {
+               PyObject *po;
+               char *prompt;
+               char *s;
+               PyObject *result;
+               if (v != NULL) {
+                       po = PyObject_Str(v);
+                       if (po == NULL)
+                               return NULL;
+                       prompt = PyString_AsString(po);
+                       if (prompt == NULL)
+                               return NULL;
+               }
+               else {
+                       po = NULL;
+                       prompt = "";
+               }
+               s = PyOS_Readline(PyFile_AsFile(fin), PyFile_AsFile(fout),
+                                  prompt);
+               Py_XDECREF(po);
+               if (s == NULL) {
+                       if (!PyErr_Occurred())
+                               PyErr_SetNone(PyExc_KeyboardInterrupt);
+                       return NULL;
+               }
+               if (*s == '\0') {
+                       PyErr_SetNone(PyExc_EOFError);
+                       result = NULL;
+               }
+               else { /* strip trailing '\n' */
+                       size_t len = strlen(s);
+                       if (len > PY_SSIZE_T_MAX) {
+                               PyErr_SetString(PyExc_OverflowError,
+                                               "input: input too long");
+                               result = NULL;
+                       }
+                       else {
+                               result = PyString_FromStringAndSize(s, len-1);
+                       }
+               }
+               PyMem_FREE(s);
+               return result;
+       }
+       if (v != NULL) {
+               if (PyFile_WriteObject(v, fout, Py_PRINT_RAW) != 0)
+                       return NULL;
+       }
+       return PyFile_GetLine(fin, -1);
+}
+
+PyDoc_STRVAR(input_doc,
+"input([prompt]) -> string\n\
+\n\
+Read a string from standard input.  The trailing newline is stripped.\n\
+If the user hits EOF (Unix: Ctl-D, Windows: Ctl-Z+Return), raise EOFError.\n\
+On Unix, GNU readline is used if enabled.  The prompt string, if given,\n\
+is printed without a trailing newline before reading.");
 
 static PyObject *
 builtin_reload(PyObject *self, PyObject *v)
@@ -2046,6 +2123,7 @@ static PyMethodDef builtin_methods[] = {
        {"hash",        builtin_hash,       METH_O, hash_doc},
        {"hex",         builtin_hex,        METH_O, hex_doc},
        {"id",          builtin_id,         METH_O, id_doc},
+       {"input",       builtin_input,      METH_VARARGS, input_doc},
        {"isinstance",  builtin_isinstance, METH_VARARGS, isinstance_doc},
        {"issubclass",  builtin_issubclass, METH_VARARGS, issubclass_doc},
        {"iter",        builtin_iter,       METH_VARARGS, iter_doc},