]> granicus.if.org Git - python/commitdiff
C implementation of itertools.permutations().
authorRaymond Hettinger <python@rcn.com>
Wed, 5 Mar 2008 20:59:58 +0000 (20:59 +0000)
committerRaymond Hettinger <python@rcn.com>
Wed, 5 Mar 2008 20:59:58 +0000 (20:59 +0000)
Lib/test/test_itertools.py
Misc/NEWS
Modules/itertoolsmodule.c

index 4197989888ff314adfc981915b5fa8aa396b48cb..069274745180f894a7ff79554ac969d798582024 100644 (file)
@@ -47,15 +47,6 @@ def fact(n):
     'Factorial'
     return prod(range(1, n+1))
 
-def permutations(iterable, r=None):
-    # XXX use this until real permutations code is added
-    pool = tuple(iterable)
-    n = len(pool)
-    r = n if r is None else r
-    for indices in product(range(n), repeat=r):
-        if len(set(indices)) == r:
-            yield tuple(pool[i] for i in indices)
-
 class TestBasicOps(unittest.TestCase):
     def test_chain(self):
         self.assertEqual(list(chain('abc', 'def')), list('abcdef'))
@@ -117,6 +108,8 @@ class TestBasicOps(unittest.TestCase):
                     self.assertEqual(len(set(c)), r)                    # no duplicate elements
                     self.assertEqual(list(c), sorted(c))                # keep original ordering
                     self.assert_(all(e in values for e in c))           # elements taken from input iterable
+                    self.assertEqual(list(c),
+                                     [e for e in values if e in c])      # comb is a subsequence of the input iterable
                 self.assertEqual(result, list(combinations1(values, r))) # matches first pure python version
                 self.assertEqual(result, list(combinations2(values, r))) # matches first pure python version
 
@@ -127,9 +120,10 @@ class TestBasicOps(unittest.TestCase):
     def test_permutations(self):
         self.assertRaises(TypeError, permutations)              # too few arguments
         self.assertRaises(TypeError, permutations, 'abc', 2, 1) # too many arguments
-##        self.assertRaises(TypeError, permutations, None)        # pool is not iterable
-##        self.assertRaises(ValueError, permutations, 'abc', -2)  # r is negative
-##        self.assertRaises(ValueError, permutations, 'abc', 32)  # r is too big
+        self.assertRaises(TypeError, permutations, None)        # pool is not iterable
+        self.assertRaises(ValueError, permutations, 'abc', -2)  # r is negative
+        self.assertRaises(ValueError, permutations, 'abc', 32)  # r is too big
+        self.assertRaises(TypeError, permutations, 'abc', 's')  # r is not an int or None
         self.assertEqual(list(permutations(range(3), 2)),
                                            [(0,1), (0,2), (1,0), (1,2), (2,0), (2,1)])
 
@@ -182,7 +176,7 @@ class TestBasicOps(unittest.TestCase):
                     self.assertEqual(result, list(permutations(values)))       # test default r
 
         # Test implementation detail:  tuple re-use
-##        self.assertEqual(len(set(map(id, permutations('abcde', 3)))), 1)
+        self.assertEqual(len(set(map(id, permutations('abcde', 3)))), 1)
         self.assertNotEqual(len(set(map(id, list(permutations('abcde', 3))))), 1)
 
     def test_count(self):
@@ -407,12 +401,23 @@ class TestBasicOps(unittest.TestCase):
                                  list(product(*args, **dict(repeat=r))))
         self.assertEqual(len(list(product(*[range(7)]*6))), 7**6)
         self.assertRaises(TypeError, product, range(6), None)
+
+        def product2(*args, **kwds):
+            'Pure python version used in docs'
+            pools = map(tuple, args) * kwds.get('repeat', 1)
+            result = [[]]
+            for pool in pools:
+                result = [x+[y] for x in result for y in pool]
+            for prod in result:
+                yield tuple(prod)
+
         argtypes = ['', 'abc', '', xrange(0), xrange(4), dict(a=1, b=2, c=3),
                     set('abcdefg'), range(11), tuple(range(13))]
         for i in range(100):
             args = [random.choice(argtypes) for j in range(random.randrange(5))]
             expected_len = prod(map(len, args))
             self.assertEqual(len(list(product(*args))), expected_len)
+            self.assertEqual(list(product(*args)), list(product2(*args)))
             args = map(iter, args)
             self.assertEqual(len(list(product(*args))), expected_len)
 
index 75d3cc104c0125ff3a32c1747ba6b014d92f8124..ff974734cceb6a3defad7be40ae6b3cc7e12cd42 100644 (file)
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -699,7 +699,7 @@ Library
 - Added itertools.product() which forms the Cartesian product of
   the input iterables.
 
-- Added itertools.combinations().
+- Added itertools.combinations() and itertools.permutations().
 
 - Patch #1541463: optimize performance of cgi.FieldStorage operations.
 
index d8e14d0dcee58f152e4b638a86714dc9c882eae6..fe06ef48e38f2b6e5a1f6ff084c68fccddd36804 100644 (file)
@@ -2238,6 +2238,279 @@ static PyTypeObject combinations_type = {
 };
 
 
+/* permutations object ************************************************************
+  
+def permutations(iterable, r=None):
+    'permutations(range(3), 2) --> (0,1) (0,2) (1,0) (1,2) (2,0) (2,1)'
+    pool = tuple(iterable)
+    n = len(pool)
+    r = n if r is None else r
+    indices = range(n)
+    cycles = range(n-r+1, n+1)[::-1]
+    yield tuple(pool[i] for i in indices[:r])
+    while n:
+        for i in reversed(range(r)):
+            cycles[i] -= 1
+            if cycles[i] == 0:
+                indices[i:] = indices[i+1:] + indices[i:i+1]
+                cycles[i] = n - i
+            else:
+                j = cycles[i]
+                indices[i], indices[-j] = indices[-j], indices[i]
+                yield tuple(pool[i] for i in indices[:r])
+                break
+        else:
+            return
+*/
+
+typedef struct {
+       PyObject_HEAD
+       PyObject *pool;                 /* input converted to a tuple */
+       Py_ssize_t *indices;            /* one index per element in the pool */
+       Py_ssize_t *cycles;             /* one rollover counter per element in the result */
+       PyObject *result;               /* most recently returned result tuple */
+       Py_ssize_t r;                   /* size of result tuple */
+       int stopped;                    /* set to 1 when the permutations iterator is exhausted */
+} permutationsobject;
+
+static PyTypeObject permutations_type;
+
+static PyObject *
+permutations_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
+{
+       permutationsobject *po;
+       Py_ssize_t n;
+       Py_ssize_t r;
+       PyObject *robj = Py_None;
+       PyObject *pool = NULL;
+       PyObject *iterable = NULL;
+       Py_ssize_t *indices = NULL;
+       Py_ssize_t *cycles = NULL;
+       Py_ssize_t i;
+       static char *kwargs[] = {"iterable", "r", NULL};
+       if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|O:permutations", kwargs, 
+                                        &iterable, &robj))
+               return NULL;
+
+       pool = PySequence_Tuple(iterable);
+       if (pool == NULL)
+               goto error;
+       n = PyTuple_GET_SIZE(pool);
+
+       r = n;
+       if (robj != Py_None) {
+               r = PyInt_AsSsize_t(robj);
+               if (r == -1 && PyErr_Occurred())
+                       goto error;
+       }
+       if (r < 0) {
+               PyErr_SetString(PyExc_ValueError, "r must be non-negative");
+               goto error;
+       }
+       if (r > n) {
+               PyErr_SetString(PyExc_ValueError, "r cannot be bigger than the iterable");
+               goto error;
+       }
+
+       indices = PyMem_Malloc(n * sizeof(Py_ssize_t));
+       cycles = PyMem_Malloc(r * sizeof(Py_ssize_t));
+       if (indices == NULL || cycles == NULL) {
+               PyErr_NoMemory();
+               goto error;
+       }
+
+       for (i=0 ; i<n ; i++)
+               indices[i] = i;
+       for (i=0 ; i<r ; i++)
+               cycles[i] = n - i;
+
+       /* create permutationsobject structure */
+       po = (permutationsobject *)type->tp_alloc(type, 0);
+       if (po == NULL)
+               goto error;
+
+       po->pool = pool;
+       po->indices = indices;
+       po->cycles = cycles;
+       po->result = NULL;
+       po->r = r;
+       po->stopped = 0;
+
+       return (PyObject *)po;
+
+error:
+       if (indices != NULL)
+               PyMem_Free(indices);
+       if (cycles != NULL)
+               PyMem_Free(cycles);
+       Py_XDECREF(pool);
+       return NULL;
+}
+
+static void
+permutations_dealloc(permutationsobject *po)
+{
+       PyObject_GC_UnTrack(po);
+       Py_XDECREF(po->pool);
+       Py_XDECREF(po->result);
+       PyMem_Free(po->indices);
+       PyMem_Free(po->cycles);
+       Py_TYPE(po)->tp_free(po);
+}
+
+static int
+permutations_traverse(permutationsobject *po, visitproc visit, void *arg)
+{
+       if (po->pool != NULL)
+               Py_VISIT(po->pool);
+       if (po->result != NULL)
+               Py_VISIT(po->result);
+       return 0;
+}
+
+static PyObject *
+permutations_next(permutationsobject *po)
+{
+       PyObject *elem;
+       PyObject *oldelem;
+       PyObject *pool = po->pool;
+       Py_ssize_t *indices = po->indices;
+       Py_ssize_t *cycles = po->cycles;
+       PyObject *result = po->result;
+       Py_ssize_t n = PyTuple_GET_SIZE(pool);
+       Py_ssize_t r = po->r;
+       Py_ssize_t i, j, k, index;
+
+       if (po->stopped)
+               return NULL;
+
+       if (result == NULL) {
+                /* On the first pass, initialize result tuple using the indices */
+               result = PyTuple_New(r);
+               if (result == NULL)
+                       goto empty;
+               po->result = result;
+               for (i=0; i<r ; i++) {
+                       index = indices[i];
+                       elem = PyTuple_GET_ITEM(pool, index);
+                       Py_INCREF(elem);
+                       PyTuple_SET_ITEM(result, i, elem);
+               }
+       } else {
+               if (n == 0)
+                       goto empty;
+
+               /* Copy the previous result tuple or re-use it if available */
+               if (Py_REFCNT(result) > 1) {
+                       PyObject *old_result = result;
+                       result = PyTuple_New(r);
+                       if (result == NULL)
+                               goto empty;
+                       po->result = result;
+                       for (i=0; i<r ; i++) {
+                               elem = PyTuple_GET_ITEM(old_result, i);
+                               Py_INCREF(elem);
+                               PyTuple_SET_ITEM(result, i, elem);
+                       }
+                       Py_DECREF(old_result);
+               }
+               /* Now, we've got the only copy so we can update it in-place */
+               assert(r == 0 || Py_REFCNT(result) == 1);
+
+                /* Decrement rightmost cycle, moving leftward upon zero rollover */
+               for (i=r-1 ; i>=0 ; i--) {
+                       cycles[i] -= 1;
+                       if (cycles[i] == 0) {
+                               /* rotatation: indices[i:] = indices[i+1:] + indices[i:i+1] */
+                               index = indices[i];
+                               for (j=i ; j<n-1 ; j++)
+                                       indices[j] = indices[j+1];
+                               indices[n-1] = index;
+                               cycles[i] = n - i;
+                       } else {
+                               j = cycles[i];
+                               index = indices[i];
+                               indices[i] = indices[n-j];
+                               indices[n-j] = index;
+
+                               for (k=i; k<r ; k++) {
+                                       /* start with i, the leftmost element that changed */
+                                       /* yield tuple(pool[k] for k in indices[:r]) */
+                                       index = indices[k];
+                                       elem = PyTuple_GET_ITEM(pool, index);
+                                       Py_INCREF(elem);
+                                       oldelem = PyTuple_GET_ITEM(result, k);
+                                       PyTuple_SET_ITEM(result, k, elem);
+                                       Py_DECREF(oldelem);
+                               }
+                               break;
+                       }
+               }
+               /* If i is negative, then the cycles have all
+                   rolled-over and we're done. */
+               if (i < 0)
+                       goto empty;
+       }
+       Py_INCREF(result);
+       return result;
+
+empty:
+       po->stopped = 1;
+       return NULL;
+}
+
+PyDoc_STRVAR(permutations_doc,
+"permutations(iterables[, r]) --> permutations object\n\
+\n\
+Return successive r-length permutations of elements in the iterable.\n\n\
+permutations(range(4), 3) --> (0,1,2), (0,1,3), (0,2,3), (1,2,3)");
+
+static PyTypeObject permutations_type = {
+       PyVarObject_HEAD_INIT(NULL, 0)
+       "itertools.permutations",               /* tp_name */
+       sizeof(permutationsobject),     /* tp_basicsize */
+       0,                              /* tp_itemsize */
+       /* methods */
+       (destructor)permutations_dealloc,       /* tp_dealloc */
+       0,                              /* tp_print */
+       0,                              /* tp_getattr */
+       0,                              /* tp_setattr */
+       0,                              /* tp_compare */
+       0,                              /* tp_repr */
+       0,                              /* tp_as_number */
+       0,                              /* tp_as_sequence */
+       0,                              /* tp_as_mapping */
+       0,                              /* tp_hash */
+       0,                              /* tp_call */
+       0,                              /* tp_str */
+       PyObject_GenericGetAttr,        /* tp_getattro */
+       0,                              /* tp_setattro */
+       0,                              /* tp_as_buffer */
+       Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC |
+               Py_TPFLAGS_BASETYPE,    /* tp_flags */
+       permutations_doc,                       /* tp_doc */
+       (traverseproc)permutations_traverse,    /* tp_traverse */
+       0,                              /* tp_clear */
+       0,                              /* tp_richcompare */
+       0,                              /* tp_weaklistoffset */
+       PyObject_SelfIter,              /* tp_iter */
+       (iternextfunc)permutations_next,        /* tp_iternext */
+       0,                              /* tp_methods */
+       0,                              /* tp_members */
+       0,                              /* tp_getset */
+       0,                              /* tp_base */
+       0,                              /* tp_dict */
+       0,                              /* tp_descr_get */
+       0,                              /* tp_descr_set */
+       0,                              /* tp_dictoffset */
+       0,                              /* tp_init */
+       0,                              /* tp_alloc */
+       permutations_new,                       /* tp_new */
+       PyObject_GC_Del,                /* tp_free */
+};
+
+
 /* ifilter object ************************************************************/
 
 typedef struct {
@@ -3295,6 +3568,7 @@ inititertools(void)
                &count_type,
                &izip_type,
                &iziplongest_type,
+               &permutations_type,
                &product_type,         
                &repeat_type,
                &groupby_type,