]> granicus.if.org Git - python/commitdiff
Finish the work on __round__ and __trunc__.
authorGuido van Rossum <guido@python.org>
Thu, 23 Aug 2007 22:07:24 +0000 (22:07 +0000)
committerGuido van Rossum <guido@python.org>
Thu, 23 Aug 2007 22:07:24 +0000 (22:07 +0000)
With Alex Martelli and Keir Mierle.

Lib/test/test_builtin.py
Objects/floatobject.c
Objects/longobject.c
Python/bltinmodule.c

index 37ea8ba88520aeee189b656abb9e8c7a6195f75e..f77cf78707d3f7f7e705cc9f3673f22589c64db5 100644 (file)
@@ -1440,6 +1440,7 @@ class BuiltinTest(unittest.TestCase):
 
     def test_round(self):
         self.assertEqual(round(0.0), 0.0)
+        self.assertEqual(type(round(0.0)), int)
         self.assertEqual(round(1.0), 1.0)
         self.assertEqual(round(10.0), 10.0)
         self.assertEqual(round(1000000000.0), 1000000000.0)
@@ -1468,6 +1469,25 @@ class BuiltinTest(unittest.TestCase):
         self.assertEqual(round(-999999999.9), -1000000000.0)
 
         self.assertEqual(round(-8.0, -1), -10.0)
+        self.assertEqual(type(round(-8.0, -1)), float)
+
+        self.assertEqual(type(round(-8.0, 0)), float)
+        self.assertEqual(type(round(-8.0, 1)), float)
+
+        # Check even / odd rounding behaviour
+        self.assertEqual(round(5.5), 6)
+        self.assertEqual(round(6.5), 6)
+        self.assertEqual(round(-5.5), -6)
+        self.assertEqual(round(-6.5), -6)
+
+        # Check behavior on ints
+        self.assertEqual(round(0), 0)
+        self.assertEqual(round(8), 8)
+        self.assertEqual(round(-8), -8)
+        self.assertEqual(type(round(0)), int)
+        self.assertEqual(type(round(-8, -1)), float)
+        self.assertEqual(type(round(-8, 0)), float)
+        self.assertEqual(type(round(-8, 1)), float)
 
         # test new kwargs
         self.assertEqual(round(number=-8.0, ndigits=-1), -10.0)
@@ -1487,6 +1507,11 @@ class BuiltinTest(unittest.TestCase):
         self.assertRaises(TypeError, round, 1, 2, 3)
         self.assertRaises(TypeError, round, TestNoRound())
 
+        t = TestNoRound()
+        t.__round__ = lambda *args: args
+        self.assertRaises(TypeError, round, t)
+        self.assertRaises(TypeError, round, t, 0)
+
     def test_setattr(self):
         setattr(sys, 'spam', 1)
         self.assertEqual(sys.spam, 1)
@@ -1529,6 +1554,18 @@ class BuiltinTest(unittest.TestCase):
         self.assertRaises(ValueError, sum, BadSeq())
 
     def test_trunc(self):
+
+        self.assertEqual(trunc(1), 1)
+        self.assertEqual(trunc(-1), -1)
+        self.assertEqual(type(trunc(1)), int)
+        self.assertEqual(type(trunc(1.5)), int)
+        self.assertEqual(trunc(1.5), 1)
+        self.assertEqual(trunc(-1.5), -1)
+        self.assertEqual(trunc(1.999999), 1)
+        self.assertEqual(trunc(-1.999999), -1)
+        self.assertEqual(trunc(-0.999999), -0)
+        self.assertEqual(trunc(-100.999), -100)
+
         class TestTrunc:
             def __trunc__(self):
                 return 23
@@ -1542,6 +1579,11 @@ class BuiltinTest(unittest.TestCase):
         self.assertRaises(TypeError, trunc, 1, 2)
         self.assertRaises(TypeError, trunc, TestNoTrunc())
 
+        t = TestNoTrunc()
+        t.__trunc__ = lambda *args: args
+        self.assertRaises(TypeError, trunc, t)
+        self.assertRaises(TypeError, trunc, t, 0)
+
     def test_tuple(self):
         self.assertEqual(tuple(()), ())
         t0_3 = (0, 1, 2, 3)
index 908258cb3a69fd47a9f92b1aa5ed7ae42e02c8d5..09efa12c65d2436f15dfc9038dad11074ca3a980 100644 (file)
@@ -743,14 +743,7 @@ float_bool(PyFloatObject *v)
 }
 
 static PyObject *
-float_long(PyObject *v)
-{
-       double x = PyFloat_AsDouble(v);
-       return PyLong_FromDouble(x);
-}
-
-static PyObject *
-float_int(PyObject *v)
+float_trunc(PyObject *v)
 {
        double x = PyFloat_AsDouble(v);
        double wholepart;       /* integral portion of x, rounded toward 0 */
@@ -775,6 +768,55 @@ float_int(PyObject *v)
        return PyLong_FromDouble(wholepart);
 }
 
+static PyObject *
+float_round(PyObject *v, PyObject *args)
+{
+#define UNDEF_NDIGITS (-0x7fffffff) /* Unlikely ndigits value */
+       double x;
+       double f;
+       double flr, cil;
+       double rounded;
+       int i;
+       int ndigits = UNDEF_NDIGITS;
+
+       if (!PyArg_ParseTuple(args, "|i", &ndigits))
+               return NULL;
+
+       x = PyFloat_AsDouble(v);
+
+       if (ndigits != UNDEF_NDIGITS) {
+               f = 1.0;
+               i = abs(ndigits);
+               while  (--i >= 0)
+                       f = f*10.0;
+               if (ndigits < 0)
+                       x /= f;
+               else
+                       x *= f;
+       }
+
+       flr = floor(x);
+       cil = ceil(x);
+
+       if (x-flr > 0.5)
+               rounded = cil;
+       else if (x-flr == 0.5) 
+               rounded = fmod(flr, 2) == 0 ? flr : cil;
+       else
+               rounded = flr;
+
+       if (ndigits != UNDEF_NDIGITS) {
+               if (ndigits < 0)
+                       rounded *= f;
+               else
+                       rounded /= f;
+               return PyFloat_FromDouble(rounded);
+       }
+
+       return PyLong_FromDouble(rounded);
+#undef UNDEF_NDIGITS
+}
+
 static PyObject *
 float_float(PyObject *v)
 {
@@ -976,6 +1018,11 @@ float_getzero(PyObject *v, void *closure)
 static PyMethodDef float_methods[] = {
        {"conjugate",   (PyCFunction)float_float,       METH_NOARGS,
         "Returns self, the complex conjugate of any float."},
+       {"__trunc__",   (PyCFunction)float_trunc, METH_NOARGS,
+         "Returns the Integral closest to x between 0 and x."},
+       {"__round__",   (PyCFunction)float_round, METH_VARARGS,
+         "Returns the Integral closest to x, rounding half toward even.\n"
+         "When an argument is passed, works like built-in round(x, ndigits)."},
        {"__getnewargs__",      (PyCFunction)float_getnewargs,  METH_NOARGS},
        {"__getformat__",       (PyCFunction)float_getformat,   
         METH_O|METH_CLASS,             float_getformat_doc},
@@ -1020,8 +1067,8 @@ static PyNumberMethods float_as_number = {
        0,              /*nb_xor*/
        0,              /*nb_or*/
        (coercion)0,    /*nb_coerce*/
-       float_int,      /*nb_int*/
-       float_long,     /*nb_long*/
+       float_trunc,    /*nb_int*/
+       float_trunc,    /*nb_long*/
        float_float,    /*nb_float*/
        0,              /* nb_oct */
        0,              /* nb_hex */
index 518e60763a3df2c7b5282fa7966fa7082ec02e78..ddf359d0eac2a136c6d7490cc4f9081e21f05120 100644 (file)
@@ -3592,9 +3592,45 @@ long_getN(PyLongObject *v, void *context) {
        return PyLong_FromLong((intptr_t)context);
 }
 
+static PyObject *
+long_round(PyObject *self, PyObject *args)
+{
+#define UNDEF_NDIGITS (-0x7fffffff) /* Unlikely ndigits value */
+       int ndigits = UNDEF_NDIGITS;
+       double x;
+       PyObject *res;
+       
+       if (!PyArg_ParseTuple(args, "|i", &ndigits))
+               return NULL;
+
+       if (ndigits == UNDEF_NDIGITS)
+               return long_long(self);
+
+       /* If called with two args, defer to float.__round__(). */
+       x = PyLong_AsDouble(self);
+       if (x == -1.0 && PyErr_Occurred())
+               return NULL;
+       self = PyFloat_FromDouble(x);
+       if (self == NULL)
+               return NULL;
+       res = PyObject_CallMethod(self, "__round__", "i", ndigits);
+       Py_DECREF(self);
+       return res;
+#undef UNDEF_NDIGITS
+}
+
 static PyMethodDef long_methods[] = {
        {"conjugate",   (PyCFunction)long_long, METH_NOARGS,
         "Returns self, the complex conjugate of any int."},
+       {"__trunc__",   (PyCFunction)long_long, METH_NOARGS,
+         "Truncating an Integral returns itself."},
+       {"__floor__",   (PyCFunction)long_long, METH_NOARGS,
+         "Flooring an Integral returns itself."},
+       {"__ceil__",    (PyCFunction)long_long, METH_NOARGS,
+         "Ceiling of an Integral returns itself."},
+       {"__round__",   (PyCFunction)long_round, METH_VARARGS,
+         "Rounding an Integral returns itself.\n"
+        "Rounding with an ndigits arguments defers to float.__round__."},
        {"__getnewargs__",      (PyCFunction)long_getnewargs,   METH_NOARGS},
        {NULL,          NULL}           /* sentinel */
 };
index b55dd5194803c6ccf8853f637baa77935dc9ab93..9bbf64b75bf30912f1fd4b77629070d67ed187e0 100644 (file)
@@ -1373,63 +1373,44 @@ For most object types, eval(repr(object)) == object.");
 static PyObject *
 builtin_round(PyObject *self, PyObject *args, PyObject *kwds)
 {
-       double number;
-       double f;
-       int ndigits = 0;
-       int i;
+#define UNDEF_NDIGITS (-0x7fffffff) /* Unlikely ndigits value */
+       static PyObject *round_str = NULL;
+       int ndigits = UNDEF_NDIGITS;
        static char *kwlist[] = {"number", "ndigits", 0};
-       PyObject* real;
+       PyObject *number, *round;
 
        if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|i:round",
-                kwlist, &real, &ndigits))
+                kwlist, &number, &ndigits))
                 return NULL;
 
-       if (ndigits == 0) {
-               PyObject *res;
-               PyObject *d = PyObject_GetAttrString(real, "__round__");
-               if (d == NULL && !PyFloat_Check(real)) {
-                       PyErr_SetString(PyExc_TypeError,
-                                       "round() argument must have __round__ attribute or be a float");
+       if (round_str == NULL) {
+               round_str = PyUnicode_FromString("__round__");
+               if (round_str == NULL)
                        return NULL;
-               } 
-               if (d == NULL) {
-                       PyErr_Clear();
-               } else {
-                       res = PyObject_CallFunction(d, "");
-                       Py_DECREF(d);
-                       return res;
-               } 
-       } else if (!PyFloat_Check(real)) {
-               PyErr_SetString(PyExc_TypeError,
-                               "round() argument must have __round__ attribute or be a float");
+       }
+
+       round = _PyType_Lookup(Py_Type(number), round_str);
+       if (round == NULL) {
+               PyErr_Format(PyExc_TypeError,
+                            "type %.100s doesn't define __round__ method",
+                            Py_Type(number)->tp_name);
                return NULL;
        }
 
-       number = PyFloat_AsDouble(real);
-       f = 1.0;
-       i = abs(ndigits);
-       while  (--i >= 0)
-               f = f*10.0;
-       if (ndigits < 0)
-               number /= f;
-       else
-               number *= f;
-       if (number >= 0.0)
-               number = floor(number + 0.5);
+       if (ndigits == UNDEF_NDIGITS)
+                return PyObject_CallFunction(round, "O", number);
        else
-               number = ceil(number - 0.5);
-       if (ndigits < 0)
-               number *= f;
-       else
-               number /= f;
-       return PyFloat_FromDouble(number);
+                return PyObject_CallFunction(round, "Oi", number, ndigits);
+#undef UNDEF_NDIGITS
 }
 
 PyDoc_STRVAR(round_doc,
 "round(number[, ndigits]) -> floating point number\n\
 \n\
 Round a number to a given precision in decimal digits (default 0 digits).\n\
-This always returns a floating point number.  Precision may be negative.");
+This returns an int when called with one argument, otherwise a float.\n\
+Precision may be negative.");
+
 
 static PyObject *
 builtin_sorted(PyObject *self, PyObject *args, PyObject *kwds)
@@ -1511,18 +1492,25 @@ Without arguments, equivalent to locals().\n\
 With an argument, equivalent to object.__dict__.");
 
 static PyObject *
-builtin_trunc(PyObject *self, PyObject *v)
+builtin_trunc(PyObject *self, PyObject *number)
 {
-       PyObject *res;
-       PyObject *d = PyObject_GetAttrString(v, "__trunc__");
-       if (d == NULL) {
-               PyErr_SetString(PyExc_TypeError,
-                   "trunc() argument must have __trunc__ attribute");
+       static PyObject *trunc_str = NULL;
+       PyObject *trunc;
+
+       if (trunc_str == NULL) {
+               trunc_str = PyUnicode_FromString("__trunc__");
+               if (trunc_str == NULL)
+                       return NULL;
+       }
+
+       trunc = _PyType_Lookup(Py_Type(number), trunc_str);
+       if (trunc == NULL) {
+               PyErr_Format(PyExc_TypeError,
+                            "type %.100s doesn't define __trunc__ method",
+                            Py_Type(number)->tp_name);
                return NULL;
        }
-       res = PyObject_CallFunction(d, "");
-       Py_DECREF(d);
-       return res;
+       return PyObject_CallFunction(trunc, "O", number);
 }
 
 PyDoc_STRVAR(trunc_doc,