]> granicus.if.org Git - python/commitdiff
Have itertools.chain() consume its inputs lazily instead of building a tuple of itera...
authorRaymond Hettinger <python@rcn.com>
Thu, 28 Feb 2008 22:30:42 +0000 (22:30 +0000)
committerRaymond Hettinger <python@rcn.com>
Thu, 28 Feb 2008 22:30:42 +0000 (22:30 +0000)
Lib/test/test_itertools.py
Modules/itertoolsmodule.c

index 79c4b3a4d9fb83dfbb92bad7fcd7ff7f16604a99..41e93627f5170b0e4e80b0e8648f96c7aefb5ce0 100644 (file)
@@ -50,7 +50,7 @@ class TestBasicOps(unittest.TestCase):
         self.assertEqual(list(chain('abc')), list('abc'))
         self.assertEqual(list(chain('')), [])
         self.assertEqual(take(4, chain('abc', 'def')), list('abcd'))
-        self.assertRaises(TypeError, chain, 2, 3)
+        self.assertRaises(TypeError, list,chain(2, 3))
 
     def test_combinations(self):
         self.assertRaises(TypeError, combinations, 'abc')   # missing r argument
@@ -670,7 +670,7 @@ class TestVariousIteratorArgs(unittest.TestCase):
             for g in (G, I, Ig, S, L, R):
                 self.assertEqual(list(chain(g(s))), list(g(s)))
                 self.assertEqual(list(chain(g(s), g(s))), list(g(s))+list(g(s)))
-            self.assertRaises(TypeError, chain, X(s))
+            self.assertRaises(TypeError, list, chain(X(s)))
             self.assertRaises(TypeError, list, chain(N(s)))
             self.assertRaises(ZeroDivisionError, list, chain(E(s)))
 
index 2ee947dcffa211cdb60ad9bb7d35ade485c0c0ae..3b8339cc56fc01bdb86d37825514c4fe91b3d23f 100644 (file)
@@ -1601,92 +1601,92 @@ static PyTypeObject imap_type = {
 
 typedef struct {
        PyObject_HEAD
-       Py_ssize_t tuplesize;
-       Py_ssize_t iternum;             /* which iterator is active */
-       PyObject *ittuple;              /* tuple of iterators */
+       PyObject *source;               /* Iterator over input iterables */
+       PyObject *active;               /* Currently running input iterator */
 } chainobject;
 
 static PyTypeObject chain_type;
 
-static PyObject *
-chain_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
+static PyObject * 
+chain_new_internal(PyTypeObject *type, PyObject *source)
 {
        chainobject *lz;
-       Py_ssize_t tuplesize = PySequence_Length(args);
-       Py_ssize_t i;
-       PyObject *ittuple;
-
-       if (type == &chain_type && !_PyArg_NoKeywords("chain()", kwds))
-               return NULL;
-
-       /* obtain iterators */
-       assert(PyTuple_Check(args));
-       ittuple = PyTuple_New(tuplesize);
-       if (ittuple == NULL)
-               return NULL;
-       for (i=0; i < tuplesize; ++i) {
-               PyObject *item = PyTuple_GET_ITEM(args, i);
-               PyObject *it = PyObject_GetIter(item);
-               if (it == NULL) {
-                       if (PyErr_ExceptionMatches(PyExc_TypeError))
-                               PyErr_Format(PyExc_TypeError,
-                                   "chain argument #%zd must support iteration",
-                                   i+1);
-                       Py_DECREF(ittuple);
-                       return NULL;
-               }
-               PyTuple_SET_ITEM(ittuple, i, it);
-       }
 
-       /* create chainobject structure */
        lz = (chainobject *)type->tp_alloc(type, 0);
        if (lz == NULL) {
-               Py_DECREF(ittuple);
+               Py_DECREF(source);
                return NULL;
        }
+       
+       lz->source = source;
+       lz->active = NULL;
+       return (PyObject *)lz;
+}
 
-       lz->ittuple = ittuple;
-       lz->iternum = 0;
-       lz->tuplesize = tuplesize;
+static PyObject *
+chain_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
+{
+       PyObject *source;
 
-       return (PyObject *)lz;
+       if (type == &chain_type && !_PyArg_NoKeywords("chain()", kwds))
+               return NULL;
+       
+       source = PyObject_GetIter(args);
+       if (source == NULL)
+               return NULL;
+
+       return chain_new_internal(type, source);
 }
 
 static void
 chain_dealloc(chainobject *lz)
 {
        PyObject_GC_UnTrack(lz);
-       Py_XDECREF(lz->ittuple);
+       Py_XDECREF(lz->active);
+       Py_XDECREF(lz->source);
        Py_TYPE(lz)->tp_free(lz);
 }
 
 static int
 chain_traverse(chainobject *lz, visitproc visit, void *arg)
 {
-       Py_VISIT(lz->ittuple);
+       Py_VISIT(lz->source);
+       Py_VISIT(lz->active);
        return 0;
 }
 
 static PyObject *
 chain_next(chainobject *lz)
 {
-       PyObject *it;
        PyObject *item;
 
-       while (lz->iternum < lz->tuplesize) {
-               it = PyTuple_GET_ITEM(lz->ittuple, lz->iternum);
-               item = PyIter_Next(it);
-               if (item != NULL)
-                       return item;
-               if (PyErr_Occurred()) {
-                       if (PyErr_ExceptionMatches(PyExc_StopIteration))
-                               PyErr_Clear();
-                       else
-                               return NULL;
+       if (lz->source == NULL)
+               return NULL;                            /* already stopped */
+
+       if (lz->active == NULL) {
+               PyObject *iterable = PyIter_Next(lz->source);
+               if (iterable == NULL) {
+                       Py_CLEAR(lz->source);
+                       return NULL;                    /* no more input sources */
+               }
+               lz->active = PyObject_GetIter(iterable);
+               if (lz->active == NULL) {
+                       Py_DECREF(iterable);
+                       Py_CLEAR(lz->source);
+                       return NULL;                    /* input not iterable */
                }
-               lz->iternum++;
        }
-       return NULL;
+       item = PyIter_Next(lz->active);
+       if (item != NULL)
+               return item;
+       if (PyErr_Occurred()) {
+               if (PyErr_ExceptionMatches(PyExc_StopIteration))
+                       PyErr_Clear();
+               else
+                       return NULL;                    /* input raised an exception */
+       }
+       Py_CLEAR(lz->active);
+       return chain_next(lz);                  /* recurse and use next active */
 }
 
 PyDoc_STRVAR(chain_doc,