]> granicus.if.org Git - python/commitdiff
Make int() and long() fall back to __trunc__(). See issue 2002.
authorJeffrey Yasskin <jyasskin@gmail.com>
Mon, 4 Feb 2008 01:04:35 +0000 (01:04 +0000)
committerJeffrey Yasskin <jyasskin@gmail.com>
Mon, 4 Feb 2008 01:04:35 +0000 (01:04 +0000)
Include/abstract.h
Lib/rational.py
Lib/test/test_builtin.py
Objects/abstract.c
Objects/classobject.c

index b7fde09d528b5b959ff113caee4e6a5d83bfa5f3..e6cbb7b5be4e727ff0c6ab7e92544fdf313c1bd6 100644 (file)
@@ -760,6 +760,19 @@ xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx*/
 
      PyAPI_FUNC(Py_ssize_t) PyNumber_AsSsize_t(PyObject *o, PyObject *exc);
 
+       /*
+         Returns the Integral instance converted to an int. The
+         instance is expected to be int or long or have an __int__
+         method. Steals integral's reference. error_format will be
+         used to create the TypeError if integral isn't actually an
+         Integral instance. error_format should be a format string
+         that can accept a char* naming integral's type.
+       */
+
+     PyAPI_FUNC(PyObject *) _PyNumber_ConvertIntegralToInt(
+             PyObject *integral,
+             const char* error_format);
+
        /*
         Returns the object converted to Py_ssize_t by going through
         PyNumber_Index first.  If an overflow error occurs while
index c76cba3d074a302ecf26979f0c4ba24e74bc9575..dcdaad494b7d04f6f66766fcd9abf1d69fc63883 100755 (executable)
@@ -424,8 +424,6 @@ class Rational(RationalAbc):
         else:
             return a.numerator // a.denominator
 
-    __int__ = __trunc__
-
     def __hash__(self):
         """hash(self)
 
index cfc900335d9b11a4fcaec691db4cb975d43b2b8d..9612a4b687299c8f52d2c4d2491b061b420cbfbb 100644 (file)
@@ -934,6 +934,14 @@ class BuiltinTest(unittest.TestCase):
 
     def test_intconversion(self):
         # Test __int__()
+        class ClassicMissingMethods:
+            pass
+        self.assertRaises(AttributeError, int, ClassicMissingMethods())
+
+        class MissingMethods(object):
+            pass
+        self.assertRaises(TypeError, int, MissingMethods())
+
         class Foo0:
             def __int__(self):
                 return 42
@@ -965,6 +973,49 @@ class BuiltinTest(unittest.TestCase):
         self.assertEqual(int(Foo4()), 42L)
         self.assertRaises(TypeError, int, Foo5())
 
+        class Classic:
+            pass
+        for base in (object, Classic):
+            class IntOverridesTrunc(base):
+                def __int__(self):
+                    return 42
+                def __trunc__(self):
+                    return -12
+            self.assertEqual(int(IntOverridesTrunc()), 42)
+
+            class JustTrunc(base):
+                def __trunc__(self):
+                    return 42
+            self.assertEqual(int(JustTrunc()), 42)
+
+            for trunc_result_base in (object, Classic):
+                class Integral(trunc_result_base):
+                    def __int__(self):
+                        return 42
+
+                class TruncReturnsNonInt(base):
+                    def __trunc__(self):
+                        return Integral()
+                self.assertEqual(int(TruncReturnsNonInt()), 42)
+
+                class NonIntegral(trunc_result_base):
+                    def __trunc__(self):
+                        # Check that we avoid infinite recursion.
+                        return NonIntegral()
+
+                class TruncReturnsNonIntegral(base):
+                    def __trunc__(self):
+                        return NonIntegral()
+                try:
+                    int(TruncReturnsNonIntegral())
+                except TypeError as e:
+                    self.assertEquals(str(e),
+                                      "__trunc__ returned non-Integral"
+                                      " (type NonIntegral)")
+                else:
+                    self.fail("Failed to raise TypeError with %s" %
+                              ((base, trunc_result_base),))
+
     def test_intern(self):
         self.assertRaises(TypeError, intern)
         s = "never interned before"
@@ -1207,6 +1258,14 @@ class BuiltinTest(unittest.TestCase):
 
     def test_longconversion(self):
         # Test __long__()
+        class ClassicMissingMethods:
+            pass
+        self.assertRaises(AttributeError, long, ClassicMissingMethods())
+
+        class MissingMethods(object):
+            pass
+        self.assertRaises(TypeError, long, MissingMethods())
+
         class Foo0:
             def __long__(self):
                 return 42L
@@ -1238,6 +1297,49 @@ class BuiltinTest(unittest.TestCase):
         self.assertEqual(long(Foo4()), 42)
         self.assertRaises(TypeError, long, Foo5())
 
+        class Classic:
+            pass
+        for base in (object, Classic):
+            class LongOverridesTrunc(base):
+                def __long__(self):
+                    return 42
+                def __trunc__(self):
+                    return -12
+            self.assertEqual(long(LongOverridesTrunc()), 42)
+
+            class JustTrunc(base):
+                def __trunc__(self):
+                    return 42
+            self.assertEqual(long(JustTrunc()), 42)
+
+            for trunc_result_base in (object, Classic):
+                class Integral(trunc_result_base):
+                    def __int__(self):
+                        return 42
+
+                class TruncReturnsNonLong(base):
+                    def __trunc__(self):
+                        return Integral()
+                self.assertEqual(long(TruncReturnsNonLong()), 42)
+
+                class NonIntegral(trunc_result_base):
+                    def __trunc__(self):
+                        # Check that we avoid infinite recursion.
+                        return NonIntegral()
+
+                class TruncReturnsNonIntegral(base):
+                    def __trunc__(self):
+                        return NonIntegral()
+                try:
+                    long(TruncReturnsNonIntegral())
+                except TypeError as e:
+                    self.assertEquals(str(e),
+                                      "__trunc__ returned non-Integral"
+                                      " (type NonIntegral)")
+                else:
+                    self.fail("Failed to raise TypeError with %s" %
+                              ((base, trunc_result_base),))
+
     def test_map(self):
         self.assertEqual(
             map(None, 'hello world'),
index 830fe8217956901941a39f22e17b27e3cacd032c..a3e159a105e016bf15c11473c0c33b1237629432 100644 (file)
@@ -1034,13 +1034,65 @@ PyNumber_AsSsize_t(PyObject *item, PyObject *err)
 }
 
 
+PyObject *
+_PyNumber_ConvertIntegralToInt(PyObject *integral, const char* error_format)
+{
+       const char *type_name;
+       static PyObject *int_name = NULL;
+       if (int_name == NULL) {
+               int_name = PyString_InternFromString("__int__");
+               if (int_name == NULL)
+                       return NULL;
+       }
+
+       if (integral && (!PyInt_Check(integral) &&
+                        !PyLong_Check(integral))) {
+               /* Don't go through tp_as_number->nb_int to avoid
+                  hitting the classic class fallback to __trunc__. */
+               PyObject *int_func = PyObject_GetAttr(integral, int_name);
+               if (int_func == NULL) {
+                       PyErr_Clear(); /* Raise a different error. */
+                       goto non_integral_error;
+               }
+               Py_DECREF(integral);
+               integral = PyEval_CallObject(int_func, NULL);
+               Py_DECREF(int_func);
+               if (integral && (!PyInt_Check(integral) &&
+                                 !PyLong_Check(integral))) {
+                       goto non_integral_error;
+               }
+       }
+       return integral;
+
+non_integral_error:
+       if (PyInstance_Check(integral)) {
+               type_name = PyString_AS_STRING(((PyInstanceObject *)integral)
+                                              ->in_class->cl_name);
+       }
+       else {
+               type_name = integral->ob_type->tp_name;
+       }
+       PyErr_Format(PyExc_TypeError, error_format, type_name);
+       Py_DECREF(integral);
+       return NULL;
+}
+
+
 PyObject *
 PyNumber_Int(PyObject *o)
 {
        PyNumberMethods *m;
+       static PyObject *trunc_name = NULL;
+       PyObject *trunc_func;
        const char *buffer;
        Py_ssize_t buffer_len;
 
+       if (trunc_name == NULL) {
+               trunc_name = PyString_InternFromString("__trunc__");
+               if (trunc_name == NULL)
+                       return NULL;
+       }
+
        if (o == NULL)
                return null_error();
        if (PyInt_CheckExact(o)) {
@@ -1049,6 +1101,7 @@ PyNumber_Int(PyObject *o)
        }
        m = o->ob_type->tp_as_number;
        if (m && m->nb_int) { /* This should include subclasses of int */
+               /* Classic classes always take this branch. */
                PyObject *res = m->nb_int(o);
                if (res && (!PyInt_Check(res) && !PyLong_Check(res))) {
                        PyErr_Format(PyExc_TypeError,
@@ -1063,6 +1116,18 @@ PyNumber_Int(PyObject *o)
                PyIntObject *io = (PyIntObject*)o;
                return PyInt_FromLong(io->ob_ival);
        }
+       trunc_func = PyObject_GetAttr(o, trunc_name);
+       if (trunc_func) {
+               PyObject *truncated = PyEval_CallObject(trunc_func, NULL);
+               Py_DECREF(trunc_func);
+               /* __trunc__ is specified to return an Integral type, but
+                  int() needs to return an int. */
+               return _PyNumber_ConvertIntegralToInt(
+                       truncated,
+                       "__trunc__ returned non-Integral (type %.200s)");
+       }
+       PyErr_Clear();  /* It's not an error if  o.__trunc__ doesn't exist. */
+
        if (PyString_Check(o))
                return int_from_string(PyString_AS_STRING(o),
                                       PyString_GET_SIZE(o));
@@ -1102,13 +1167,22 @@ PyObject *
 PyNumber_Long(PyObject *o)
 {
        PyNumberMethods *m;
+       static PyObject *trunc_name = NULL;
+       PyObject *trunc_func;
        const char *buffer;
        Py_ssize_t buffer_len;
 
+       if (trunc_name == NULL) {
+               trunc_name = PyString_InternFromString("__trunc__");
+               if (trunc_name == NULL)
+                       return NULL;
+       }
+
        if (o == NULL)
                return null_error();
        m = o->ob_type->tp_as_number;
        if (m && m->nb_long) { /* This should include subclasses of long */
+               /* Classic classes always take this branch. */
                PyObject *res = m->nb_long(o);
                if (res && (!PyInt_Check(res) && !PyLong_Check(res))) {
                        PyErr_Format(PyExc_TypeError,
@@ -1121,6 +1195,26 @@ PyNumber_Long(PyObject *o)
        }
        if (PyLong_Check(o)) /* A long subclass without nb_long */
                return _PyLong_Copy((PyLongObject *)o);
+       trunc_func = PyObject_GetAttr(o, trunc_name);
+       if (trunc_func) {
+               PyObject *truncated = PyEval_CallObject(trunc_func, NULL);
+               PyObject *int_instance;
+               Py_DECREF(trunc_func);
+               /* __trunc__ is specified to return an Integral type,
+                  but long() needs to return a long. */
+               int_instance = _PyNumber_ConvertIntegralToInt(
+                       truncated,
+                       "__trunc__ returned non-Integral (type %.200s)");
+               if (int_instance && PyInt_Check(int_instance)) {
+                       /* Make sure that long() returns a long instance. */
+                       long value = PyInt_AS_LONG(int_instance);
+                       Py_DECREF(int_instance);
+                       return PyLong_FromLong(value);
+               }
+               return int_instance;
+       }
+       PyErr_Clear();  /* It's not an error if  o.__trunc__ doesn't exist. */
+
        if (PyString_Check(o))
                /* need to do extra error checking that PyLong_FromString()
                 * doesn't do.  In particular long('9.5') must raise an
index b4b17f90777a1d87b9a3893c3f6746749b333c88..9f364e2e88a6ec983590a51fd84e209b9d5e5f2e 100644 (file)
@@ -1798,7 +1798,29 @@ instance_index(PyInstanceObject *self)
 
 
 UNARY(instance_invert, "__invert__")
-UNARY(instance_int, "__int__")
+UNARY(_instance_trunc, "__trunc__")
+
+static PyObject *
+instance_int(PyInstanceObject *self)
+{
+       PyObject *truncated;
+       static PyObject *int_name;
+       if (int_name == NULL) {
+               int_name = PyString_InternFromString("__int__");
+               if (int_name == NULL)
+                       return NULL;
+       }
+       if (PyObject_HasAttr((PyObject*)self, int_name))
+               return generic_unary_op(self, int_name);
+
+       truncated = _instance_trunc(self);
+       /* __trunc__ is specified to return an Integral type, but
+          int() needs to return an int. */
+       return _PyNumber_ConvertIntegralToInt(
+               truncated,
+               "__trunc__ returned non-Integral (type %.200s)");
+}
+
 UNARY_FB(instance_long, "__long__", instance_int)
 UNARY(instance_float, "__float__")
 UNARY(instance_oct, "__oct__")