]> granicus.if.org Git - python/commitdiff
First draft for itertools.product(). Docs and other updates forthcoming.
authorRaymond Hettinger <python@rcn.com>
Fri, 22 Feb 2008 03:16:42 +0000 (03:16 +0000)
committerRaymond Hettinger <python@rcn.com>
Fri, 22 Feb 2008 03:16:42 +0000 (03:16 +0000)
Lib/test/test_itertools.py
Modules/itertoolsmodule.c

index 9d1922823d30744bb49932491cc4a45a9ce30ee1..e65bba7f92387290da27129c10fa877d260b863a 100644 (file)
@@ -253,6 +253,28 @@ class TestBasicOps(unittest.TestCase):
         ids = map(id, list(izip_longest('abc', 'def')))
         self.assertEqual(len(dict.fromkeys(ids)), len(ids))
 
+    def test_product(self):
+        for args, result in [
+            ([], []),                       # zero iterables   ??? is this correct
+            (['ab'], [('a',), ('b',)]),     # one iterable
+            ([range(2), range(3)], [(0,0), (0,1), (0,2), (1,0), (1,1), (1,2)]),     # two iterables
+            ([range(0), range(2), range(3)], []),           # first iterable with zero length
+            ([range(2), range(0), range(3)], []),           # middle iterable with zero length
+            ([range(2), range(3), range(0)], []),           # last iterable with zero length
+            ]:
+            self.assertEqual(list(product(*args)), result)
+        self.assertEqual(len(list(product(*[range(7)]*6))), 7**6)
+        self.assertRaises(TypeError, product, range(6), None)
+        argtypes = ['', 'abc', '', xrange(0), xrange(4), dict(a=1, b=2, c=3),
+                    set('abcdefg'), range(11), tuple(range(13))]
+        for i in range(100):
+            args = [random.choice(argtypes) for j in range(random.randrange(5))]
+            n = reduce(operator.mul, map(len, args), 1) if args else 0
+            self.assertEqual(len(list(product(*args))), n)
+            args = map(iter, args)
+            self.assertEqual(len(list(product(*args))), n)
+
+
     def test_repeat(self):
         self.assertEqual(zip(xrange(3),repeat('a')),
                          [(0, 'a'), (1, 'a'), (2, 'a')])
@@ -623,6 +645,12 @@ class TestVariousIteratorArgs(unittest.TestCase):
             self.assertRaises(TypeError, list, chain(N(s)))
             self.assertRaises(ZeroDivisionError, list, chain(E(s)))
 
+    def test_product(self):
+        for s in ("123", "", range(1000), ('do', 1.2), xrange(2000,2200,5)):
+            self.assertRaises(TypeError, product, X(s))
+            self.assertRaises(TypeError, product, N(s))
+            self.assertRaises(ZeroDivisionError, product, E(s))
+
     def test_cycle(self):
         for s in ("123", "", range(1000), ('do', 1.2), xrange(2000,2200,5)):
             for g in (G, I, Ig, S, L, R):
index 430313eaa63fd41e6114eaf904d8aaeb88224a32..89293098d508945eb22460b59d7ce311a44508a1 100644 (file)
@@ -1741,6 +1741,216 @@ static PyTypeObject chain_type = {
 };
 
 
+/* product object ************************************************************/
+
+typedef struct {
+       PyObject_HEAD
+       PyObject *pools;                /* tuple of pool tuples */
+       Py_ssize_t *maxvec;
+       Py_ssize_t *indices;
+       PyObject *result;
+       int stopped;
+} productobject;
+
+static PyTypeObject product_type;
+
+static PyObject *
+product_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
+{
+       productobject *lz;
+       Py_ssize_t npools;
+       PyObject *pools = NULL;
+       Py_ssize_t *maxvec = NULL;
+       Py_ssize_t *indices = NULL;
+       Py_ssize_t i;
+
+       if (type == &product_type && !_PyArg_NoKeywords("product()", kwds))
+               return NULL;
+
+       assert(PyTuple_Check(args));
+       npools = PyTuple_GET_SIZE(args);
+
+       maxvec = PyMem_Malloc(npools * sizeof(Py_ssize_t));
+       indices = PyMem_Malloc(npools * sizeof(Py_ssize_t));
+       if (maxvec == NULL || indices == NULL) {
+               PyErr_NoMemory();
+               goto error;
+       }
+
+       pools = PyTuple_New(npools);
+       if (pools == NULL)
+               goto error;
+
+       for (i=0; i < npools; ++i) {
+               PyObject *item = PyTuple_GET_ITEM(args, i);
+               PyObject *pool = PySequence_Tuple(item);
+               if (pool == NULL)
+                       goto error;
+
+               PyTuple_SET_ITEM(pools, i, pool);
+               maxvec[i] = PyTuple_GET_SIZE(pool);
+               indices[i] = 0;
+       }
+
+       /* create productobject structure */
+       lz = (productobject *)type->tp_alloc(type, 0);
+       if (lz == NULL) {
+               Py_DECREF(pools);
+               return NULL;
+       }
+
+       lz->pools = pools;
+       lz->maxvec = maxvec;
+       lz->indices = indices;
+       lz->result = NULL;
+       lz->stopped = 0;
+
+       return (PyObject *)lz;
+
+error:
+       if (maxvec != NULL)
+               PyMem_Free(maxvec);
+       if (indices != NULL)
+               PyMem_Free(indices);
+       Py_XDECREF(pools);
+       return NULL;
+}
+
+static void
+product_dealloc(productobject *lz)
+{
+       PyObject_GC_UnTrack(lz);
+       Py_XDECREF(lz->pools);
+       Py_XDECREF(lz->result);
+       PyMem_Free(lz->maxvec);
+       PyMem_Free(lz->indices);
+       Py_TYPE(lz)->tp_free(lz);
+}
+
+static int
+product_traverse(productobject *lz, visitproc visit, void *arg)
+{
+       Py_VISIT(lz->pools);
+       Py_VISIT(lz->result);
+       return 0;
+}
+
+static PyObject *
+product_next(productobject *lz)
+{
+       PyObject *pool;
+       PyObject *elem;
+       PyObject *tuple_result;
+       PyObject *pools = lz->pools;
+       PyObject *result = lz->result;
+       Py_ssize_t npools = PyTuple_GET_SIZE(pools);
+       Py_ssize_t i;
+
+       if (lz->stopped)
+               return NULL;
+       if (result == NULL) {
+               if (npools == 0)
+                       goto empty;
+               result = PyList_New(npools);
+               if (result == NULL)
+                       goto empty;
+               lz->result = result;
+               for (i=0; i < npools; i++) {
+                       pool = PyTuple_GET_ITEM(pools, i);
+                       if (PyTuple_GET_SIZE(pool) == 0)
+                               goto empty;
+                       elem = PyTuple_GET_ITEM(pool, 0);
+                       Py_INCREF(elem);
+                       PyList_SET_ITEM(result, i, elem);
+               }
+       } else {
+               Py_ssize_t *indices = lz->indices;
+               Py_ssize_t *maxvec = lz->maxvec;
+               for (i=npools-1 ; i >= 0 ; i--) {
+                       pool = PyTuple_GET_ITEM(pools, i);
+                       indices[i]++;
+                       if (indices[i] == maxvec[i]) {
+                               indices[i] = 0;
+                               elem = PyTuple_GET_ITEM(pool, 0);
+                               Py_INCREF(elem);
+                               PyList_SetItem(result, i, elem);
+                       } else {
+                               elem = PyTuple_GET_ITEM(pool, indices[i]);
+                               Py_INCREF(elem);
+                               PyList_SetItem(result, i, elem);
+                               break;
+                       }
+               }
+               if (i < 0)
+                       return NULL;
+       }
+
+       tuple_result = PySequence_Tuple(result);
+       if (tuple_result == NULL)
+               lz->stopped = 1;
+       return tuple_result;
+
+empty:
+       lz->stopped = 1;
+       return NULL;
+}
+
+PyDoc_STRVAR(product_doc,
+"product(*iterables) --> product object\n\
+\n\
+Cartesian product of input interables.  Equivalent to nested for-loops.\n\n\
+For example, product(A, B) returns the same as:  ((x,y) for x in A for y in B).\n\
+The leftmost iterators are in the outermost for-loop, so the output tuples\n\
+cycle in a manner similar to an odometer (with the rightmost element changing\n\
+on every iteration).\n\n\
+product('ab', range(3)) --> ('a',0) ('a',1) ('a',2) ('b',0) ('b',1) ('b',2)\n\
+product((0,1), (0,1), (0,1)) --> (0,0,0) (0,0,1) (0,1,0) (0,1,1) (1,0,0) ...");
+
+static PyTypeObject product_type = {
+       PyVarObject_HEAD_INIT(NULL, 0)
+       "itertools.product",            /* tp_name */
+       sizeof(productobject),  /* tp_basicsize */
+       0,                              /* tp_itemsize */
+       /* methods */
+       (destructor)product_dealloc,    /* tp_dealloc */
+       0,                              /* tp_print */
+       0,                              /* tp_getattr */
+       0,                              /* tp_setattr */
+       0,                              /* tp_compare */
+       0,                              /* tp_repr */
+       0,                              /* tp_as_number */
+       0,                              /* tp_as_sequence */
+       0,                              /* tp_as_mapping */
+       0,                              /* tp_hash */
+       0,                              /* tp_call */
+       0,                              /* tp_str */
+       PyObject_GenericGetAttr,        /* tp_getattro */
+       0,                              /* tp_setattro */
+       0,                              /* tp_as_buffer */
+       Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC |
+               Py_TPFLAGS_BASETYPE,    /* tp_flags */
+       product_doc,                    /* tp_doc */
+       (traverseproc)product_traverse, /* tp_traverse */
+       0,                              /* tp_clear */
+       0,                              /* tp_richcompare */
+       0,                              /* tp_weaklistoffset */
+       PyObject_SelfIter,              /* tp_iter */
+       (iternextfunc)product_next,     /* tp_iternext */
+       0,                              /* tp_methods */
+       0,                              /* tp_members */
+       0,                              /* tp_getset */
+       0,                              /* tp_base */
+       0,                              /* tp_dict */
+       0,                              /* tp_descr_get */
+       0,                              /* tp_descr_set */
+       0,                              /* tp_dictoffset */
+       0,                              /* tp_init */
+       0,                              /* tp_alloc */
+       product_new,                    /* tp_new */
+       PyObject_GC_Del,                /* tp_free */
+};
+
+
 /* ifilter object ************************************************************/
 
 typedef struct {
@@ -2796,7 +3006,8 @@ inititertools(void)
                &ifilterfalse_type,
                &count_type,
                &izip_type,
-               &iziplongest_type,                
+               &iziplongest_type,
+               &product_type,          
                &repeat_type,
                &groupby_type,
                NULL