]> granicus.if.org Git - python/commitdiff
Forward port r68941 adding itertools.compress().
authorRaymond Hettinger <python@rcn.com>
Mon, 26 Jan 2009 02:56:58 +0000 (02:56 +0000)
committerRaymond Hettinger <python@rcn.com>
Mon, 26 Jan 2009 02:56:58 +0000 (02:56 +0000)
Doc/library/collections.rst
Doc/library/itertools.rst
Lib/test/test_itertools.py
Misc/NEWS
Modules/itertoolsmodule.c

index 803abab73ab5b9200ad5204a719ad1da68f86692..a8911d6c137aed9a1e78a76cfde3b515c5de94a0 100644 (file)
@@ -286,7 +286,7 @@ counts less than one::
       Section 4.6.3, Exercise 19*\.
 
     * To enumerate all distinct multisets of a given size over a given set of
-      elements, see the :func:`combinations_with_replacement` function in the
+      elements, see :func:`combinations_with_replacement` in the
       :ref:`itertools-recipes` for itertools::
 
           map(Counter, combinations_with_replacement('ABC', 2)) --> AA AB AC BB BC CC
index 36254cd8f8d38bcfe99133be7385dfeca9ad9813..d28127801f5fbac71e53523983785a7da96d6a8f 100644 (file)
@@ -133,6 +133,20 @@ loops that truncate the stream.
    The number of items returned is ``n! / r! / (n-r)!`` when ``0 <= r <= n``
    or zero when ``r > n``.
 
+.. function:: compress(data, selectors)
+
+   Make an iterator that filters elements from *data* returning only those that
+   have a corresponding element in *selectors* that evaluates to ``True``.
+   Stops when either the *data* or *selectors* iterables have been exhausted.
+   Equivalent to::
+
+       def compress(data, selectors):
+           # compress('ABCDEF', [1,0,1,0,1,1]) --> A C E F
+           return (d for d, s in zip(data, selectors) if s)
+
+   .. versionadded:: 2.7
+
+
 .. function:: count([n])
 
    Make an iterator that returns consecutive integers starting with *n*. If not
@@ -594,10 +608,6 @@ which incur interpreter overhead.
        s = list(iterable)
        return chain.from_iterable(combinations(s, r) for r in range(len(s)+1))
 
-   def compress(data, selectors):
-       "compress('ABCDEF', [1,0,1,0,1,1]) --> A C E F"
-       return (d for d, s in zip(data, selectors) if s)
-
    def combinations_with_replacement(iterable, r):
        "combinations_with_replacement('ABC', 2) --> AA AB AC BB BC CC"
        # number items returned:  (n+r-1)! / r! / (n-1)!
index 7023b293bf047a003825fb45f9883304ea36f0e1..16789d8f9f3775ef244dd65db1f56bdba8352238 100644 (file)
@@ -195,6 +195,21 @@ class TestBasicOps(unittest.TestCase):
         self.assertEqual(len(set(map(id, permutations('abcde', 3)))), 1)
         self.assertNotEqual(len(set(map(id, list(permutations('abcde', 3))))), 1)
 
+    def test_compress(self):
+        self.assertEqual(list(compress('ABCDEF', [1,0,1,0,1,1])), list('ACEF'))
+        self.assertEqual(list(compress('ABCDEF', [0,0,0,0,0,0])), list(''))
+        self.assertEqual(list(compress('ABCDEF', [1,1,1,1,1,1])), list('ABCDEF'))
+        self.assertEqual(list(compress('ABCDEF', [1,0,1])), list('AC'))
+        self.assertEqual(list(compress('ABC', [0,1,1,1,1,1])), list('BC'))
+        n = 10000
+        data = chain.from_iterable(repeat(range(6), n))
+        selectors = chain.from_iterable(repeat((0, 1)))
+        self.assertEqual(list(compress(data, selectors)), [1,3,5] * n)
+        self.assertRaises(TypeError, compress, None, range(6))      # 1st arg not iterable
+        self.assertRaises(TypeError, compress, range(6), None)      # 2nd arg not iterable
+        self.assertRaises(TypeError, compress, range(6))            # too few args
+        self.assertRaises(TypeError, compress, range(6), None)      # too many args
+
     def test_count(self):
         self.assertEqual(lzip('abc',count()), [('a', 0), ('b', 1), ('c', 2)])
         self.assertEqual(lzip('abc',count(3)), [('a', 3), ('b', 4), ('c', 5)])
@@ -715,6 +730,9 @@ class TestExamples(unittest.TestCase):
         self.assertEqual(list(combinations(range(4), 3)),
                          [(0,1,2), (0,1,3), (0,2,3), (1,2,3)])
 
+    def test_compress(self):
+        self.assertEqual(list(compress('ABCDEF', [1,0,1,0,1,1])), list('ACEF'))
+
     def test_count(self):
         self.assertEqual(list(islice(count(10), 5)), [10, 11, 12, 13, 14])
 
@@ -795,6 +813,10 @@ class TestGC(unittest.TestCase):
         a = []
         self.makecycle(combinations([1,2,a,3], 3), a)
 
+    def test_compress(self):
+        a = []
+        self.makecycle(compress('ABCDEF', [1,0,1,0,1,0]), a)
+
     def test_cycle(self):
         a = []
         self.makecycle(cycle([a]*2), a)
@@ -948,6 +970,15 @@ class TestVariousIteratorArgs(unittest.TestCase):
             self.assertRaises(TypeError, list, chain(N(s)))
             self.assertRaises(ZeroDivisionError, list, chain(E(s)))
 
+    def test_compress(self):
+        for s in ("123", "", range(1000), ('do', 1.2), range(2000,2200,5)):
+            n = len(s)
+            for g in (G, I, Ig, S, L, R):
+                self.assertEqual(list(compress(g(s), repeat(1))), list(g(s)))
+            self.assertRaises(TypeError, compress, X(s), repeat(1))
+            self.assertRaises(TypeError, compress, N(s), repeat(1))
+            self.assertRaises(ZeroDivisionError, list, compress(E(s), repeat(1)))
+
     def test_product(self):
         for s in ("123", "", range(1000), ('do', 1.2), range(2000,2200,5)):
             self.assertRaises(TypeError, product, X(s))
@@ -1144,7 +1175,7 @@ class SubclassWithKwargsTest(unittest.TestCase):
     def test_keywords_in_subclass(self):
         # count is not subclassable...
         for cls in (repeat, zip, filter, filterfalse, chain, map,
-                    starmap, islice, takewhile, dropwhile, cycle):
+                    starmap, islice, takewhile, dropwhile, cycle, compress):
             class Subclass(cls):
                 def __init__(self, newarg=None, *args):
                     cls.__init__(self, *args)
@@ -1281,10 +1312,6 @@ Samuele
 ...     s = list(iterable)
 ...     return chain.from_iterable(combinations(s, r) for r in range(len(s)+1))
 
->>> def compress(data, selectors):
-...     "compress('ABCDEF', [1,0,1,0,1,1]) --> A C E F"
-...     return (d for d, s in zip(data, selectors) if s)
-
 >>> def combinations_with_replacement(iterable, r):
 ...     "combinations_with_replacement('ABC', 3) --> AA AB AC BB BC CC"
 ...     pool = tuple(iterable)
@@ -1380,9 +1407,6 @@ perform as purported.
 >>> list(powerset([1,2,3]))
 [(), (1,), (2,), (3,), (1, 2), (1, 3), (2, 3), (1, 2, 3)]
 
->>> list(compress('abcdef', [1,0,1,0,1,1]))
-['a', 'c', 'e', 'f']
-
 >>> list(combinations_with_replacement('abc', 2))
 [('a', 'a'), ('a', 'b'), ('a', 'c'), ('b', 'b'), ('b', 'c'), ('c', 'c')]
 
index 3c6922842cf38625f8623c9f0b0b050682de6f66..c72e51a985cfe2124ee1ea3502d3f8e818a92ad6 100644 (file)
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -150,6 +150,8 @@ Library
 
 - Issue #4863: distutils.mwerkscompiler has been removed.
 
+- Added a new function:  itertools.compress().
+
 - Fix and properly document the multiprocessing module's logging
   support, expose the internal levels and provide proper usage
   examples.
index dcf6aba0969cfc824fe39120f31846b93cd08f20..bee08de73b3ed6e8dcb98ed7427e2231cfe2e313 100644 (file)
@@ -2331,6 +2331,162 @@ static PyTypeObject permutations_type = {
 };
 
 
+/* compress object ************************************************************/
+
+/* Equivalent to:
+
+       def compress(data, selectors):
+               "compress('ABCDEF', [1,0,1,0,1,1]) --> A C E F"
+               return (d for d, s in zip(data, selectors) if s)
+*/
+
+typedef struct {
+       PyObject_HEAD
+       PyObject *data;
+       PyObject *selectors;
+} compressobject;
+
+static PyTypeObject compress_type;
+
+static PyObject *
+compress_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
+{
+       PyObject *seq1, *seq2;
+       PyObject *data=NULL, *selectors=NULL;
+       compressobject *lz;
+
+       if (type == &compress_type && !_PyArg_NoKeywords("compress()", kwds))
+               return NULL;
+
+       if (!PyArg_UnpackTuple(args, "compress", 2, 2, &seq1, &seq2))
+               return NULL;
+
+       data = PyObject_GetIter(seq1);
+       if (data == NULL)
+               goto fail;
+       selectors = PyObject_GetIter(seq2);
+       if (selectors == NULL)
+               goto fail;
+
+       /* create compressobject structure */
+       lz = (compressobject *)type->tp_alloc(type, 0);
+       if (lz == NULL)
+               goto fail;
+       lz->data = data;
+       lz->selectors = selectors;
+       return (PyObject *)lz;
+
+fail:
+       Py_XDECREF(data);
+       Py_XDECREF(selectors);
+       return NULL;
+}
+
+static void
+compress_dealloc(compressobject *lz)
+{
+       PyObject_GC_UnTrack(lz);
+       Py_XDECREF(lz->data);
+       Py_XDECREF(lz->selectors);
+       Py_TYPE(lz)->tp_free(lz);
+}
+
+static int
+compress_traverse(compressobject *lz, visitproc visit, void *arg)
+{
+       Py_VISIT(lz->data);
+       Py_VISIT(lz->selectors);
+       return 0;
+}
+
+static PyObject *
+compress_next(compressobject *lz)
+{
+       PyObject *data = lz->data, *selectors = lz->selectors;
+       PyObject *datum, *selector;
+       PyObject *(*datanext)(PyObject *) = *Py_TYPE(data)->tp_iternext;
+       PyObject *(*selectornext)(PyObject *) = *Py_TYPE(selectors)->tp_iternext;
+       int ok;
+
+       while (1) {
+               /* Steps:  get datum, get selector, evaluate selector.
+                  Order is important (to match the pure python version
+                  in terms of which input gets a chance to raise an
+                  exception first).
+               */
+
+               datum = datanext(data);
+               if (datum == NULL)
+                       return NULL;
+
+               selector = selectornext(selectors);
+               if (selector == NULL) {
+                       Py_DECREF(datum);
+                       return NULL;
+               }
+
+               ok = PyObject_IsTrue(selector);
+               Py_DECREF(selector);
+               if (ok == 1)
+                       return datum;
+               Py_DECREF(datum);
+               if (ok == -1)
+                       return NULL;
+       }
+}
+
+PyDoc_STRVAR(compress_doc,
+"compress(data sequence, selector sequence) --> iterator over selected data\n\
+\n\
+Return data elements corresponding to true selector elements.\n\
+Forms a shorter iterator from selected data elements using the\n\
+selectors to choose the data elements.");
+
+static PyTypeObject compress_type = {
+       PyVarObject_HEAD_INIT(NULL, 0)
+       "itertools.compress",           /* tp_name */
+       sizeof(compressobject),         /* tp_basicsize */
+       0,                                                      /* tp_itemsize */
+       /* methods */
+       (destructor)compress_dealloc,   /* tp_dealloc */
+       0,                                                              /* tp_print */
+       0,                                                              /* tp_getattr */
+       0,                                                              /* tp_setattr */
+       0,                                                              /* tp_compare */
+       0,                                                              /* tp_repr */
+       0,                                                              /* tp_as_number */
+       0,                                                              /* tp_as_sequence */
+       0,                                                              /* tp_as_mapping */
+       0,                                                              /* tp_hash */
+       0,                                                              /* tp_call */
+       0,                                                              /* tp_str */
+       PyObject_GenericGetAttr,                /* tp_getattro */
+       0,                                                              /* tp_setattro */
+       0,                                                              /* tp_as_buffer */
+       Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC |
+               Py_TPFLAGS_BASETYPE,            /* tp_flags */
+       compress_doc,                                   /* tp_doc */
+       (traverseproc)compress_traverse,        /* tp_traverse */
+       0,                                                              /* tp_clear */
+       0,                                                              /* tp_richcompare */
+       0,                                                              /* tp_weaklistoffset */
+       PyObject_SelfIter,                              /* tp_iter */
+       (iternextfunc)compress_next,    /* tp_iternext */
+       0,                                                              /* tp_methods */
+       0,                                                              /* tp_members */
+       0,                                                              /* tp_getset */
+       0,                                                              /* tp_base */
+       0,                                                              /* tp_dict */
+       0,                                                              /* tp_descr_get */
+       0,                                                              /* tp_descr_set */
+       0,                                                              /* tp_dictoffset */
+       0,                                                              /* tp_init */
+       0,                                                              /* tp_alloc */
+       compress_new,                                   /* tp_new */
+       PyObject_GC_Del,                                /* tp_free */
+};
+
+
 /* filterfalse object ************************************************************/
 
 typedef struct {
@@ -3041,6 +3197,7 @@ PyInit_itertools(void)
                &islice_type,
                &starmap_type,
                &chain_type,
+               &compress_type,
                &filterfalse_type,
                &count_type,
                &ziplongest_type,