]> granicus.if.org Git - python/commitdiff
Issue #2819: Add math.sum, a function that sums a sequence of floats
authorMark Dickinson <dickinsm@gmail.com>
Fri, 23 May 2008 01:35:30 +0000 (01:35 +0000)
committerMark Dickinson <dickinsm@gmail.com>
Fri, 23 May 2008 01:35:30 +0000 (01:35 +0000)
efficiently but with no intermediate loss of precision.  Based on
Raymond Hettinger's ASPN recipe.  Thanks Jean Brouwers for the patch.

Misc/NEWS
Modules/mathmodule.c

index b7b8685faf39f3922890ed69e18b35b241661cc5..9479125f706f3e257941dfd940862ae1f330552b 100644 (file)
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -36,6 +36,9 @@ Core and Builtins
 Extension Modules
 -----------------
 
+- Issue #2819: add full-precision summation function to math module,
+  based on Hettinger's ASPN Python Cookbook recipe.
+
 - Issue #2592: delegate nb_index and the floor/truediv slots in
   weakref.proxy.
 
index c4ac69adad16cdef3236a6b62133dff44ef2929b..19d6f4340393af8cf15301823624faad80b3b7f2 100644 (file)
@@ -307,6 +307,228 @@ FUNC1(tan, tan, 0,
 FUNC1(tanh, tanh, 0,
       "tanh(x)\n\nReturn the hyperbolic tangent of x.")
 
+/* Precision summation function as msum() by Raymond Hettinger in
+   <http://aspn.activestate.com/ASPN/Cookbook/Python/Recipe/393090>,
+   enhanced with the exact partials sum and roundoff from Mark
+   Dickinson's post at <http://bugs.python.org/file10357/msum4.py>.
+
+   See both of those for more details, proofs and other references.
+
+   Note 1: IEEE 754 floating point format and semantics are assumed, but not
+   explicitly maintained.  The following rules may not apply:
+
+   1. if the summands include a NaN, return a NaN,
+
+   2. if the summands include infinities of both signs, raise ValueError,
+
+   3. if the summands include infinities of only one sign, return infinity
+      with that sign,
+
+   4. otherwise (all summands are finite) if the result is infinite, raise
+      OverflowError.  The result can never be a NaN if all summands are
+      finite.
+
+   Note 2: the implementation below not include the intermediate overflow
+   handling from Mark Dickinson's msum().  Therefore, sum([1e+308, 1e-308,
+   1e+308]) returns result 1e+308, however sum([1e+308, 1e+308, 1e-308])
+   raises an OverflowError due to intermediate overflow of the first
+   partial sum.
+
+   Note 3: aggressively optimizing compilers may eliminate the roundoff
+   expressions critical for accurate summation.  For example, the compiler
+   may optimize the following expressions
+
+       hi = x + y;
+       lo = y - (hi - x);
+   to
+       hi = x + y;
+       lo = 0.0;
+
+   defeating the whole purpose.  Using volatile variables and/or explicit
+   assignment of critical subexpressions to a volatile variable should
+   remedy the problem
+
+       volatile double v;  // Deter compiler from algebraically optimizing
+                           // this critical, intermediate value away
+       hi = x + y;
+       v = hi - x;
+       lo = y - v;
+
+   by forcing the compiler to compute the value for v.  This may also help
+   when subexpression are not computed with the full double precision.
+
+   Note 4. the same summation functions may be in ./cmathmodule.c.  Make
+   sure to update both when making changes.
+*/
+
+#define NUM_PARTIALS  32  /* initial partials array size, on stack */
+
+/* Extend the partials array p[] by doubling its size.
+ */
+static int  /* non-zero on error */
+_sum_realloc(double **p_ptr, Py_ssize_t  n,
+             double  *ps,    Py_ssize_t *m_ptr)
+{
+       void *v = NULL;
+       Py_ssize_t m = *m_ptr;
+
+       m += m;  /* double */
+       if (n < m && m < (PY_SSIZE_T_MAX / sizeof(double))) {
+               double *p = *p_ptr;
+               if (p == ps) {
+                       v = PyMem_Malloc(sizeof(double) * m);
+                       if (v != NULL)
+                               memcpy(v, ps, sizeof(double) * n);
+               }
+               else
+                       v = PyMem_Realloc(p, sizeof(double) * m);
+       }
+       if (v == NULL) {  /* size overflow or no memory */
+               PyErr_SetString(PyExc_MemoryError, "math sum partials");
+               return 1;
+       }
+       *p_ptr = (double*) v;
+       *m_ptr = m;
+       return 0;
+}
+
+/* Full precision summation of a sequence of floats.
+
+   def msum(iterable):
+       partials = []  # sorted, non-overlapping partial sums
+       for x in iterable:
+           i = 0
+           for y in partials:
+               if abs(x) < abs(y):
+                   x, y = y, x
+               hi = x + y
+               lo = y - (hi - x)
+               if lo:
+                   partials[i] = lo
+                   i += 1
+               x = hi
+           partials[i:] = [x]
+       return sum_exact(partials)
+
+   Rounded x+y stored in hi with the roundoff stored in lo.  Together hi+lo
+   are exactly equal to x+y.  The inner loop applies hi/lo summation to each
+   partial so that the list of partial sums remains exact.
+
+   Sum_exact() adds the partial sums exactly and correctly rounds the final
+   result (using the round-half-to-even rule).  The items in partials remain
+   non-zero, non-special, non-overlapping and strictly increasing in
+   magnitude, but possibly not all having the same sign.
+
+   Depends on IEEE 754 arithmetic guarantees.
+ */
+static PyObject*
+math_sum(PyObject *self, PyObject *seq)
+{
+       PyObject *item, *iter, *sum = NULL;
+       Py_ssize_t i, j, n = 0, m = NUM_PARTIALS;
+       double x, y, hi, lo=0.0, ps[NUM_PARTIALS], *p = ps;
+
+       iter = PyObject_GetIter(seq);
+       if (iter == NULL)
+               return NULL;
+
+       PyFPE_START_PROTECT("sum", Py_DECREF(iter); return NULL)
+
+       for(;;) {  /* for x in iterable */
+               /* some invariants */
+               assert(0 <= n && n <= m);
+               assert((m == NUM_PARTIALS && p == ps) ||
+                      (m >  NUM_PARTIALS && p != NULL));
+
+               item = PyIter_Next(iter);
+               if (item == NULL) {
+                       if (PyErr_Occurred())
+                               goto _sum_error;
+                       else
+                               break;
+               }
+               x = PyFloat_AsDouble(item);
+               Py_DECREF(item);
+               if (PyErr_Occurred())
+                       goto _sum_error;
+
+               for (i = j = 0; j < n; j++) {  /* for y in partials */
+                       y = p[j];
+                       hi = x + y;
+                       lo = fabs(x) < fabs(y)
+                          ? x - (hi - y)   /* volatile */
+                          : y - (hi - x);  /* volatile */
+                       if (lo != 0.0)
+                               p[i++] = lo;
+                       x = hi;
+               }
+               /* ps[i:] = [x] */
+               n = i;
+               if (x != 0.0) {
+                       /* if non-finite, reset partials, effectively
+                          adding subsequent items without roundoff
+                          and yielding correct non-finite results,
+                          provided IEEE 754 rules are observed */
+                       if (! Py_IS_FINITE(x))
+                               n = 0;
+                       else if (n >= m && _sum_realloc(&p, n, ps, &m))
+                               goto _sum_error;
+                       p[n++] = x;
+               }
+       }
+       assert(n <= m);
+
+       if (n > 0) {
+               hi = p[--n];
+               if (Py_IS_FINITE(hi)) {
+                       /* sum_exact(ps, hi) from the top, stop
+                          as soon as the sum becomes inexact */
+                       while (n > 0) {
+                               x = p[--n];
+                               y = hi;
+                               hi = x + y;
+                               assert(fabs(x) < fabs(y));
+                               lo = x - (hi - y);  /* volatile */
+                               if (lo != 0.0)
+                                       break;
+                       }
+                       /* round correctly if necessary */
+                       if (n > 0 && ((lo < 0.0 && p[n-1] < 0.0) ||
+                                     (lo > 0.0 && p[n-1] > 0.0))) {
+                               y = lo * 2.0;
+                               x = hi + y;  /* volatile */
+                               if (y == (x - hi))
+                                       hi = x;
+                       }
+               }
+               else {  /* raise corresponding error */
+                       errno = Py_IS_NAN(hi) ? EDOM : ERANGE;
+                       if (is_error(hi))
+                               goto _sum_error;
+               }
+       }
+       else  /* default */
+               hi = 0.0;
+       sum = PyFloat_FromDouble(hi);
+
+_sum_error:
+       PyFPE_END_PROTECT(hi)
+
+       Py_DECREF(iter);
+       if (p != ps)
+               PyMem_Free(p);
+       return sum;
+}
+
+#undef NUM_PARTIALS
+
+PyDoc_STRVAR(math_sum_doc,
+"sum(sequence)\n\n\
+Return the full precision sum of a sequence of numbers.\n\
+When the sequence is empty, return zero.\n\n\
+For accurate results, IEEE 754 floating point format\n\
+and semantics and floating point radix 2 are required.");
+
 static PyObject *
 math_trunc(PyObject *self, PyObject *number)
 {
@@ -760,6 +982,7 @@ static PyMethodDef math_methods[] = {
        {"sin",         math_sin,       METH_O,         math_sin_doc},
        {"sinh",        math_sinh,      METH_O,         math_sinh_doc},
        {"sqrt",        math_sqrt,      METH_O,         math_sqrt_doc},
+       {"sum",         math_sum,       METH_O,         math_sum_doc},
        {"tan",         math_tan,       METH_O,         math_tan_doc},
        {"tanh",        math_tanh,      METH_O,         math_tanh_doc},
        {"trunc",       math_trunc,     METH_O,         math_trunc_doc},