]> granicus.if.org Git - python/commitdiff
Forward port r69001: itertools.combinations_with_replacement().
authorRaymond Hettinger <python@rcn.com>
Tue, 27 Jan 2009 04:20:44 +0000 (04:20 +0000)
committerRaymond Hettinger <python@rcn.com>
Tue, 27 Jan 2009 04:20:44 +0000 (04:20 +0000)
Doc/library/collections.rst
Doc/library/itertools.rst
Lib/test/test_itertools.py
Misc/NEWS
Modules/itertoolsmodule.c

index 169293f6c370bcbfa5a976b3919fe1ca4bf849f3..0503ad29c97215bb3e81227a3fbaab29bb623f62 100644 (file)
@@ -279,8 +279,7 @@ counts less than one::
       Section 4.6.3, Exercise 19*\.
 
     * To enumerate all distinct multisets of a given size over a given set of
-      elements, see :func:`combinations_with_replacement` in the
-      :ref:`itertools-recipes` for itertools::
+      elements, see :func:`itertools.combinations_with_replacement`.
 
           map(Counter, combinations_with_replacement('ABC', 2)) --> AA AB AC BB BC CC
 
index d28127801f5fbac71e53523983785a7da96d6a8f..aba1e25f091678a445ddc33acf40fbe0f8c65abe 100644 (file)
@@ -133,6 +133,53 @@ loops that truncate the stream.
    The number of items returned is ``n! / r! / (n-r)!`` when ``0 <= r <= n``
    or zero when ``r > n``.
 
+.. function:: combinations_with_replacement(iterable, r)
+
+   Return *r* length subsequences of elements from the input *iterable*
+   allowing individual elements to be repeated more than once.
+
+   Combinations are emitted in lexicographic sort order.  So, if the
+   input *iterable* is sorted, the combination tuples will be produced
+   in sorted order.
+
+   Elements are treated as unique based on their position, not on their
+   value.  So if the input elements are unique, the generated combinations
+   will also be unique.
+
+   Equivalent to::
+
+        def combinations_with_replacement(iterable, r):
+            # combinations_with_replacement('ABC', 2) --> AA AB AC BB BC CC
+            pool = tuple(iterable)
+            n = len(pool)
+            if not n and r:
+                return
+            indices = [0] * r
+            yield tuple(pool[i] for i in indices)
+            while 1:
+                for i in reversed(range(r)):
+                    if indices[i] != n - 1:
+                        break
+                else:
+                    return
+                indices[i:] = [indices[i] + 1] * (r - i)
+                yield tuple(pool[i] for i in indices)
+
+   The code for :func:`combinations_with_replacement` can be also expressed as
+   a subsequence of :func:`product` after filtering entries where the elements
+   are not in sorted order (according to their position in the input pool)::
+
+        def combinations_with_replacement(iterable, r):
+            pool = tuple(iterable)
+            n = len(pool)
+            for indices in product(range(n), repeat=r):
+                if sorted(indices) == list(indices):
+                    yield tuple(pool[i] for i in indices)
+
+   The number of items returned is ``(n+r-1)! / r! / (n-1)!`` when ``n > 0``.
+
+   .. versionadded:: 2.7
+
 .. function:: compress(data, selectors)
 
    Make an iterator that filters elements from *data* returning only those that
@@ -608,22 +655,6 @@ which incur interpreter overhead.
        s = list(iterable)
        return chain.from_iterable(combinations(s, r) for r in range(len(s)+1))
 
-   def combinations_with_replacement(iterable, r):
-       "combinations_with_replacement('ABC', 2) --> AA AB AC BB BC CC"
-       # number items returned:  (n+r-1)! / r! / (n-1)!
-       pool = tuple(iterable)
-       n = len(pool)
-       indices = [0] * r
-       yield tuple(pool[i] for i in indices)
-       while True:
-           for i in reversed(range(r)):
-               if indices[i] != n - 1:
-                   break
-           else:
-               return
-           indices[i:] = [indices[i] + 1] * (r - i)
-           yield tuple(pool[i] for i in indices)
-
     def unique_everseen(iterable, key=None):
         "List unique elements, preserving order. Remember all elements ever seen."
         # unique_everseen('AAAABBBCCDAABBB') --> A B C D
index 16789d8f9f3775ef244dd65db1f56bdba8352238..b1ba8c025caed6613e4924fb39cf4867531d920d 100644 (file)
@@ -131,6 +131,76 @@ class TestBasicOps(unittest.TestCase):
         self.assertEqual(len(set(map(id, combinations('abcde', 3)))), 1)
         self.assertNotEqual(len(set(map(id, list(combinations('abcde', 3))))), 1)
 
+    def test_combinations_with_replacement(self):
+        cwr = combinations_with_replacement
+        self.assertRaises(TypeError, cwr, 'abc')       # missing r argument
+        self.assertRaises(TypeError, cwr, 'abc', 2, 1) # too many arguments
+        self.assertRaises(TypeError, cwr, None)        # pool is not iterable
+        self.assertRaises(ValueError, cwr, 'abc', -2)  # r is negative
+        self.assertEqual(list(cwr('ABC', 2)),
+                         [('A','A'), ('A','B'), ('A','C'), ('B','B'), ('B','C'), ('C','C')])
+
+        def cwr1(iterable, r):
+            'Pure python version shown in the docs'
+            # number items returned:  (n+r-1)! / r! / (n-1)! when n>0
+            pool = tuple(iterable)
+            n = len(pool)
+            if not n and r:
+                return
+            indices = [0] * r
+            yield tuple(pool[i] for i in indices)
+            while 1:
+                for i in reversed(range(r)):
+                    if indices[i] != n - 1:
+                        break
+                else:
+                    return
+                indices[i:] = [indices[i] + 1] * (r - i)
+                yield tuple(pool[i] for i in indices)
+
+        def cwr2(iterable, r):
+            'Pure python version shown in the docs'
+            pool = tuple(iterable)
+            n = len(pool)
+            for indices in product(range(n), repeat=r):
+                if sorted(indices) == list(indices):
+                    yield tuple(pool[i] for i in indices)
+
+        def numcombs(n, r):
+            if not n:
+                return 0 if r else 1
+            return fact(n+r-1) / fact(r)/ fact(n-1)
+
+        for n in range(7):
+            values = [5*x-12 for x in range(n)]
+            for r in range(n+2):
+                result = list(cwr(values, r))
+
+                self.assertEqual(len(result), numcombs(n, r))           # right number of combs
+                self.assertEqual(len(result), len(set(result)))         # no repeats
+                self.assertEqual(result, sorted(result))                # lexicographic order
+
+                regular_combs = list(combinations(values, r))           # compare to combs without replacement
+                if n == 0 or r <= 1:
+                    self.assertEquals(result, regular_combs)            # cases that should be identical
+                else:
+                    self.assert_(set(result) >= set(regular_combs))     # rest should be supersets of regular combs
+
+                for c in result:
+                    self.assertEqual(len(c), r)                         # r-length combinations
+                    noruns = [k for k,v in groupby(c)]                  # combo without consecutive repeats
+                    self.assertEqual(len(noruns), len(set(noruns)))     # no repeats other than consecutive
+                    self.assertEqual(list(c), sorted(c))                # keep original ordering
+                    self.assert_(all(e in values for e in c))           # elements taken from input iterable
+                    self.assertEqual(noruns,
+                                     [e for e in values if e in c])     # comb is a subsequence of the input iterable
+                self.assertEqual(result, list(cwr1(values, r)))         # matches first pure python version
+                self.assertEqual(result, list(cwr2(values, r)))         # matches second pure python version
+
+        # Test implementation detail:  tuple re-use
+        self.assertEqual(len(set(map(id, cwr('abcde', 3)))), 1)
+        self.assertNotEqual(len(set(map(id, list(cwr('abcde', 3))))), 1)
+
     def test_permutations(self):
         self.assertRaises(TypeError, permutations)              # too few arguments
         self.assertRaises(TypeError, permutations, 'abc', 2, 1) # too many arguments
@@ -730,6 +800,10 @@ class TestExamples(unittest.TestCase):
         self.assertEqual(list(combinations(range(4), 3)),
                          [(0,1,2), (0,1,3), (0,2,3), (1,2,3)])
 
+    def test_combinations_with_replacement(self):
+        self.assertEqual(list(combinations_with_replacement('ABC', 2)),
+                         [('A','A'), ('A','B'), ('A','C'), ('B','B'), ('B','C'), ('C','C')])
+
     def test_compress(self):
         self.assertEqual(list(compress('ABCDEF', [1,0,1,0,1,1])), list('ACEF'))
 
@@ -813,6 +887,10 @@ class TestGC(unittest.TestCase):
         a = []
         self.makecycle(combinations([1,2,a,3], 3), a)
 
+    def test_combinations_with_replacement(self):
+        a = []
+        self.makecycle(combinations_with_replacement([1,2,a,3], 3), a)
+
     def test_compress(self):
         a = []
         self.makecycle(compress('ABCDEF', [1,0,1,0,1,0]), a)
@@ -1312,21 +1390,6 @@ Samuele
 ...     s = list(iterable)
 ...     return chain.from_iterable(combinations(s, r) for r in range(len(s)+1))
 
->>> def combinations_with_replacement(iterable, r):
-...     "combinations_with_replacement('ABC', 3) --> AA AB AC BB BC CC"
-...     pool = tuple(iterable)
-...     n = len(pool)
-...     indices = [0] * r
-...     yield tuple(pool[i] for i in indices)
-...     while 1:
-...         for i in reversed(range(r)):
-...             if indices[i] != n - 1:
-...                 break
-...         else:
-...             return
-...         indices[i:] = [indices[i] + 1] * (r - i)
-...         yield tuple(pool[i] for i in indices)
-
 >>> def unique_everseen(iterable, key=None):
 ...     "List unique elements, preserving order. Remember all elements ever seen."
 ...     # unique_everseen('AAAABBBCCDAABBB') --> A B C D
@@ -1407,29 +1470,6 @@ perform as purported.
 >>> list(powerset([1,2,3]))
 [(), (1,), (2,), (3,), (1, 2), (1, 3), (2, 3), (1, 2, 3)]
 
->>> list(combinations_with_replacement('abc', 2))
-[('a', 'a'), ('a', 'b'), ('a', 'c'), ('b', 'b'), ('b', 'c'), ('c', 'c')]
-
->>> list(combinations_with_replacement('01', 3))
-[('0', '0', '0'), ('0', '0', '1'), ('0', '1', '1'), ('1', '1', '1')]
-
->>> def combinations_with_replacement2(iterable, r):
-...     'Alternate version that filters from product()'
-...     pool = tuple(iterable)
-...     n = len(pool)
-...     for indices in product(range(n), repeat=r):
-...         if sorted(indices) == list(indices):
-...             yield tuple(pool[i] for i in indices)
-
->>> list(combinations_with_replacement('abc', 2)) == list(combinations_with_replacement2('abc', 2))
-True
-
->>> list(combinations_with_replacement('01', 3)) == list(combinations_with_replacement2('01', 3))
-True
-
->>> list(combinations_with_replacement('2310', 6)) == list(combinations_with_replacement2('2310', 6))
-True
-
 >>> list(unique_everseen('AAAABBBCCDAABBB'))
 ['A', 'B', 'C', 'D']
 
index 0284e1d833fe4bd119a98e4d94ad39eabc441611..ee2e2628b83aa2e1a4fbe080905bcea6dc9eb969 100644 (file)
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -153,7 +153,8 @@ Library
 
 - Issue #4863: distutils.mwerkscompiler has been removed.
 
-- Added a new function:  itertools.compress().
+- Added a new itertools functions:  combinations_with_replacement()
+  and compress().
 
 - Fix and properly document the multiprocessing module's logging
   support, expose the internal levels and provide proper usage
index bee08de73b3ed6e8dcb98ed7427e2231cfe2e313..f9d2ee8a8e94455d543c3377ea6e110c8d2f580d 100644 (file)
@@ -1683,7 +1683,8 @@ product_dealloc(productobject *lz)
        PyObject_GC_UnTrack(lz);
        Py_XDECREF(lz->pools);
        Py_XDECREF(lz->result);
-       PyMem_Free(lz->indices);
+       if (lz->indices != NULL)
+               PyMem_Free(lz->indices);
        Py_TYPE(lz)->tp_free(lz);
 }
 
@@ -1911,7 +1912,8 @@ combinations_dealloc(combinationsobject *co)
        PyObject_GC_UnTrack(co);
        Py_XDECREF(co->pool);
        Py_XDECREF(co->result);
-       PyMem_Free(co->indices);
+       if (co->indices != NULL)
+               PyMem_Free(co->indices);
        Py_TYPE(co)->tp_free(co);
 }
 
@@ -2060,6 +2062,252 @@ static PyTypeObject combinations_type = {
 };
 
 
+/* combinations with replacement object *******************************************/
+
+/* Equivalent to:
+
+               def combinations_with_replacement(iterable, r):
+                       "combinations_with_replacement('ABC', 2) --> AA AB AC BB BC CC"
+                       # number items returned:  (n+r-1)! / r! / (n-1)!
+                       pool = tuple(iterable)
+                       n = len(pool)
+                       indices = [0] * r
+                       yield tuple(pool[i] for i in indices)   
+                       while 1:
+                               for i in reversed(range(r)):
+                                       if indices[i] != n - 1:
+                                               break
+                               else:
+                                       return
+                               indices[i:] = [indices[i] + 1] * (r - i)
+                               yield tuple(pool[i] for i in indices)
+
+               def combinations_with_replacement2(iterable, r):
+                       'Alternate version that filters from product()'
+                       pool = tuple(iterable)
+                       n = len(pool)
+                       for indices in product(range(n), repeat=r):
+                               if sorted(indices) == list(indices):
+                                       yield tuple(pool[i] for i in indices)
+*/
+typedef struct {
+       PyObject_HEAD
+       PyObject *pool;                 /* input converted to a tuple */
+       Py_ssize_t *indices;    /* one index per result element */
+       PyObject *result;       /* most recently returned result tuple */
+       Py_ssize_t r;                   /* size of result tuple */
+       int stopped;                    /* set to 1 when the cwr iterator is exhausted */
+} cwrobject;
+
+static PyTypeObject cwr_type;
+
+static PyObject *
+cwr_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
+{
+       cwrobject *co;
+       Py_ssize_t n;
+       Py_ssize_t r;
+       PyObject *pool = NULL;
+       PyObject *iterable = NULL;
+       Py_ssize_t *indices = NULL;
+       Py_ssize_t i;
+       static char *kwargs[] = {"iterable", "r", NULL};
+       if (!PyArg_ParseTupleAndKeywords(args, kwds, "On:combinations_with_replacement", kwargs, 
+                                        &iterable, &r))
+               return NULL;
+
+       pool = PySequence_Tuple(iterable);
+       if (pool == NULL)
+               goto error;
+       n = PyTuple_GET_SIZE(pool);
+       if (r < 0) {
+               PyErr_SetString(PyExc_ValueError, "r must be non-negative");
+               goto error;
+       }
+
+       indices = PyMem_Malloc(r * sizeof(Py_ssize_t));
+       if (indices == NULL) {
+               PyErr_NoMemory();
+               goto error;
+       }
+
+       for (i=0 ; i<r ; i++)
+               indices[i] = 0;
+
+       /* create cwrobject structure */
+       co = (cwrobject *)type->tp_alloc(type, 0);
+       if (co == NULL)
+               goto error;
+
+       co->pool = pool;
+       co->indices = indices;
+       co->result = NULL;
+       co->r = r;
+       co->stopped = !n && r;
+
+       return (PyObject *)co;
+
+error:
+       if (indices != NULL)
+               PyMem_Free(indices);
+       Py_XDECREF(pool);
+       return NULL;
+}
+
+static void
+cwr_dealloc(cwrobject *co)
+{
+       PyObject_GC_UnTrack(co);
+       Py_XDECREF(co->pool);
+       Py_XDECREF(co->result);
+       if (co->indices != NULL)
+               PyMem_Free(co->indices);
+       Py_TYPE(co)->tp_free(co);
+}
+
+static int
+cwr_traverse(cwrobject *co, visitproc visit, void *arg)
+{
+       Py_VISIT(co->pool);
+       Py_VISIT(co->result);
+       return 0;
+}
+
+static PyObject *
+cwr_next(cwrobject *co)
+{
+       PyObject *elem;
+       PyObject *oldelem;
+       PyObject *pool = co->pool;
+       Py_ssize_t *indices = co->indices;
+       PyObject *result = co->result;
+       Py_ssize_t n = PyTuple_GET_SIZE(pool);
+       Py_ssize_t r = co->r;
+       Py_ssize_t i, j, index;
+
+       if (co->stopped)
+               return NULL;
+
+       if (result == NULL) {
+                /* On the first pass, initialize result tuple using the indices */
+               result = PyTuple_New(r);
+               if (result == NULL)
+                       goto empty;
+               co->result = result;
+               for (i=0; i<r ; i++) {
+                       index = indices[i];
+                       elem = PyTuple_GET_ITEM(pool, index);
+                       Py_INCREF(elem);
+                       PyTuple_SET_ITEM(result, i, elem);
+               }
+       } else {
+               /* Copy the previous result tuple or re-use it if available */
+               if (Py_REFCNT(result) > 1) {
+                       PyObject *old_result = result;
+                       result = PyTuple_New(r);
+                       if (result == NULL)
+                               goto empty;
+                       co->result = result;
+                       for (i=0; i<r ; i++) {
+                               elem = PyTuple_GET_ITEM(old_result, i);
+                               Py_INCREF(elem);
+                               PyTuple_SET_ITEM(result, i, elem);
+                       }
+                       Py_DECREF(old_result);
+               }
+               /* Now, we've got the only copy so we can update it in-place CPython's
+                  empty tuple is a singleton and cached in PyTuple's freelist. */
+               assert(r == 0 || Py_REFCNT(result) == 1);
+
+        /* Scan indices right-to-left until finding one that is not
+         * at its maximum (n-1). */
+               for (i=r-1 ; i >= 0 && indices[i] == n-1; i--)
+                       ;
+
+               /* If i is negative, then the indices are all at
+           their maximum value and we're done. */
+               if (i < 0)
+                       goto empty;
+
+               /* Increment the current index which we know is not at its
+           maximum.  Then set all to the right to the same value. */
+               indices[i]++;
+               for (j=i+1 ; j<r ; j++)
+                       indices[j] = indices[j-1];
+
+               /* Update the result tuple for the new indices
+                  starting with i, the leftmost index that changed */
+               for ( ; i<r ; i++) {
+                       index = indices[i];
+                       elem = PyTuple_GET_ITEM(pool, index);
+                       Py_INCREF(elem);
+                       oldelem = PyTuple_GET_ITEM(result, i);
+                       PyTuple_SET_ITEM(result, i, elem);
+                       Py_DECREF(oldelem);
+               }
+       }
+
+       Py_INCREF(result);
+       return result;
+
+empty:
+       co->stopped = 1;
+       return NULL;
+}
+
+PyDoc_STRVAR(cwr_doc,
+"combinations_with_replacement(iterable[, r]) --> combinations_with_replacement object\n\
+\n\
+Return successive r-length combinations of elements in the iterable\n\
+allowing individual elements to have successive repeats.\n\
+combinations_with_replacement('ABC', 2) --> AA AB AC BB BC CC");
+
+static PyTypeObject cwr_type = {
+       PyVarObject_HEAD_INIT(NULL, 0)
+       "itertools.combinations_with_replacement",              /* tp_name */
+       sizeof(cwrobject),              /* tp_basicsize */
+       0,                                              /* tp_itemsize */
+       /* methods */
+       (destructor)cwr_dealloc,        /* tp_dealloc */
+       0,                                              /* tp_print */
+       0,                                              /* tp_getattr */
+       0,                                              /* tp_setattr */
+       0,                                              /* tp_compare */
+       0,                                              /* tp_repr */
+       0,                                              /* tp_as_number */
+       0,                                              /* tp_as_sequence */
+       0,                                              /* tp_as_mapping */
+       0,                                              /* tp_hash */
+       0,                                              /* tp_call */
+       0,                                              /* tp_str */
+       PyObject_GenericGetAttr,        /* tp_getattro */
+       0,                                              /* tp_setattro */
+       0,                                              /* tp_as_buffer */
+       Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC |
+               Py_TPFLAGS_BASETYPE,    /* tp_flags */
+       cwr_doc,                                /* tp_doc */
+       (traverseproc)cwr_traverse,     /* tp_traverse */
+       0,                                              /* tp_clear */
+       0,                                              /* tp_richcompare */
+       0,                                              /* tp_weaklistoffset */
+       PyObject_SelfIter,              /* tp_iter */
+       (iternextfunc)cwr_next, /* tp_iternext */
+       0,                                              /* tp_methods */
+       0,                                              /* tp_members */
+       0,                                              /* tp_getset */
+       0,                                              /* tp_base */
+       0,                                              /* tp_dict */
+       0,                                              /* tp_descr_get */
+       0,                                              /* tp_descr_set */
+       0,                                              /* tp_dictoffset */
+       0,                                              /* tp_init */
+       0,                                              /* tp_alloc */
+       cwr_new,                                /* tp_new */
+       PyObject_GC_Del,                /* tp_free */
+};
+
+
 /* permutations object ************************************************************
   
 def permutations(iterable, r=None):
@@ -3191,6 +3439,7 @@ PyInit_itertools(void)
        char *name;
        PyTypeObject *typelist[] = {
                &combinations_type,
+               &cwr_type,
                &cycle_type,
                &dropwhile_type,
                &takewhile_type,