]> granicus.if.org Git - python/commitdiff
Improve the implementation of itertools.product()
authorRaymond Hettinger <python@rcn.com>
Sat, 23 Feb 2008 02:20:41 +0000 (02:20 +0000)
committerRaymond Hettinger <python@rcn.com>
Sat, 23 Feb 2008 02:20:41 +0000 (02:20 +0000)
* Fix-up issues pointed-out by Neal Norwitz.
* Add extensive comments.
* The lz->result variable is now a tuple instead of a list.
* Use fast macro getitem/setitem calls so most code is in-line.
* Re-use the result tuple if available (modify in-place instead of copy).

Lib/test/test_itertools.py
Modules/itertoolsmodule.c

index e65bba7f92387290da27129c10fa877d260b863a..f5dd0691c803e0baf9274f65aa69351ea1e472da 100644 (file)
@@ -274,6 +274,9 @@ class TestBasicOps(unittest.TestCase):
             args = map(iter, args)
             self.assertEqual(len(list(product(*args))), n)
 
+        # Test implementation detail:  tuple re-use
+        self.assertEqual(len(set(map(id, product('abc', 'def')))), 1)
+        self.assertNotEqual(len(set(map(id, list(product('abc', 'def'))))), 1)
 
     def test_repeat(self):
         self.assertEqual(zip(xrange(3),repeat('a')),
index 89293098d508945eb22460b59d7ce311a44508a1..5a3b03fa6baff6879b0faa0757c8dccbfbef1698 100644 (file)
@@ -1796,7 +1796,7 @@ product_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
        lz = (productobject *)type->tp_alloc(type, 0);
        if (lz == NULL) {
                Py_DECREF(pools);
-               return NULL;
+               goto error;
        }
 
        lz->pools = pools;
@@ -1840,7 +1840,7 @@ product_next(productobject *lz)
 {
        PyObject *pool;
        PyObject *elem;
-       PyObject *tuple_result;
+       PyObject *oldelem;
        PyObject *pools = lz->pools;
        PyObject *result = lz->result;
        Py_ssize_t npools = PyTuple_GET_SIZE(pools);
@@ -1848,10 +1848,14 @@ product_next(productobject *lz)
 
        if (lz->stopped)
                return NULL;
+
        if (result == NULL) {
+                /* On the first pass, return an initial tuple filled with the 
+                   first element from each pool.  If any pool is empty, then 
+                   whole product is empty and we're already done */
                if (npools == 0)
                        goto empty;
-               result = PyList_New(npools);
+               result = PyTuple_New(npools);
                if (result == NULL)
                        goto empty;
                lz->result = result;
@@ -1861,34 +1865,61 @@ product_next(productobject *lz)
                                goto empty;
                        elem = PyTuple_GET_ITEM(pool, 0);
                        Py_INCREF(elem);
-                       PyList_SET_ITEM(result, i, elem);
+                       PyTuple_SET_ITEM(result, i, elem);
                }
        } else {
                Py_ssize_t *indices = lz->indices;
                Py_ssize_t *maxvec = lz->maxvec;
+
+               /* Copy the previous result tuple or re-use it if available */
+               if (Py_REFCNT(result) > 1) {
+                       PyObject *old_result = result;
+                       result = PyTuple_New(npools);
+                       if (result == NULL)
+                               goto empty;
+                       lz->result = result;
+                       for (i=0; i < npools; i++) {
+                               elem = PyTuple_GET_ITEM(old_result, i);
+                               Py_INCREF(elem);
+                               PyTuple_SET_ITEM(result, i, elem);
+                       }
+                       Py_DECREF(old_result);
+               }
+               /* Now, we've got the only copy so we can update it in-place */
+               assert (Py_REFCNT(result) == 1);
+
+                /* Update the pool indices right-to-left.  Only advance to the
+                   next pool when the previous one rolls-over */
                for (i=npools-1 ; i >= 0 ; i--) {
                        pool = PyTuple_GET_ITEM(pools, i);
                        indices[i]++;
                        if (indices[i] == maxvec[i]) {
+                               /* Roll-over and advance to next pool */
                                indices[i] = 0;
                                elem = PyTuple_GET_ITEM(pool, 0);
                                Py_INCREF(elem);
-                               PyList_SetItem(result, i, elem);
+                               oldelem = PyTuple_GET_ITEM(result, i);
+                               PyTuple_SET_ITEM(result, i, elem);
+                               Py_DECREF(oldelem);
                        } else {
+                               /* No rollover. Just increment and stop here. */
                                elem = PyTuple_GET_ITEM(pool, indices[i]);
                                Py_INCREF(elem);
-                               PyList_SetItem(result, i, elem);
+                               oldelem = PyTuple_GET_ITEM(result, i);
+                               PyTuple_SET_ITEM(result, i, elem);
+                               Py_DECREF(oldelem);
                                break;
                        }
                }
+
+               /* If i is negative, then the indices have all rolled-over
+                   and we're done. */
                if (i < 0)
-                       return NULL;
+                       goto empty;
        }
 
-       tuple_result = PySequence_Tuple(result);
-       if (tuple_result == NULL)
-               lz->stopped = 1;
-       return tuple_result;
+       Py_INCREF(result);
+       return result;
 
 empty:
        lz->stopped = 1;
@@ -1898,7 +1929,7 @@ empty:
 PyDoc_STRVAR(product_doc,
 "product(*iterables) --> product object\n\
 \n\
-Cartesian product of input interables.  Equivalent to nested for-loops.\n\n\
+Cartesian product of input iterables.  Equivalent to nested for-loops.\n\n\
 For example, product(A, B) returns the same as:  ((x,y) for x in A for y in B).\n\
 The leftmost iterators are in the outermost for-loop, so the output tuples\n\
 cycle in a manner similar to an odometer (with the rightmost element changing\n\