]> granicus.if.org Git - python/commitdiff
Added itertools.tee()
authorRaymond Hettinger <python@rcn.com>
Fri, 24 Oct 2003 08:45:23 +0000 (08:45 +0000)
committerRaymond Hettinger <python@rcn.com>
Fri, 24 Oct 2003 08:45:23 +0000 (08:45 +0000)
It works like the pure python verion except:
* it stops storing data after of the iterators gets deallocated
* the data queue is implemented with two stacks instead of one dictionary.

Doc/lib/libitertools.tex
Lib/test/test_itertools.py
Misc/NEWS
Modules/itertoolsmodule.c

index aec55cb4e9b63297e7826203d9cfad4b54ace8d5..eb6bc49b49203bb563d9f68bb9e44e4c3f12f612 100644 (file)
@@ -108,9 +108,8 @@ by functions or loops that truncate the stream.
                    yield element
   \end{verbatim}
 
-  Note, this is the only member of the toolkit that may require
-  significant auxiliary storage (depending on the length of the
-  iterable).
+  Note, this member of the toolkit may require significant
+  auxiliary storage (depending on the length of the iterable).
 \end{funcdesc}
 
 \begin{funcdesc}{dropwhile}{predicate, iterable}
@@ -282,6 +281,32 @@ by functions or loops that truncate the stream.
   \end{verbatim}
 \end{funcdesc}
 
+\begin{funcdesc}{tee}{iterable}
+  Return two independent iterators from a single iterable.
+  Equivalent to:
+
+  \begin{verbatim}
+     def tee(iterable):
+         def gen(next, data={}, cnt=[0]):
+             for i in count():
+                 if i == cnt[0]:
+                     item = data[i] = next()
+                     cnt[0] += 1
+                 else:
+                     item = data.pop(i)
+                 yield item
+         it = iter(iterable)
+         return (gen(it.next), gen(it.next))
+  \end{verbatim}
+
+  Note, this member of the toolkit may require significant auxiliary
+  storage (depending on how much temporary data needs to be stored).
+  In general, if one iterator is going use most or all of the data before
+  the other iterator, it is faster to use \function{list()} instead of
+  \function{tee()}.
+  \versionadded{2.4}
+\end{funcdesc}
+
 
 \subsection{Examples \label{itertools-example}}
 
@@ -369,6 +394,17 @@ def ncycles(seq, n):
 def dotproduct(vec1, vec2):
     return sum(imap(operator.mul, vec1, vec2))
 
+def flatten(listOfLists):
+    return list(chain(*listOfLists))
+
+def repeatfunc(func, times=None, *args):
+    "Repeat calls to func with specified arguments."
+    "Example:  repeatfunc(random.random)"
+    if times is None:
+        return starmap(func, repeat(args))
+    else:
+        return starmap(func, repeat(args, times))
+
 def window(seq, n=2):
     "Returns a sliding window (of width n) over data from the iterable"
     "   s -> (s0,s1,...s[n-1]), (s1,s2,...,sn), ...                   "
@@ -380,18 +416,4 @@ def window(seq, n=2):
         result = result[1:] + (elem,)
         yield result
 
-def tee(iterable):
-    "Return two independent iterators from a single iterable"
-    def gen(next, data={}, cnt=[0]):
-        dpop = data.pop
-        for i in count():
-            if i == cnt[0]:
-                item = data[i] = next()
-                cnt[0] += 1
-            else:
-                item = dpop(i)
-            yield item
-    next = iter(iterable).next
-    return (gen(next), gen(next))
-
 \end{verbatim}
index 0880be3fed47908ff033657a14d4b344ed4098ec..ce03b1aefb6b5e708074f59762c06fdeea37c3d0 100644 (file)
@@ -3,6 +3,7 @@ from test import test_support
 from itertools import *
 import sys
 import operator
+import random
 
 def onearg(x):
     'Test function of one argument'
@@ -198,6 +199,50 @@ class TestBasicOps(unittest.TestCase):
         self.assertRaises(TypeError, dropwhile(10, [(4,5)]).next)
         self.assertRaises(ValueError, dropwhile(errfunc, [(4,5)]).next)
 
+    def test_tee(self):
+        n = 100
+        def irange(n):
+            for i in xrange(n):
+                yield i
+
+        a, b = tee([])        # test empty iterator
+        self.assertEqual(list(a), [])
+        self.assertEqual(list(b), [])
+
+        a, b = tee(irange(n)) # test 100% interleaved
+        self.assertEqual(zip(a,b), zip(range(n),range(n)))
+
+        a, b = tee(irange(n)) # test 0% interleaved
+        self.assertEqual(list(a), range(n))
+        self.assertEqual(list(b), range(n))
+
+        a, b = tee(irange(n)) # test dealloc of leading iterator
+        self.assertEqual(a.next(), 0)
+        self.assertEqual(a.next(), 1)
+        del a
+        self.assertEqual(list(b), range(n))
+
+        a, b = tee(irange(n)) # test dealloc of trailing iterator
+        self.assertEqual(a.next(), 0)
+        self.assertEqual(a.next(), 1)
+        del b
+        self.assertEqual(list(a), range(2, n))
+
+        for j in xrange(5):   # test randomly interleaved
+            order = [0]*n + [1]*n
+            random.shuffle(order)
+            lists = ([], [])
+            its = tee(irange(n))
+            for i in order:
+                value = its[i].next()
+                lists[i].append(value)
+            self.assertEqual(lists[0], range(n))
+            self.assertEqual(lists[1], range(n))
+
+        self.assertRaises(TypeError, tee)
+        self.assertRaises(TypeError, tee, 3)
+        self.assertRaises(TypeError, tee, [1,2], 'x')
+
     def test_StopIteration(self):
         self.assertRaises(StopIteration, izip().next)
 
@@ -208,12 +253,65 @@ class TestBasicOps(unittest.TestCase):
         self.assertRaises(StopIteration, islice([], None).next)
         self.assertRaises(StopIteration, islice(StopNow(), None).next)
 
+        p, q = tee([])
+        self.assertRaises(StopIteration, p.next)
+        self.assertRaises(StopIteration, q.next)
+        p, q = tee(StopNow())
+        self.assertRaises(StopIteration, p.next)
+        self.assertRaises(StopIteration, q.next)
+
         self.assertRaises(StopIteration, repeat(None, 0).next)
 
         for f in (ifilter, ifilterfalse, imap, takewhile, dropwhile, starmap):
             self.assertRaises(StopIteration, f(lambda x:x, []).next)
             self.assertRaises(StopIteration, f(lambda x:x, StopNow()).next)
 
+class TestGC(unittest.TestCase):
+
+    def makecycle(self, iterator, container):
+        container.append(iterator)
+        iterator.next()
+        del container, iterator
+
+    def test_chain(self):
+        a = []
+        self.makecycle(chain(a), a)
+
+    def test_cycle(self):
+        a = []
+        self.makecycle(cycle([a]*2), a)
+
+    def test_ifilter(self):
+        a = []
+        self.makecycle(ifilter(lambda x:True, [a]*2), a)
+
+    def test_ifilterfalse(self):
+        a = []
+        self.makecycle(ifilterfalse(lambda x:False, a), a)
+
+    def test_izip(self):
+        a = []
+        self.makecycle(izip([a]*2, [a]*3), a)
+
+    def test_imap(self):
+        a = []
+        self.makecycle(imap(lambda x:x, [a]*2), a)
+
+    def test_islice(self):
+        a = []
+        self.makecycle(islice([a]*2, None), a)
+
+    def test_starmap(self):
+        a = []
+        self.makecycle(starmap(lambda *t: t, [(a,a)]*2), a)
+
+    def test_tee(self):
+        a = []
+        p, q = t = tee([a]*2)
+        a += [a, p, q, t]
+        p.next()
+        del a, p, q, t
+
 def R(seqn):
     'Regular generator'
     for i in seqn:
@@ -290,45 +388,6 @@ def L(seqn):
     'Test multiple tiers of iterators'
     return chain(imap(lambda x:x, R(Ig(G(seqn)))))
 
-class TestGC(unittest.TestCase):
-
-    def makecycle(self, iterator, container):
-        container.append(iterator)
-        iterator.next()
-        del container, iterator
-
-    def test_chain(self):
-        a = []
-        self.makecycle(chain(a), a)
-
-    def test_cycle(self):
-        a = []
-        self.makecycle(cycle([a]*2), a)
-
-    def test_ifilter(self):
-        a = []
-        self.makecycle(ifilter(lambda x:True, [a]*2), a)
-
-    def test_ifilterfalse(self):
-        a = []
-        self.makecycle(ifilterfalse(lambda x:False, a), a)
-
-    def test_izip(self):
-        a = []
-        self.makecycle(izip([a]*2, [a]*3), a)
-
-    def test_imap(self):
-        a = []
-        self.makecycle(imap(lambda x:x, [a]*2), a)
-
-    def test_islice(self):
-        a = []
-        self.makecycle(islice([a]*2, None), a)
-
-    def test_starmap(self):
-        a = []
-        self.makecycle(starmap(lambda *t: t, [(a,a)]*2), a)
-
 
 class TestVariousIteratorArgs(unittest.TestCase):
 
@@ -427,6 +486,16 @@ class TestVariousIteratorArgs(unittest.TestCase):
             self.assertRaises(TypeError, list, dropwhile(isOdd, N(s)))
             self.assertRaises(ZeroDivisionError, list, dropwhile(isOdd, E(s)))
 
+    def test_tee(self):
+        for s in ("123", "", range(1000), ('do', 1.2), xrange(2000,2200,5)):
+            for g in (G, I, Ig, S, L, R):
+                it1, it2 = tee(g(s))
+                self.assertEqual(list(it1), list(g(s)))
+                self.assertEqual(list(it2), list(g(s)))
+            self.assertRaises(TypeError, tee, X(s))
+            self.assertRaises(TypeError, list, tee(N(s))[0])
+            self.assertRaises(ZeroDivisionError, list, tee(E(s))[0])
+
 class RegressionTests(unittest.TestCase):
 
     def test_sf_793826(self):
@@ -531,6 +600,17 @@ Samuele
 >>> def dotproduct(vec1, vec2):
 ...     return sum(imap(operator.mul, vec1, vec2))
 
+>>> def flatten(listOfLists):
+...     return list(chain(*listOfLists))
+
+>>> def repeatfunc(func, times=None, *args):
+...     "Repeat calls to func with specified arguments."
+...     "   Example:  repeatfunc(random.random)"
+...     if times is None:
+...         return starmap(func, repeat(args))
+...     else:
+...         return starmap(func, repeat(args, times))
+
 >>> def window(seq, n=2):
 ...     "Returns a sliding window (of width n) over data from the iterable"
 ...     "   s -> (s0,s1,...s[n-1]), (s1,s2,...,sn), ...                   "
@@ -542,20 +622,6 @@ Samuele
 ...         result = result[1:] + (elem,)
 ...         yield result
 
->>> def tee(iterable):
-...     "Return two independent iterators from a single iterable"
-...     def gen(next, data={}, cnt=[0]):
-...         dpop = data.pop
-...         for i in count():
-...             if i == cnt[0]:
-...                 item = data[i] = next()
-...                 cnt[0] += 1
-...             else:
-...                 item = dpop(i)
-...             yield item
-...     next = iter(iterable).next
-...     return (gen(next), gen(next))
-
 This is not part of the examples but it tests to make sure the definitions
 perform as purported.
 
@@ -592,6 +658,17 @@ False
 >>> quantify(xrange(99), lambda x: x%2==0)
 50
 
+>>> a = [[1, 2, 3], [4, 5, 6]]
+>>> flatten(a)
+[1, 2, 3, 4, 5, 6]
+
+>>> list(repeatfunc(pow, 5, 2, 3))
+[8, 8, 8, 8, 8]
+
+>>> import random
+>>> take(5, imap(int, repeatfunc(random.random)))
+[0, 0, 0, 0, 0]
+
 >>> list(window('abc'))
 [('a', 'b'), ('b', 'c')]
 
@@ -607,14 +684,6 @@ False
 >>> dotproduct([1,2,3], [4,5,6])
 32
 
->>> x, y = tee(chain(xrange(2,10)))
->>> list(x), list(y)
-([2, 3, 4, 5, 6, 7, 8, 9], [2, 3, 4, 5, 6, 7, 8, 9])
-
->>> x, y = tee(chain(xrange(2,10)))
->>> zip(x, y)
-[(2, 2), (3, 3), (4, 4), (5, 5), (6, 6), (7, 7), (8, 8), (9, 9)]
-
 """
 
 __test__ = {'libreftest' : libreftest}
index a53fdaf5a482e5f39b55724e6477b62d7854a487..08568dd51805a37bdf2f269ad74a20ade9702942 100644 (file)
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -73,6 +73,24 @@ Extension modules
 
 - Implemented (?(id/name)yes|no) support in SRE (#572936).
 
+- random.seed() with no arguments or None uses time.time() as a default
+  seed.  Modified to match Py2.2 behavior and use fractional seconds so
+  that successive runs are more likely to produce different sequences.
+
+- random.Random has a new method, getrandbits(k), which returns an int
+  with k random bits.  This method is now an optional part of the API
+  for user defined generators.  Any generator that defines genrandbits()
+  can now use randrange() for ranges with a length >= 2**53.  Formerly,
+  randrange would return only even numbers for ranges that large (see
+  SF bug #812202).  Generators that do not define genrandbits() now
+  issue a warning when randrange() is called with a range that large.
+
+- itertools now has a new function, tee() which produces two independent
+  iterators from a single iterable.
+
+- itertools.izip() with no arguments now returns an empty iterator instead
+  of raising a TypeError exception.
+
 Library
 -------
 
@@ -108,21 +126,6 @@ Library
   allow any iterable.  Also the Set.update() has been deprecated because
   it duplicates Set.union_update().
 
-- random.seed() with no arguments or None uses time.time() as a default
-  seed.  Modified to match Py2.2 behavior and use fractional seconds so
-  that successive runs are more likely to produce different sequences.
-
-- random.Random has a new method, getrandbits(k), which returns an int
-  with k random bits.  This method is now an optional part of the API
-  for user defined generators.  Any generator that defines genrandbits()
-  can now use randrange() for ranges with a length >= 2**53.  Formerly,
-  randrange would return only even numbers for ranges that large (see
-  SF bug #812202).  Generators that do not define genrandbits() now
-  issue a warning when randrange() is called with a range that large.
-
-- itertools.izip() with no arguments now returns an empty iterator instead
-  of raising a TypeError exception.
-
 - _strptime.py now has a behind-the-scenes caching mechanism for the most
   recent TimeRE instance used along with the last five unique directive
   patterns.  The overall module was also made more thread-safe.
index 68e176f23d4e71c0078d6e6f76887dadcc87d80f..42440df0613e5803ef0bfd317dcbb99801753d81 100644 (file)
@@ -7,6 +7,264 @@
    All rights reserved.
 */
 
+/* independent iterator object supporting the tee object ***************/
+
+/* The tee object maintains a queue of data seen by the leading iterator
+   but not seen by the trailing iterator.  When the leading iterator
+   gets data from PyIter_Next() it appends a copy to the inbasket stack.
+   When the trailing iterator needs data, it is popped from the outbasket
+   stack.  If the outbasket stack is empty, then it is filled from the
+   inbasket (i.e. the queue is implemented using two stacks so that only
+   O(n) operations like append() and pop() are used to access data and
+   calls to reverse() never move any data element more than once).
+
+   If one of the independent iterators gets deallocated, it sets tee's
+   save_mode to zero so that future calls to PyIter_Next() stop getting
+   saved to the queue (because there is no longer a second iterator that
+   may need the data).
+*/
+
+typedef struct {
+       PyObject_HEAD
+       PyObject *it;
+       PyObject *inbasket;
+       PyObject *outbasket;
+       int save_mode;
+       int num_seen;
+} teeobject;
+
+typedef struct {
+       PyObject_HEAD
+       teeobject *tee;
+       int num_seen;
+} iiobject;
+
+static PyTypeObject ii_type;
+
+static PyObject *
+ii_next(iiobject *lz)
+{
+       teeobject *to = lz->tee;
+       PyObject *result, *tmp;
+
+       if (lz->num_seen == to->num_seen) { 
+               /* This instance is leading, use iter to get more data */
+               result = PyIter_Next(to->it);
+               if (result == NULL)
+                       return NULL;
+               if (to->save_mode)
+                       PyList_Append(to->inbasket, result);
+               to->num_seen++;
+               lz->num_seen++;
+               return result;
+       }
+
+       /* This instance is trailing, get data from the queue */
+       if (PyList_GET_SIZE(to->outbasket) == 0) {
+               /* outbasket is empty, so refill from the inbasket */
+               tmp = to->outbasket;
+               to->outbasket = to->inbasket;
+               to->inbasket = tmp;
+               PyList_Reverse(to->outbasket);
+               assert(PyList_GET_SIZE(to->outbasket) > 0);
+       }
+
+       lz->num_seen++;
+       return PyObject_CallMethod(to->outbasket, "pop", NULL);
+}
+
+static void
+ii_dealloc(iiobject *ii)
+{
+       PyObject_GC_UnTrack(ii);
+       ii->tee->save_mode = 0;  /* Stop saving data */
+       Py_XDECREF(ii->tee);
+       PyObject_GC_Del(ii);
+}
+
+static int
+ii_traverse(iiobject *ii, visitproc visit, void *arg)
+{
+       if (ii->tee)
+               return visit((PyObject *)(ii->tee), arg);
+       return 0;
+}
+
+PyDoc_STRVAR(ii_doc, "Independent iterators linked to a tee() object.");
+
+static PyTypeObject ii_type = {
+       PyObject_HEAD_INIT(&PyType_Type)
+       0,                                      /* ob_size */
+       "itertools.independent_iterator",       /* tp_name */
+       sizeof(iiobject),                       /* tp_basicsize */
+       0,                                      /* tp_itemsize */
+       /* methods */
+       (destructor)ii_dealloc,                 /* tp_dealloc */
+       0,                                      /* tp_print */
+       0,                                      /* tp_getattr */
+       0,                                      /* tp_setattr */
+       0,                                      /* tp_compare */
+       0,                                      /* tp_repr */
+       0,                                      /* tp_as_number */
+       0,                                      /* tp_as_sequence */
+       0,                                      /* tp_as_mapping */
+       0,                                      /* tp_hash */
+       0,                                      /* tp_call */
+       0,                                      /* tp_str */
+       PyObject_GenericGetAttr,                /* tp_getattro */
+       0,                                      /* tp_setattro */
+       0,                                      /* tp_as_buffer */
+       Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC,/* tp_flags */
+       ii_doc,                                 /* tp_doc */
+       (traverseproc)ii_traverse,              /* tp_traverse */
+       0,                                      /* tp_clear */
+       0,                                      /* tp_richcompare */    
+       0,                                      /* tp_weaklistoffset */
+       PyObject_SelfIter,                      /* tp_iter */
+       (iternextfunc)ii_next,                  /* tp_iternext */
+       0,                                      /* tp_methods */
+};
+
+/* tee object **********************************************************/
+
+static PyTypeObject tee_type;
+
+static PyObject *
+tee_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
+{
+       PyObject *it = NULL;
+       PyObject *iterable;
+       PyObject *inbasket = NULL, *outbasket = NULL, *result = NULL;
+       teeobject *to = NULL;
+       int i;
+
+       if (!PyArg_UnpackTuple(args, "tee", 1, 1, &iterable))
+               return NULL;
+
+       it = PyObject_GetIter(iterable);
+       if (it == NULL) goto fail;
+
+       inbasket = PyList_New(0);
+       if (inbasket == NULL) goto fail;
+
+       outbasket = PyList_New(0);
+       if (outbasket == NULL) goto fail;
+
+       to = (teeobject *)type->tp_alloc(type, 0);
+       if (to == NULL)  goto fail;
+
+       to->it = it;
+       to->inbasket = inbasket;
+       to->outbasket = outbasket;
+       to->save_mode = 1;
+       to->num_seen = 0;
+
+       /* create independent iterators */
+       result = PyTuple_New(2);
+       if (result == NULL)  goto fail;
+       for (i=0 ; i<2 ; i++) {
+               iiobject *indep_it = PyObject_GC_New(iiobject, &ii_type);
+               if (indep_it == NULL) goto fail;
+               Py_INCREF(to);
+               indep_it->tee = to;
+               indep_it->num_seen = 0;
+               PyObject_GC_Track(indep_it);
+               PyTuple_SET_ITEM(result, i, (PyObject *)indep_it);
+       }
+       goto succeed;
+fail:
+       Py_XDECREF(it);
+       Py_XDECREF(inbasket);
+       Py_XDECREF(outbasket);
+       Py_XDECREF(result);
+succeed:
+       Py_XDECREF(to);
+       return result;
+}
+
+static void
+tee_dealloc(teeobject *to)
+{
+       PyObject_GC_UnTrack(to);
+       Py_XDECREF(to->inbasket);
+       Py_XDECREF(to->outbasket);
+       Py_XDECREF(to->it);
+       to->ob_type->tp_free(to);
+}
+
+static int
+tee_traverse(teeobject *to, visitproc visit, void *arg)
+{
+       int err;
+
+       if (to->it) {
+               err = visit(to->it, arg);
+               if (err)
+                       return err;
+       }
+       if (to->inbasket) {
+               err = visit(to->inbasket, arg);
+               if (err)
+                       return err;
+       }
+       if (to->outbasket) {
+               err = visit(to->outbasket, arg);
+               if (err)
+                       return err;
+       }
+       return 0;
+}
+
+PyDoc_STRVAR(tee_doc,
+"tee(iterable) --> (it1, it2)\n\
+\n\
+Split the iterable into to independent iterables.");
+
+static PyTypeObject tee_type = {
+       PyObject_HEAD_INIT(NULL)
+       0,                              /* ob_size */
+       "itertools.tee",                /* tp_name */
+       sizeof(teeobject),              /* tp_basicsize */
+       0,                              /* tp_itemsize */
+       /* methods */
+       (destructor)tee_dealloc,        /* tp_dealloc */
+       0,                              /* tp_print */
+       0,                              /* tp_getattr */
+       0,                              /* tp_setattr */
+       0,                              /* tp_compare */
+       0,                              /* tp_repr */
+       0,                              /* tp_as_number */
+       0,                              /* tp_as_sequence */
+       0,                              /* tp_as_mapping */
+       0,                              /* tp_hash */
+       0,                              /* tp_call */
+       0,                              /* tp_str */
+       PyObject_GenericGetAttr,        /* tp_getattro */
+       0,                              /* tp_setattro */
+       0,                              /* tp_as_buffer */
+       Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC |
+               Py_TPFLAGS_BASETYPE,    /* tp_flags */
+       tee_doc,                        /* tp_doc */
+       (traverseproc)tee_traverse,     /* tp_traverse */
+       0,                              /* tp_clear */
+       0,                              /* tp_richcompare */
+       0,                              /* tp_weaklistoffset */
+       0,                              /* tp_iter */
+       0,                              /* tp_iternext */
+       0,                              /* tp_methods */
+       0,                              /* tp_members */
+       0,                              /* tp_getset */
+       0,                              /* tp_base */
+       0,                              /* tp_dict */
+       0,                              /* tp_descr_get */
+       0,                              /* tp_descr_set */
+       0,                              /* tp_dictoffset */
+       0,                              /* tp_init */
+       0,                              /* tp_alloc */
+       tee_new,                        /* tp_new */
+       PyObject_GC_Del,                /* tp_free */
+};
+
 /* cycle object **********************************************************/
 
 typedef struct {
@@ -1824,6 +2082,7 @@ inititertools(void)
        PyObject *m;
        char *name;
        PyTypeObject *typelist[] = {
+               &tee_type,
                &cycle_type,
                &dropwhile_type,
                &takewhile_type,