]> granicus.if.org Git - python/commitdiff
Generalize file.writelines() to allow iterable objects.
authorTim Peters <tim.peters@gmail.com>
Sun, 23 Sep 2001 04:06:05 +0000 (04:06 +0000)
committerTim Peters <tim.peters@gmail.com>
Sun, 23 Sep 2001 04:06:05 +0000 (04:06 +0000)
Doc/lib/libstdtypes.tex
Lib/test/test_iter.py
Misc/NEWS
Objects/fileobject.c
Tools/scripts/ndiff.py

index a79e14294547d6c8d76400cd97673f4540c6a377..916a1cfa4996aa01047055bbe49ba88d1251a829 100644 (file)
@@ -1312,8 +1312,10 @@ Files have the following methods:
   the \method{flush()} or \method{close()} method is called.
 \end{methoddesc}
 
-\begin{methoddesc}[file]{writelines}{list}
-  Write a list of strings to the file.  There is no return value.
+\begin{methoddesc}[file]{writelines}{sequence}
+  Write a sequence of strings to the file.  The sequence can be any
+  iterable object producing strings, typically a list of strings.
+  There is no return value.
   (The name is intended to match \method{readlines()};
   \method{writelines()} does not add line separators.)
 \end{methoddesc}
index f6084cb9a33caf9d543024c16809cbceafa9c527..257b61d8931e94ea5cfc4d5ec446c5e82a16f5ee 100644 (file)
@@ -641,6 +641,59 @@ class TestCase(unittest.TestCase):
             self.assertEqual(indexOf(iclass, i), i)
         self.assertRaises(ValueError, indexOf, iclass, -1)
 
+    # Test iterators with file.writelines().
+    def test_writelines(self):
+        f = file(TESTFN, "w")
+
+        try:
+            self.assertRaises(TypeError, f.writelines, None)
+            self.assertRaises(TypeError, f.writelines, 42)
+    
+            f.writelines(["1\n", "2\n"])
+            f.writelines(("3\n", "4\n"))
+            f.writelines({'5\n': None})
+            f.writelines({})
+
+            # Try a big chunk too.
+            class Iterator:
+                def __init__(self, start, finish):
+                    self.start = start
+                    self.finish = finish
+                    self.i = self.start
+
+                def next(self):
+                    if self.i >= self.finish:
+                        raise StopIteration
+                    result = str(self.i) + '\n'
+                    self.i += 1
+                    return result
+
+                def __iter__(self):
+                    return self
+
+            class Whatever:
+                def __init__(self, start, finish):
+                    self.start = start
+                    self.finish = finish
+
+                def __iter__(self):
+                    return Iterator(self.start, self.finish)
+                    
+            f.writelines(Whatever(6, 6+2000))            
+            f.close()
+
+            f = file(TESTFN)
+            expected = [str(i) + "\n" for i in range(1, 2006)]
+            self.assertEqual(list(f), expected)
+            
+        finally:
+            f.close()
+            try:
+                unlink(TESTFN)
+            except OSError:
+                pass
+
+
     # Test iterators on RHS of unpacking assignments.
     def test_unpack_iter(self):
         a, b = 1, 2
index cf8e3fc218b2803e77d29027e0ab5311ce1f12af..9a58c384c0f9912e550e37042bbb752b7be58a08 100644 (file)
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -3,6 +3,8 @@ What's New in Python 2.2a4?
 
 Core
 
+- file.writelines() now accepts any iterable object producing strings.
+
 - PyUnicode_FromEncodedObject() now works very much like
   PyObject_Str(obj) in that it tries to use __str__/tp_str
   on the object if the object is not a string or buffer. This
index b6c039ccbcbf03783374a16ea3e345ceae061638..d3309245021f3a395b82d5b313eea668557c028d 100644 (file)
@@ -1164,55 +1164,54 @@ file_write(PyFileObject *f, PyObject *args)
 }
 
 static PyObject *
-file_writelines(PyFileObject *f, PyObject *args)
+file_writelines(PyFileObject *f, PyObject *seq)
 {
 #define CHUNKSIZE 1000
        PyObject *list, *line;
+       PyObject *it;   /* iter(seq) */
        PyObject *result;
        int i, j, index, len, nwritten, islist;
 
+       assert(seq != NULL);
        if (f->f_fp == NULL)
                return err_closed();
-       if (args == NULL || !PySequence_Check(args)) {
-               PyErr_SetString(PyExc_TypeError,
-                          "writelines() argument must be a sequence of strings");
-               return NULL;
-       }
-       islist = PyList_Check(args);
 
-       /* Strategy: slurp CHUNKSIZE lines into a private list,
-          checking that they are all strings, then write that list
-          without holding the interpreter lock, then come back for more. */
-       index = 0;
-       if (islist)
-               list = NULL;
+       result = NULL;
+       list = NULL;
+       islist = PyList_Check(seq);
+       if  (islist)
+               it = NULL;
        else {
+               it = PyObject_GetIter(seq);
+               if (it == NULL) {
+                       PyErr_SetString(PyExc_TypeError,
+                               "writelines() requires an iterable argument");
+                       return NULL;
+               }
+               /* From here on, fail by going to error, to reclaim "it". */
                list = PyList_New(CHUNKSIZE);
                if (list == NULL)
-                       return NULL;
+                       goto error;
        }
-       result = NULL;
 
-       for (;;) {
+       /* Strategy: slurp CHUNKSIZE lines into a private list,
+          checking that they are all strings, then write that list
+          without holding the interpreter lock, then come back for more. */
+       for (index = 0; ; index += CHUNKSIZE) {
                if (islist) {
                        Py_XDECREF(list);
-                       list = PyList_GetSlice(args, index, index+CHUNKSIZE);
+                       list = PyList_GetSlice(seq, index, index+CHUNKSIZE);
                        if (list == NULL)
-                               return NULL;
+                               goto error;
                        j = PyList_GET_SIZE(list);
                }
                else {
                        for (j = 0; j < CHUNKSIZE; j++) {
-                               line = PySequence_GetItem(args, index+j);
+                               line = PyIter_Next(it);
                                if (line == NULL) {
-                                       if (PyErr_ExceptionMatches(
-                                               PyExc_IndexError)) {
-                                               PyErr_Clear();
-                                               break;
-                                       }
-                                       /* Some other error occurred.
-                                          XXX We may lose some output. */
-                                       goto error;
+                                       if (PyErr_Occurred())
+                                               goto error;
+                                       break;
                                }
                                PyList_SetItem(list, j, line);
                        }
@@ -1271,14 +1270,15 @@ file_writelines(PyFileObject *f, PyObject *args)
 
                if (j < CHUNKSIZE)
                        break;
-               index += CHUNKSIZE;
        }
 
        Py_INCREF(Py_None);
        result = Py_None;
   error:
        Py_XDECREF(list);
+       Py_XDECREF(it);
        return result;
+#undef CHUNKSIZE
 }
 
 static char readline_doc[] =
@@ -1342,10 +1342,10 @@ static char xreadlines_doc[] =
 "often quicker, due to reading ahead internally.";
 
 static char writelines_doc[] =
-"writelines(list of strings) -> None.  Write the strings to the file.\n"
+"writelines(sequence_of_strings) -> None.  Write the strings to the file.\n"
 "\n"
-"Note that newlines are not added.  This is equivalent to calling write()\n"
-"for each string in the list.";
+"Note that newlines are not added.  The sequence can be any iterable object\n"
+"producing strings. This is equivalent to calling write() for each string.";
 
 static char flush_doc[] =
 "flush() -> None.  Flush the internal I/O buffer.";
index 7ceccc5711a938dc50df4b988f5c94b0199984fa..6f0f9a9cf34fddec1c834d9898b58c209cc6f8af 100755 (executable)
@@ -118,8 +118,7 @@ def main(args):
 
 def restore(which):
     restored = difflib.restore(sys.stdin.readlines(), which)
-    for line in restored:
-        print line,
+    sys.stdout.writelines(restored)
 
 if __name__ == '__main__':
     args = sys.argv[1:]