]> granicus.if.org Git - python/commitdiff
Handle the repeat keyword argument for itertools.product().
authorRaymond Hettinger <python@rcn.com>
Fri, 29 Feb 2008 02:21:48 +0000 (02:21 +0000)
committerRaymond Hettinger <python@rcn.com>
Fri, 29 Feb 2008 02:21:48 +0000 (02:21 +0000)
Lib/test/test_itertools.py
Modules/itertoolsmodule.c

index 500afef01e5b31d6466c1cf1f20e05f47d7c638f..087570c93f1a46222fdf5b7b1cbffa32f6ba98c1 100644 (file)
@@ -296,6 +296,9 @@ class TestBasicOps(unittest.TestCase):
             ([range(2), range(3), range(0)], []),           # last iterable with zero length
             ]:
             self.assertEqual(list(product(*args)), result)
+            for r in range(4):
+                self.assertEqual(list(product(*(args*r))),
+                                 list(product(*args, **dict(repeat=r))))
         self.assertEqual(len(list(product(*[range(7)]*6))), 7**6)
         self.assertRaises(TypeError, product, range(6), None)
         argtypes = ['', 'abc', '', xrange(0), xrange(4), dict(a=1, b=2, c=3),
index f29077a3af1a2704b767cb0e5033da77d19ff1e4..e3d8bd8739f6027fee26b8b6bf4002bd008956ad 100644 (file)
@@ -1782,17 +1782,32 @@ static PyObject *
 product_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
 {
        productobject *lz;
-       Py_ssize_t npools;
+       Py_ssize_t nargs, npools, repeat=1;
        PyObject *pools = NULL;
        Py_ssize_t *maxvec = NULL;
        Py_ssize_t *indices = NULL;
        Py_ssize_t i;
 
-       if (type == &product_type && !_PyArg_NoKeywords("product()", kwds))
-               return NULL;
+       if (kwds != NULL) {
+               char *kwlist[] = {"repeat", 0};
+               PyObject *tmpargs = PyTuple_New(0);
+               if (tmpargs == NULL)
+                       return NULL;
+               if (!PyArg_ParseTupleAndKeywords(tmpargs, kwds, "|n:product", kwlist, &repeat)) {
+                       Py_DECREF(tmpargs);
+                       return NULL;
+               }
+               Py_DECREF(tmpargs);
+               if (repeat < 0) {
+                       PyErr_SetString(PyExc_ValueError, 
+                                       "repeat argument cannot be negative");
+                       return NULL;
+               }
+       }
 
        assert(PyTuple_Check(args));
-       npools = PyTuple_GET_SIZE(args);
+       nargs = (repeat == 0) ? 0 : PyTuple_GET_SIZE(args);
+       npools = nargs * repeat;
 
        maxvec = PyMem_Malloc(npools * sizeof(Py_ssize_t));
        indices = PyMem_Malloc(npools * sizeof(Py_ssize_t));
@@ -1805,7 +1820,7 @@ product_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
        if (pools == NULL)
                goto error;
 
-       for (i=0; i < npools; ++i) {
+       for (i=0; i < nargs ; ++i) {
                PyObject *item = PyTuple_GET_ITEM(args, i);
                PyObject *pool = PySequence_Tuple(item);
                if (pool == NULL)
@@ -1815,6 +1830,13 @@ product_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
                maxvec[i] = PyTuple_GET_SIZE(pool);
                indices[i] = 0;
        }
+       for ( ; i < npools; ++i) {
+               PyObject *pool = PyTuple_GET_ITEM(pools, i - nargs);
+               Py_INCREF(pool);
+               PyTuple_SET_ITEM(pools, i, pool);
+               maxvec[i] = maxvec[i - nargs];
+               indices[i] = 0;
+       }
 
        /* create productobject structure */
        lz = (productobject *)type->tp_alloc(type, 0);