]> granicus.if.org Git - python/commitdiff
Issue #5816:
authorMark Dickinson <dickinsm@gmail.com>
Fri, 24 Apr 2009 12:46:53 +0000 (12:46 +0000)
committerMark Dickinson <dickinsm@gmail.com>
Fri, 24 Apr 2009 12:46:53 +0000 (12:46 +0000)
 - simplify parsing and printing of complex numbers
 - make complex(repr(z)) round-tripping work for complex
   numbers involving nans, infs, or negative zeros
 - don't accept some of the stranger complex strings
   that were previously allowed---e.g., complex('1..1j')

Lib/test/test_complex.py
Misc/NEWS
Objects/complexobject.c
Python/pystrtod.c

index 9a3310198cba04ab46233c017e3aca94ddaf5b4d..1ff710fb8747ead171cd17726fc4891a58222206 100644 (file)
@@ -9,7 +9,7 @@ warnings.filterwarnings(
 )
 
 from random import random
-from math import atan2
+from math import atan2, isnan, copysign
 
 INF = float("inf")
 NAN = float("nan")
@@ -44,6 +44,29 @@ class ComplexTest(unittest.TestCase):
         # check that relative difference < eps
         self.assert_(abs((x-y)/y) < eps)
 
+    def assertFloatsAreIdentical(self, x, y):
+        """assert that floats x and y are identical, in the sense that:
+        (1) both x and y are nans, or
+        (2) both x and y are infinities, with the same sign, or
+        (3) both x and y are zeros, with the same sign, or
+        (4) x and y are both finite and nonzero, and x == y
+
+        """
+        msg = 'floats {!r} and {!r} are not identical'
+
+        if isnan(x) or isnan(y):
+            if isnan(x) and isnan(y):
+                return
+        elif x == y:
+            if x != 0.0:
+                return
+            # both zero; check that signs match
+            elif copysign(1.0, x) == copysign(1.0, y):
+                return
+            else:
+                msg += ': zeros have different signs'
+        self.fail(msg.format(x, y))
+
     def assertClose(self, x, y, eps=1e-9):
         """Return true iff complexes x and y "are close\""""
         self.assertCloseAbs(x.real, y.real, eps)
@@ -220,6 +243,17 @@ class ComplexTest(unittest.TestCase):
         self.assertAlmostEqual(complex("+1"), +1)
         self.assertAlmostEqual(complex("(1+2j)"), 1+2j)
         self.assertAlmostEqual(complex("(1.3+2.2j)"), 1.3+2.2j)
+        self.assertAlmostEqual(complex("3.14+1J"), 3.14+1j)
+        self.assertAlmostEqual(complex(" ( +3.14-6J )"), 3.14-6j)
+        self.assertAlmostEqual(complex(" ( +3.14-J )"), 3.14-1j)
+        self.assertAlmostEqual(complex(" ( +3.14+j )"), 3.14+1j)
+        self.assertAlmostEqual(complex("J"), 1j)
+        self.assertAlmostEqual(complex("( j )"), 1j)
+        self.assertAlmostEqual(complex("+J"), 1j)
+        self.assertAlmostEqual(complex("( -j)"), -1j)
+        self.assertAlmostEqual(complex('1e-500'), 0.0 + 0.0j)
+        self.assertAlmostEqual(complex('-1e-500j'), 0.0 - 0.0j)
+        self.assertAlmostEqual(complex('-1e-500+1e-500j'), -0.0 + 0.0j)
 
         class complex2(complex): pass
         self.assertAlmostEqual(complex(complex2(1+1j)), 1+1j)
@@ -247,7 +281,6 @@ class ComplexTest(unittest.TestCase):
         self.assertRaises(TypeError, complex, "1", "1")
         self.assertRaises(TypeError, complex, 1, "1")
 
-        self.assertEqual(complex("  3.14+J  "), 3.14+1j)
         if test_support.have_unicode:
             self.assertEqual(complex(unicode("  3.14+J  ")), 3.14+1j)
 
@@ -275,6 +308,14 @@ class ComplexTest(unittest.TestCase):
         if test_support.have_unicode:
             self.assertRaises(ValueError, complex, unicode("1"*500))
             self.assertRaises(ValueError, complex, unicode("x"))
+        self.assertRaises(ValueError, complex, "1j+2")
+        self.assertRaises(ValueError, complex, "1e1ej")
+        self.assertRaises(ValueError, complex, "1e++1ej")
+        self.assertRaises(ValueError, complex, ")1+2j(")
+        # the following three are accepted by Python 2.6
+        self.assertRaises(ValueError, complex, "1..1j")
+        self.assertRaises(ValueError, complex, "1.11.1j")
+        self.assertRaises(ValueError, complex, "1e1.1j")
 
         class EvilExc(Exception):
             pass
@@ -339,17 +380,17 @@ class ComplexTest(unittest.TestCase):
         self.assertEqual(-6j,complex(repr(-6j)))
         self.assertEqual(6j,complex(repr(6j)))
 
-        self.assertEqual(repr(complex(1., INF)), "(1+inf*j)")
-        self.assertEqual(repr(complex(1., -INF)), "(1-inf*j)")
+        self.assertEqual(repr(complex(1., INF)), "(1+infj)")
+        self.assertEqual(repr(complex(1., -INF)), "(1-infj)")
         self.assertEqual(repr(complex(INF, 1)), "(inf+1j)")
-        self.assertEqual(repr(complex(-INF, INF)), "(-inf+inf*j)")
+        self.assertEqual(repr(complex(-INF, INF)), "(-inf+infj)")
         self.assertEqual(repr(complex(NAN, 1)), "(nan+1j)")
-        self.assertEqual(repr(complex(1, NAN)), "(1+nan*j)")
-        self.assertEqual(repr(complex(NAN, NAN)), "(nan+nan*j)")
+        self.assertEqual(repr(complex(1, NAN)), "(1+nanj)")
+        self.assertEqual(repr(complex(NAN, NAN)), "(nan+nanj)")
 
-        self.assertEqual(repr(complex(0, INF)), "inf*j")
-        self.assertEqual(repr(complex(0, -INF)), "-inf*j")
-        self.assertEqual(repr(complex(0, NAN)), "nan*j")
+        self.assertEqual(repr(complex(0, INF)), "infj")
+        self.assertEqual(repr(complex(0, -INF)), "-infj")
+        self.assertEqual(repr(complex(0, NAN)), "nanj")
 
     def test_neg(self):
         self.assertEqual(-(1+6j), -1-6j)
@@ -388,6 +429,21 @@ class ComplexTest(unittest.TestCase):
             self.assertEquals(atan2(z1.imag, -1.), atan2(0., -1.))
             self.assertEquals(atan2(z2.imag, -1.), atan2(-0., -1.))
 
+    @unittest.skipUnless(float.__getformat__("double").startswith("IEEE"),
+                         "test requires IEEE 754 doubles")
+    def test_repr_roundtrip(self):
+        # complex(repr(z)) should recover z exactly, even for complex numbers
+        # involving an infinity, nan, or negative zero
+        vals = [0.0, 1e-200, 0.0123, 3.1415, 1e50, INF, NAN]
+        vals += [-v for v in vals]
+        for x in vals:
+            for y in vals:
+                z = complex(x, y)
+                roundtrip = complex(repr(z))
+                self.assertFloatsAreIdentical(z.real, roundtrip.real)
+                self.assertFloatsAreIdentical(z.imag, roundtrip.imag)
+
+
 def test_main():
     test_support.run_unittest(ComplexTest)
 
index 5515af9ff27797c92575a289dd02a6ce0a3cf766..082aa2f27fcb37e9c0f439e56cc5c1abcfe68831 100644 (file)
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -12,6 +12,9 @@ What's New in Python 2.7 alpha 1
 Core and Builtins
 -----------------
 
+- Issue #5816: complex(repr(z)) now recovers z exactly, even when
+  z involves nans, infs or negative zeros.
+
 - Implement PEP 378, Format Specifier for Thousands Separator, for
   floats, ints, and longs.
 
index 9943d0d5004af2abdc59dbd74b4f017064d90d5f..894f7966a4b9c3691cbe9dc33de4a7533e0e586c 100644 (file)
@@ -353,83 +353,95 @@ complex_dealloc(PyObject *op)
 }
 
 
-static void
-complex_to_buf(char *buf, int bufsz, PyComplexObject *v, int precision)
+static PyObject *
+complex_format(PyComplexObject *v, char format_code)
 {
-       char format[32];
-       if (v->cval.real == 0.) {
-               if (!Py_IS_FINITE(v->cval.imag)) {
-                       if (Py_IS_NAN(v->cval.imag))
-                               strncpy(buf, "nan*j", 6);
-                       else if (copysign(1, v->cval.imag) == 1)
-                               strncpy(buf, "inf*j", 6);
-                       else
-                               strncpy(buf, "-inf*j", 7);
-               }
-               else {
-                       PyOS_snprintf(format, sizeof(format), "%%.%ig", precision);
-                       PyOS_ascii_formatd(buf, bufsz - 1, format, v->cval.imag);
-                       strncat(buf, "j", 1);
+       PyObject *result = NULL;
+       Py_ssize_t len;
+
+       /* If these are non-NULL, they'll need to be freed. */
+       char *pre = NULL;
+       char *im = NULL;
+       char *buf = NULL;
+
+       /* These do not need to be freed. re is either an alias
+          for pre or a pointer to a constant.  lead and tail
+          are pointers to constants. */
+       char *re = NULL;
+       char *lead = "";
+       char *tail = "";
+
+       if (v->cval.real == 0. && copysign(1.0, v->cval.real)==1.0) {
+               re = "";
+               im = PyOS_double_to_string(v->cval.imag, format_code,
+                                          0, 0, NULL);
+               if (!im) {
+                       PyErr_NoMemory();
+                       goto done;
                }
        } else {
-               char re[64], im[64];
                /* Format imaginary part with sign, real part without */
-               if (!Py_IS_FINITE(v->cval.real)) {
-                       if (Py_IS_NAN(v->cval.real))
-                               strncpy(re, "nan", 4);
-                       /* else if (copysign(1, v->cval.real) == 1) */
-                       else if (v->cval.real > 0)
-                               strncpy(re, "inf", 4);
-                       else
-                               strncpy(re, "-inf", 5);
-               }
-               else {
-                       PyOS_snprintf(format, sizeof(format), "%%.%ig", precision);
-                       PyOS_ascii_formatd(re, sizeof(re), format, v->cval.real);
-               }
-               if (!Py_IS_FINITE(v->cval.imag)) {
-                       if (Py_IS_NAN(v->cval.imag))
-                               strncpy(im, "+nan*", 6);
-                       /* else if (copysign(1, v->cval.imag) == 1) */
-                       else if (v->cval.imag > 0)
-                               strncpy(im, "+inf*", 6);
-                       else
-                               strncpy(im, "-inf*", 6);
+               pre = PyOS_double_to_string(v->cval.real, format_code,
+                                           0, 0, NULL);
+               if (!pre) {
+                       PyErr_NoMemory();
+                       goto done;
                }
-               else {
-                       PyOS_snprintf(format, sizeof(format), "%%+.%ig", precision);
-                       PyOS_ascii_formatd(im, sizeof(im), format, v->cval.imag);
+               re = pre;
+
+               im = PyOS_double_to_string(v->cval.imag, format_code,
+                                          0, Py_DTSF_SIGN, NULL);
+               if (!im) {
+                       PyErr_NoMemory();
+                       goto done;
                }
-               PyOS_snprintf(buf, bufsz, "(%s%sj)", re, im);
+               lead = "(";
+               tail = ")";
+       }
+       /* Alloc the final buffer. Add one for the "j" in the format string,
+          and one for the trailing zero. */
+       len = strlen(lead) + strlen(re) + strlen(im) + strlen(tail) + 2;
+       buf = PyMem_Malloc(len);
+       if (!buf) {
+               PyErr_NoMemory();
+               goto done;
        }
+       PyOS_snprintf(buf, len, "%s%s%sj%s", lead, re, im, tail);
+       result = PyString_FromString(buf);
+  done:
+       PyMem_Free(im);
+       PyMem_Free(pre);
+       PyMem_Free(buf);
+
+       return result;
 }
 
 static int
 complex_print(PyComplexObject *v, FILE *fp, int flags)
 {
-       char buf[100];
-       complex_to_buf(buf, sizeof(buf), v,
-                      (flags & Py_PRINT_RAW) ? PREC_STR : PREC_REPR);
+       PyObject *formatv;
+       char *buf;
+       formatv = complex_format(v, (flags & Py_PRINT_RAW) ? 's' : 'r');
+       if (formatv == NULL)
+               return -1;
+       buf = PyString_AS_STRING(formatv);
        Py_BEGIN_ALLOW_THREADS
        fputs(buf, fp);
        Py_END_ALLOW_THREADS
+       Py_DECREF(formatv);
        return 0;
 }
 
 static PyObject *
 complex_repr(PyComplexObject *v)
 {
-       char buf[100];
-       complex_to_buf(buf, sizeof(buf), v, PREC_REPR);
-       return PyString_FromString(buf);
+       return complex_format(v, 'r');
 }
 
 static PyObject *
 complex_str(PyComplexObject *v)
 {
-       char buf[100];
-       complex_to_buf(buf, sizeof(buf), v, PREC_STR);
-       return PyString_FromString(buf);
+       return complex_format(v, 's');
 }
 
 static long
@@ -867,11 +879,7 @@ complex_subtype_from_string(PyTypeObject *type, PyObject *v)
        const char *s, *start;
        char *end;
        double x=0.0, y=0.0, z;
-       int got_re=0, got_im=0, got_bracket=0, done=0;
-       int digit_or_dot;
-       int sw_error=0;
-       int sign;
-       char buffer[256]; /* For errors */
+       int got_bracket=0;
 #ifdef Py_USING_UNICODE
        char s_buffer[256];
 #endif
@@ -903,16 +911,13 @@ complex_subtype_from_string(PyTypeObject *type, PyObject *v)
                return NULL;
        }
 
+       errno = 0;
+
        /* position on first nonblank */
        start = s;
        while (*s && isspace(Py_CHARMASK(*s)))
                s++;
-       if (s[0] == '\0') {
-               PyErr_SetString(PyExc_ValueError,
-                               "complex() arg is an empty string");
-               return NULL;
-       }
-       if (s[0] == '(') {
+       if (*s == '(') {
                /* Skip over possible bracket from repr(). */
                got_bracket = 1;
                s++;
@@ -920,120 +925,109 @@ complex_subtype_from_string(PyTypeObject *type, PyObject *v)
                        s++;
        }
 
-       z = -1.0;
-       sign = 1;
-       do {
+       /* a valid complex string usually takes one of the three forms:
 
-               switch (*s) {
+            <float>                  - real part only
+            <float>j                 - imaginary part only
+            <float><signed-float>j   - real and imaginary parts
 
-               case '\0':
-                       if (s-start != len) {
-                               PyErr_SetString(
-                                       PyExc_ValueError,
-                                       "complex() arg contains a null byte");
-                               return NULL;
-                       }
-                       if(!done) sw_error=1;
-                       break;
+          where <float> represents any numeric string that's accepted by the
+          float constructor (including 'nan', 'inf', 'infinity', etc.), and
+          <signed-float> is any string of the form <float> whose first
+          character is '+' or '-'.
+
+          For backwards compatibility, the extra forms
+
+            <float><sign>j
+            <sign>j
+            j
 
-               case ')':
-                       if (!got_bracket || !(got_re || got_im)) {
-                               sw_error=1;
-                               break;
+          are also accepted, though support for these forms may be removed from
+          a future version of Python.
+       */
+
+       /* first look for forms starting with <float> */
+       z = PyOS_ascii_strtod(s, &end);
+       if (end == s && errno == ENOMEM)
+               return PyErr_NoMemory();
+       if (errno == ERANGE && fabs(z) >= 1.0)
+               goto overflow;
+
+       if (end != s) {
+               /* all 4 forms starting with <float> land here */
+               s = end;
+               if (*s == '+' || *s == '-') {
+                       /* <float><signed-float>j | <float><sign>j */
+                       x = z;
+                       y = PyOS_ascii_strtod(s, &end);
+                       if (end == s && errno == ENOMEM)
+                               return PyErr_NoMemory();
+                       if (errno == ERANGE && fabs(z) >= 1.0)
+                               goto overflow;
+                       if (end != s)
+                               /* <float><signed-float>j */
+                               s = end;
+                       else {
+                               /* <float><sign>j */
+                               y = *s == '+' ? 1.0 : -1.0;
+                               s++;
                        }
-                       got_bracket=0;
-                       done=1;
+                       if (!(*s == 'j' || *s == 'J'))
+                               goto parse_error;
                        s++;
-                       while (*s && isspace(Py_CHARMASK(*s)))
-                               s++;
-                       if (*s) sw_error=1;
-                       break;
-
-               case '-':
-                       sign = -1;
-                               /* Fallthrough */
-               case '+':
-                       if (done)  sw_error=1;
+               }
+               else if (*s == 'j' || *s == 'J') {
+                       /* <float>j */
                        s++;
-                       if  (  *s=='\0'||*s=='+'||*s=='-'||*s==')'||
-                              isspace(Py_CHARMASK(*s))  )  sw_error=1;
-                       break;
-
-               case 'J':
-               case 'j':
-                       if (got_im || done) {
-                               sw_error = 1;
-                               break;
-                       }
-                       if  (z<0.0) {
-                               y=sign;
-                       }
-                       else{
-                               y=sign*z;
-                       }
-                       got_im=1;
+                       y = z;
+               }
+               else
+                       /* <float> */
+                       x = z;
+       }
+       else {
+               /* not starting with <float>; must be <sign>j or j */
+               if (*s == '+' || *s == '-') {
+                       /* <sign>j */
+                       y = *s == '+' ? 1.0 : -1.0;
                        s++;
-                       if  (*s!='+' && *s!='-' )
-                               done=1;
-                       break;
-
-               default:
-                       if (isspace(Py_CHARMASK(*s))) {
-                               while (*s && isspace(Py_CHARMASK(*s)))
-                                       s++;
-                               if (*s && *s != ')')
-                                       sw_error=1;
-                               else
-                                       done = 1;
-                               break;
-                       }
-                       digit_or_dot =
-                               (*s=='.' || isdigit(Py_CHARMASK(*s)));
-                       if  (done||!digit_or_dot) {
-                               sw_error=1;
-                               break;
-                       }
-                       errno = 0;
-                       PyFPE_START_PROTECT("strtod", return 0)
-                               z = PyOS_ascii_strtod(s, &end) ;
-                       PyFPE_END_PROTECT(z)
-                               if (errno != 0) {
-                                       PyOS_snprintf(buffer, sizeof(buffer),
-                                         "float() out of range: %.150s", s);
-                                       PyErr_SetString(
-                                               PyExc_ValueError,
-                                               buffer);
-                                       return NULL;
-                               }
-                       s=end;
-                       if  (*s=='J' || *s=='j') {
-
-                               break;
-                       }
-                       if  (got_re) {
-                               sw_error=1;
-                               break;
-                       }
+               }
+               else
+                       /* j */
+                       y = 1.0;
+               if (!(*s == 'j' || *s == 'J'))
+                       goto parse_error;
+               s++;
+       }
 
-                               /* accept a real part */
-                       x=sign*z;
-                       got_re=1;
-                       if  (got_im)  done=1;
-                       z = -1.0;
-                       sign = 1;
-                       break;
+       /* trailing whitespace and closing bracket */
+       while (*s && isspace(Py_CHARMASK(*s)))
+               s++;
+       if (got_bracket) {
+               /* if there was an opening parenthesis, then the corresponding
+                  closing parenthesis should be right here */
+               if (*s != ')')
+                       goto parse_error;
+               s++;
+               while (*s && isspace(Py_CHARMASK(*s)))
+                       s++;
+       }
 
-               }  /* end of switch  */
+       /* we should now be at the end of the string */
+       if (s-start != len)
+               goto parse_error;
 
-       } while (s - start < len && !sw_error);
+       return complex_subtype_from_doubles(type, x, y);
 
-       if (sw_error || got_bracket) {
-               PyErr_SetString(PyExc_ValueError,
-                               "complex() arg is a malformed string");
-               return NULL;
-       }
+  parse_error:
+       PyErr_SetString(PyExc_ValueError,
+                       "complex() arg is a malformed string");
+       return NULL;
+
+  overflow:
+       PyErr_SetString(PyExc_OverflowError,
+                       "complex() arg overflow");
 
-       return complex_subtype_from_doubles(type, x, y);
 }
 
 static PyObject *
index ce2e3825a779e78e53402dc785a33fb777bf9116..703ae64563f0c50d2cfdfdaeb427d2cf62c7813c 100644 (file)
@@ -544,8 +544,9 @@ PyAPI_FUNC(char *) PyOS_double_to_string(double val,
        }
        p = result;
 
-       /* Never add sign for nan/inf, even if asked. */
-       if (flags & Py_DTSF_SIGN && buf[0] != '-' && t == Py_DTST_FINITE)
+       /* Add sign when requested.  It's convenient (esp. when formatting
+        complex numbers) to include a sign even for inf and nan. */
+       if (flags & Py_DTSF_SIGN && buf[0] != '-')
                *p++ = '+';
 
        strcpy(p, buf);