Fix special-value handling for math.sum.
authorMark Dickinson <dickinsm@gmail.com>
Wed, 30 Jul 2008 12:01:41 +0000 (12:01 +0000)
committerMark Dickinson <dickinsm@gmail.com>
Wed, 30 Jul 2008 12:01:41 +0000 (12:01 +0000)
Also minor cleanups to the code: fix tabbing, remove
trailing whitespace, and reformat to fit into 80
columns.

Modules/mathmodule.c

index 5ff2f31a6a3920adf8d2595bbf07e6d5c8eb0757..e481d305e6b6abaec882fba41917ede9f77e1d30 100644 (file)
@@ -414,6 +414,7 @@ math_sum(PyObject *self, PyObject *seq)
        PyObject *item, *iter, *sum = NULL;
        Py_ssize_t i, j, n = 0, m = NUM_PARTIALS;
        double x, y, t, ps[NUM_PARTIALS], *p = ps;
+       double xsave, special_sum = 0.0, inf_sum = 0.0;
        volatile double hi, yr, lo;
 
        iter = PyObject_GetIter(seq);
@@ -438,10 +439,11 @@ math_sum(PyObject *self, PyObject *seq)
                if (PyErr_Occurred())
                        goto _sum_error;
 
+               xsave = x;
                for (i = j = 0; j < n; j++) {       /* for y in partials */
                        y = p[j];
                        if (fabs(x) < fabs(y)) {
-                                       t = x; x = y; y = t;
+                               t = x; x = y; y = t;
                        }
                        hi = x + y;
                        yr = hi - x;
@@ -450,54 +452,68 @@ math_sum(PyObject *self, PyObject *seq)
                                p[i++] = lo;
                        x = hi;
                }
-               
-               n = i;                              /* ps[i:] = [x] */                   
+
+               n = i;                              /* ps[i:] = [x] */
                if (x != 0.0) {
-                       /* If non-finite, reset partials, effectively
-                          adding subsequent items without roundoff
-                          and yielding correct non-finite results,
-                          provided IEEE 754 rules are observed */
-                       if (! Py_IS_FINITE(x))
+                       if (! Py_IS_FINITE(x)) {
+                               /* a nonfinite x could arise either as
+                                  a result of intermediate overflow, or
+                                  as a result of a nan or inf in the
+                                  summands */
+                               if (Py_IS_FINITE(xsave)) {
+                                       PyErr_SetString(PyExc_OverflowError,
+                                             "intermediate overflow in sum");
+                                       goto _sum_error;
+                               }
+                               if (Py_IS_INFINITY(xsave))
+                                       inf_sum += xsave;
+                               special_sum += xsave;
+                               /* reset partials */
                                n = 0;
+                       }
                        else if (n >= m && _sum_realloc(&p, n, ps, &m))
                                goto _sum_error;
-                       p[n++] = x;
+                       else
+                               p[n++] = x;
                }
        }
 
+       if (special_sum != 0.0) {
+               if (Py_IS_NAN(inf_sum))
+                       PyErr_SetString(PyExc_ValueError,
+                                       "-inf + inf in sum");
+               else
+                       sum = PyFloat_FromDouble(special_sum);
+               goto _sum_error;
+       }
+
        hi = 0.0;
        if (n > 0) {
                hi = p[--n];
-               if (Py_IS_FINITE(hi)) {
-                       /* sum_exact(ps, hi) from the top, stop when the sum becomes inexact. */
-                       while (n > 0) {
-                               x = hi;
-                               y = p[--n];
-                               assert(fabs(y) < fabs(x));
-                               hi = x + y;
-                               yr = hi - x;
-                               lo = y - yr;
-                               if (lo != 0.0)
-                                       break;
-                       }
-                       /* Make half-even rounding work across multiple partials.  Needed 
-                          so that sum([1e-16, 1, 1e16]) will round-up the last digit to 
-                          two instead of down to zero (the 1e-16 makes the 1 slightly 
-                          closer to two).  With a potential 1 ULP rounding error fixed-up,
-                          math.sum() can guarantee commutativity. */
-                       if (n > 0 && ((lo < 0.0 && p[n-1] < 0.0) ||
-                                     (lo > 0.0 && p[n-1] > 0.0))) {
-                               y = lo * 2.0;
-                               x = hi + y;
-                               yr = x - hi;
-                               if (y == yr)
-                                       hi = x;
-                       }
+               /* sum_exact(ps, hi) from the top, stop when the sum becomes
+                  inexact. */
+               while (n > 0) {
+                       x = hi;
+                       y = p[--n];
+                       assert(fabs(y) < fabs(x));
+                       hi = x + y;
+                       yr = hi - x;
+                       lo = y - yr;
+                       if (lo != 0.0)
+                               break;
                }
-               else {  /* raise exception corresponding to a special value */
-                       errno = Py_IS_NAN(hi) ? EDOM : ERANGE;
-                       if (is_error(hi))
-                               goto _sum_error;
+               /* Make half-even rounding work across multiple partials.
+                  Needed so that sum([1e-16, 1, 1e16]) will round-up the last
+                  digit to two instead of down to zero (the 1e-16 makes the 1
+                  slightly closer to two).  With a potential 1 ULP rounding
+                  error fixed-up, math.sum() can guarantee commutativity. */
+               if (n > 0 && ((lo < 0.0 && p[n-1] < 0.0) ||
+                             (lo > 0.0 && p[n-1] > 0.0))) {
+                       y = lo * 2.0;
+                       x = hi + y;
+                       yr = x - hi;
+                       if (y == yr)
+                               hi = x;
                }
        }
        sum = PyFloat_FromDouble(hi);