]> granicus.if.org Git - python/commitdiff
bpo-34659: Adds initial kwarg to itertools.accumulate() (GH-9345)
authorLisa Roach <lisaroach14@gmail.com>
Mon, 24 Sep 2018 00:34:59 +0000 (17:34 -0700)
committerGitHub <noreply@github.com>
Mon, 24 Sep 2018 00:34:59 +0000 (17:34 -0700)
Doc/library/itertools.rst
Lib/test/test_itertools.py
Misc/NEWS.d/next/Library/2018-09-16-17-04-16.bpo-34659.CWemzH.rst [new file with mode: 0644]
Modules/clinic/itertoolsmodule.c.h
Modules/itertoolsmodule.c

index 959424ff914390e8151089da5dfaa24a8aa9433e..b1513cd8b1b909d3febc5a59cbb76e29fc9ccfa1 100644 (file)
@@ -86,29 +86,38 @@ The following module functions all construct and return iterators. Some provide
 streams of infinite length, so they should only be accessed by functions or
 loops that truncate the stream.
 
-.. function:: accumulate(iterable[, func])
+.. function:: accumulate(iterable[, func, *, initial=None])
 
     Make an iterator that returns accumulated sums, or accumulated
     results of other binary functions (specified via the optional
-    *func* argument).  If *func* is supplied, it should be a function
+    *func* argument).
+
+    If *func* is supplied, it should be a function
     of two arguments. Elements of the input *iterable* may be any type
     that can be accepted as arguments to *func*. (For example, with
     the default operation of addition, elements may be any addable
     type including :class:`~decimal.Decimal` or
-    :class:`~fractions.Fraction`.) If the input iterable is empty, the
-    output iterable will also be empty.
+    :class:`~fractions.Fraction`.)
+
+    Usually, the number of elements output matches the input iterable.
+    However, if the keyword argument *initial* is provided, the
+    accumulation leads off with the *initial* value so that the output
+    has one more element than the input iterable.
 
     Roughly equivalent to::
 
-        def accumulate(iterable, func=operator.add):
+        def accumulate(iterable, func=operator.add, *, initial=None):
             'Return running totals'
             # accumulate([1,2,3,4,5]) --> 1 3 6 10 15
+            # accumulate([1,2,3,4,5], initial=100) --> 100 101 103 106 110 115
             # accumulate([1,2,3,4,5], operator.mul) --> 1 2 6 24 120
             it = iter(iterable)
-            try:
-                total = next(it)
-            except StopIteration:
-                return
+            total = initial
+            if initial is None:
+                try:
+                    total = next(it)
+                except StopIteration:
+                    return
             yield total
             for element in it:
                 total = func(total, element)
@@ -152,6 +161,9 @@ loops that truncate the stream.
     .. versionchanged:: 3.3
        Added the optional *func* parameter.
 
+    .. versionchanged:: 3.8
+       Added the optional *initial* parameter.
+
 .. function:: chain(*iterables)
 
    Make an iterator that returns elements from the first iterable until it is
index cbbb4c4f71d3b8d2e8fad564078ee14b0c467291..ea060a98a5eef497e4cc129544c24eadb38842b3 100644 (file)
@@ -147,6 +147,12 @@ class TestBasicOps(unittest.TestCase):
             list(accumulate(s, chr))                                # unary-operation
         for proto in range(pickle.HIGHEST_PROTOCOL + 1):
             self.pickletest(proto, accumulate(range(10)))           # test pickling
+            self.pickletest(proto, accumulate(range(10), initial=7))
+        self.assertEqual(list(accumulate([10, 5, 1], initial=None)), [10, 15, 16])
+        self.assertEqual(list(accumulate([10, 5, 1], initial=100)), [100, 110, 115, 116])
+        self.assertEqual(list(accumulate([], initial=100)), [100])
+        with self.assertRaises(TypeError):
+            list(accumulate([10, 20], 100))
 
     def test_chain(self):
 
diff --git a/Misc/NEWS.d/next/Library/2018-09-16-17-04-16.bpo-34659.CWemzH.rst b/Misc/NEWS.d/next/Library/2018-09-16-17-04-16.bpo-34659.CWemzH.rst
new file mode 100644 (file)
index 0000000..3b7925a
--- /dev/null
@@ -0,0 +1 @@
+Add an optional *initial* argument to itertools.accumulate().
index 94df96c0b7e44744cad1ea328cef7601af889df3..476adc1f5c5d7ad28a2577111647b18f5ea5ad1f 100644 (file)
@@ -382,29 +382,30 @@ exit:
 }
 
 PyDoc_STRVAR(itertools_accumulate__doc__,
-"accumulate(iterable, func=None)\n"
+"accumulate(iterable, func=None, *, initial=None)\n"
 "--\n"
 "\n"
 "Return series of accumulated sums (or other binary function results).");
 
 static PyObject *
 itertools_accumulate_impl(PyTypeObject *type, PyObject *iterable,
-                          PyObject *binop);
+                          PyObject *binop, PyObject *initial);
 
 static PyObject *
 itertools_accumulate(PyTypeObject *type, PyObject *args, PyObject *kwargs)
 {
     PyObject *return_value = NULL;
-    static const char * const _keywords[] = {"iterable", "func", NULL};
-    static _PyArg_Parser _parser = {"O|O:accumulate", _keywords, 0};
+    static const char * const _keywords[] = {"iterable", "func", "initial", NULL};
+    static _PyArg_Parser _parser = {"O|O$O:accumulate", _keywords, 0};
     PyObject *iterable;
     PyObject *binop = Py_None;
+    PyObject *initial = Py_None;
 
     if (!_PyArg_ParseTupleAndKeywordsFast(args, kwargs, &_parser,
-        &iterable, &binop)) {
+        &iterable, &binop, &initial)) {
         goto exit;
     }
-    return_value = itertools_accumulate_impl(type, iterable, binop);
+    return_value = itertools_accumulate_impl(type, iterable, binop, initial);
 
 exit:
     return return_value;
@@ -509,4 +510,4 @@ itertools_count(PyTypeObject *type, PyObject *args, PyObject *kwargs)
 exit:
     return return_value;
 }
-/*[clinic end generated code: output=d9eb9601bd3296ef input=a9049054013a1b77]*/
+/*[clinic end generated code: output=c8c47b766deeffc3 input=a9049054013a1b77]*/
index ec8f0ae14206cbbaf8f607f3bc3dcd1962aabe85..89c0280c9d35c5b1ff484783b3d729fb4d16c8ea 100644 (file)
@@ -3475,6 +3475,7 @@ typedef struct {
     PyObject *total;
     PyObject *it;
     PyObject *binop;
+    PyObject *initial;
 } accumulateobject;
 
 static PyTypeObject accumulate_type;
@@ -3484,18 +3485,19 @@ static PyTypeObject accumulate_type;
 itertools.accumulate.__new__
     iterable: object
     func as binop: object = None
+    *
+    initial: object = None
 Return series of accumulated sums (or other binary function results).
 [clinic start generated code]*/
 
 static PyObject *
 itertools_accumulate_impl(PyTypeObject *type, PyObject *iterable,
-                          PyObject *binop)
-/*[clinic end generated code: output=514d0fb30ba14d55 input=6d9d16aaa1d3cbfc]*/
+                          PyObject *binop, PyObject *initial)
+/*[clinic end generated code: output=66da2650627128f8 input=c4ce20ac59bf7ffd]*/
 {
     PyObject *it;
     accumulateobject *lz;
 
-
     /* Get iterator. */
     it = PyObject_GetIter(iterable);
     if (it == NULL)
@@ -3514,6 +3516,8 @@ itertools_accumulate_impl(PyTypeObject *type, PyObject *iterable,
     }
     lz->total = NULL;
     lz->it = it;
+    Py_XINCREF(initial);
+    lz->initial = initial;
     return (PyObject *)lz;
 }
 
@@ -3524,6 +3528,7 @@ accumulate_dealloc(accumulateobject *lz)
     Py_XDECREF(lz->binop);
     Py_XDECREF(lz->total);
     Py_XDECREF(lz->it);
+    Py_XDECREF(lz->initial);
     Py_TYPE(lz)->tp_free(lz);
 }
 
@@ -3533,6 +3538,7 @@ accumulate_traverse(accumulateobject *lz, visitproc visit, void *arg)
     Py_VISIT(lz->binop);
     Py_VISIT(lz->it);
     Py_VISIT(lz->total);
+    Py_VISIT(lz->initial);
     return 0;
 }
 
@@ -3541,6 +3547,13 @@ accumulate_next(accumulateobject *lz)
 {
     PyObject *val, *newtotal;
 
+    if (lz->initial != Py_None) {
+        lz->total = lz->initial;
+        Py_INCREF(Py_None);
+        lz->initial = Py_None;
+        Py_INCREF(lz->total);
+        return lz->total;
+    }
     val = (*Py_TYPE(lz->it)->tp_iternext)(lz->it);
     if (val == NULL)
         return NULL;
@@ -3567,6 +3580,19 @@ accumulate_next(accumulateobject *lz)
 static PyObject *
 accumulate_reduce(accumulateobject *lz, PyObject *Py_UNUSED(ignored))
 {
+    if (lz->initial != Py_None) {
+        PyObject *it;
+
+        assert(lz->total == NULL);
+        if (PyType_Ready(&chain_type) < 0)
+            return NULL;
+        it = PyObject_CallFunction((PyObject *)&chain_type, "(O)O",
+                                   lz->initial, lz->it);
+        if (it == NULL)
+            return NULL;
+        return Py_BuildValue("O(NO)O", Py_TYPE(lz),
+                            it, lz->binop?lz->binop:Py_None, Py_None);
+    }
     if (lz->total == Py_None) {
         PyObject *it;