]> granicus.if.org Git - python/commitdiff
Simplify the signature for itertools.accumulate() to match numpy. Handle one item...
authorRaymond Hettinger <python@rcn.com>
Fri, 3 Dec 2010 02:09:34 +0000 (02:09 +0000)
committerRaymond Hettinger <python@rcn.com>
Fri, 3 Dec 2010 02:09:34 +0000 (02:09 +0000)
Doc/library/itertools.rst
Lib/test/test_itertools.py
Modules/itertoolsmodule.c

index 56eb452b0b0f603e59bfbc1aa69a6bea81c86a37..befc6beb055bc415d740063efd34a8258730abd3 100644 (file)
@@ -90,13 +90,15 @@ loops that truncate the stream.
     parameter (which defaults to :const:`0`). Elements may be any addable type
     including :class:`Decimal` or :class:`Fraction`.  Equivalent to::
 
-        def accumulate(iterable, start=0):
+        def accumulate(iterable):
             'Return running totals'
-                # accumulate([1,2,3,4,5]) --> 1 3 6 10 15
-                total = start
-                for element in iterable:
-                    total += element
-                    yield total
+            # accumulate([1,2,3,4,5]) --> 1 3 6 10 15
+            it = iter(iterable)
+            total = next(it)
+            yield total
+            for element in it:
+                total += element
+                yield total
 
     .. versionadded:: 3.2
 
index 8a67cff60ce9db32e7c531b17227bdb6d5bfd0b8..b8f6eecbbeb83168e7328138664c78c41a171dcb 100644 (file)
@@ -59,18 +59,18 @@ class TestBasicOps(unittest.TestCase):
 
     def test_accumulate(self):
         self.assertEqual(list(accumulate(range(10))),               # one positional arg
-                         [0, 1, 3, 6, 10, 15, 21, 28, 36, 45])
-        self.assertEqual(list(accumulate(range(10), 100)),          # two positional args
-            [100, 101, 103, 106, 110, 115, 121, 128, 136, 145])
-        self.assertEqual(list(accumulate(iterable=range(10), start=100)),   # kw args
-            [100, 101, 103, 106, 110, 115, 121, 128, 136, 145])
+                          [0, 1, 3, 6, 10, 15, 21, 28, 36, 45])
+        self.assertEqual(list(accumulate(iterable=range(10))),      # kw arg
+                          [0, 1, 3, 6, 10, 15, 21, 28, 36, 45])
         for typ in int, complex, Decimal, Fraction:                 # multiple types
-            self.assertEqual(list(accumulate(range(10), typ(0))),
+            self.assertEqual(
+                list(accumulate(map(typ, range(10)))),
                 list(map(typ, [0, 1, 3, 6, 10, 15, 21, 28, 36, 45])))
         self.assertEqual(list(accumulate([])), [])                  # empty iterable
-        self.assertRaises(TypeError, accumulate, range(10), 0, 5)   # too many args
+        self.assertEqual(list(accumulate([7])), [7])                # iterable of length one
+        self.assertRaises(TypeError, accumulate, range(10), 5)      # too many args
         self.assertRaises(TypeError, accumulate)                    # too few args
-        self.assertRaises(TypeError, accumulate, range(10), x=7)    # unexpected kwd args
+        self.assertRaises(TypeError, accumulate, x=range(10))       # unexpected kwd arg
         self.assertRaises(TypeError, list, accumulate([1, []]))     # args that don't add
 
     def test_chain(self):
index 04bfffc5b0db51095248ec9093ca005f5ea3a89e..b202e5262bab463e6abe240c0974b6548265b3d4 100644 (file)
@@ -2597,41 +2597,27 @@ static PyTypeObject accumulate_type;
 static PyObject *
 accumulate_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
 {
-    static char *kwargs[] = {"iterable", "start", NULL};
+    static char *kwargs[] = {"iterable", NULL};
     PyObject *iterable;
     PyObject *it;
-    PyObject *start = NULL;
     accumulateobject *lz;
 
-    if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|O:accumulate",
-                                                    kwargs, &iterable, &start))
-       return NULL;
+    if (!PyArg_ParseTupleAndKeywords(args, kwds, "O:accumulate", kwargs, &iterable))
+        return NULL;
 
     /* Get iterator. */
     it = PyObject_GetIter(iterable);
     if (it == NULL)
         return NULL;
 
-    /* Default start value */
-    if (start == NULL) {
-           start = PyLong_FromLong(0);
-           if (start == NULL) {
-               Py_DECREF(it);
-               return NULL;
-           }
-    } else {
-        Py_INCREF(start);
-    }
-
     /* create accumulateobject structure */
     lz = (accumulateobject *)type->tp_alloc(type, 0);
     if (lz == NULL) {
         Py_DECREF(it);
-           Py_DECREF(start);
-            return NULL;
+        return NULL;
     }
 
-    lz->total = start;
+    lz->total = NULL;
     lz->it = it;
     return (PyObject *)lz;
 }
@@ -2661,11 +2647,17 @@ accumulate_next(accumulateobject *lz)
     val = PyIter_Next(lz->it);
     if (val == NULL)
         return NULL;
-    
+    if (lz->total == NULL) {
+        Py_INCREF(val);
+        lz->total = val;
+        return lz->total;
+    }
+   
     newtotal = PyNumber_Add(lz->total, val);
-       Py_DECREF(val);
+    Py_DECREF(val);
     if (newtotal == NULL)
-           return NULL;
+        return NULL;
 
     oldtotal = lz->total;
     lz->total = newtotal;
@@ -2676,7 +2668,7 @@ accumulate_next(accumulateobject *lz)
 }
 
 PyDoc_STRVAR(accumulate_doc,
-"accumulate(iterable, start=0) --> accumulate object\n\
+"accumulate(iterable) --> accumulate object\n\
 \n\
 Return series of accumulated sums.");