]> granicus.if.org Git - python/commitdiff
Issue #6970: Remove redundant calls made when comparing objects.
authorMark Dickinson <dickinsm@gmail.com>
Sun, 15 Nov 2009 13:58:49 +0000 (13:58 +0000)
committerMark Dickinson <dickinsm@gmail.com>
Sun, 15 Nov 2009 13:58:49 +0000 (13:58 +0000)
Lib/test/test_binop.py
Misc/NEWS
Objects/object.c
Objects/typeobject.c

index 5d77f234fe6f8320c6c5f68fe383a0b58661b22c..c7987b2013ae43901dbc0149d665929e88b95509 100644 (file)
@@ -2,6 +2,7 @@
 
 import unittest
 from test import support
+from operator import eq, ne, lt, gt, le, ge
 
 def gcd(a, b):
     """Greatest common divisor using Euclid's algorithm."""
@@ -305,9 +306,78 @@ class RatTestCase(unittest.TestCase):
 
     # XXX Ran out of steam; TO DO: divmod, div, future division
 
-def test_main():
-    support.run_unittest(RatTestCase)
 
+class OperationLogger:
+    """Base class for classes with operation logging."""
+    def __init__(self, logger):
+        self.logger = logger
+    def log_operation(self, *args):
+        self.logger(*args)
+
+def op_sequence(op, *classes):
+    """Return the sequence of operations that results from applying
+    the operation `op` to instances of the given classes."""
+    log = []
+    instances = []
+    for c in classes:
+        instances.append(c(log.append))
+
+    try:
+        op(*instances)
+    except TypeError:
+        pass
+    return log
+
+class A(OperationLogger):
+    def __eq__(self, other):
+        self.log_operation('A.__eq__')
+        return NotImplemented
+    def __le__(self, other):
+        self.log_operation('A.__le__')
+        return NotImplemented
+    def __ge__(self, other):
+        self.log_operation('A.__ge__')
+        return NotImplemented
+
+class B(OperationLogger):
+    def __eq__(self, other):
+        self.log_operation('B.__eq__')
+        return NotImplemented
+    def __le__(self, other):
+        self.log_operation('B.__le__')
+        return NotImplemented
+    def __ge__(self, other):
+        self.log_operation('B.__ge__')
+        return NotImplemented
+
+class C(B):
+    def __eq__(self, other):
+        self.log_operation('C.__eq__')
+        return NotImplemented
+    def __le__(self, other):
+        self.log_operation('C.__le__')
+        return NotImplemented
+    def __ge__(self, other):
+        self.log_operation('C.__ge__')
+        return NotImplemented
+
+class OperationOrderTests(unittest.TestCase):
+    def test_comparison_orders(self):
+        self.assertEqual(op_sequence(eq, A, A), ['A.__eq__', 'A.__eq__'])
+        self.assertEqual(op_sequence(eq, A, B), ['A.__eq__', 'B.__eq__'])
+        self.assertEqual(op_sequence(eq, B, A), ['B.__eq__', 'A.__eq__'])
+        # C is a subclass of B, so C.__eq__ is called first
+        self.assertEqual(op_sequence(eq, B, C), ['C.__eq__', 'B.__eq__'])
+        self.assertEqual(op_sequence(eq, C, B), ['C.__eq__', 'B.__eq__'])
+
+        self.assertEqual(op_sequence(le, A, A), ['A.__le__', 'A.__ge__'])
+        self.assertEqual(op_sequence(le, A, B), ['A.__le__', 'B.__ge__'])
+        self.assertEqual(op_sequence(le, B, A), ['B.__le__', 'A.__ge__'])
+        self.assertEqual(op_sequence(le, B, C), ['C.__ge__', 'B.__le__'])
+        self.assertEqual(op_sequence(le, C, B), ['C.__le__', 'B.__ge__'])
+
+def test_main():
+    support.run_unittest(RatTestCase, OperationOrderTests)
 
 if __name__ == "__main__":
     test_main()
index 1577baf80b69457db78ce57b439bb60a7cce8cb3..9480d9f61e6bad4ca15777f29c50e7e48beb9503 100644 (file)
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -12,6 +12,9 @@ What's New in Python 3.2 Alpha 1?
 Core and Builtins
 -----------------
 
+- Issue #6970: Remove redundant calls when comparing objects that don't
+  implement the relevant rich comparison methods.
+
 - Issue #7298: fixes for range and reversed(range(...)).  Iteration
   over range(a, b, c) incorrectly gave an empty iterator when a, b and
   c fit in C long but the length of the range did not.  Also fix
index 90cdc74195a875cf66bff5fff364a899c64a4930..002acd04be68a8109e4073c5b27e78a36f437e5e 100644 (file)
@@ -544,10 +544,12 @@ do_richcompare(PyObject *v, PyObject *w, int op)
 {
        richcmpfunc f;
        PyObject *res;
+       int checked_reverse_op = 0;
 
        if (v->ob_type != w->ob_type &&
            PyType_IsSubtype(w->ob_type, v->ob_type) &&
            (f = w->ob_type->tp_richcompare) != NULL) {
+               checked_reverse_op = 1;
                res = (*f)(w, v, _Py_SwappedOp[op]);
                if (res != Py_NotImplemented)
                        return res;
@@ -559,7 +561,7 @@ do_richcompare(PyObject *v, PyObject *w, int op)
                        return res;
                Py_DECREF(res);
        }
-       if ((f = w->ob_type->tp_richcompare) != NULL) {
+       if (!checked_reverse_op && (f = w->ob_type->tp_richcompare) != NULL) {
                res = (*f)(w, v, _Py_SwappedOp[op]);
                if (res != Py_NotImplemented)
                        return res;
index 24866ff465f6be31edaea026e2a1fd8a94c41da2..be4b6f861ef7729675af771de9fed0ee5ec06771 100644 (file)
@@ -5068,7 +5068,7 @@ static char *name_op[] = {
 };
 
 static PyObject *
-half_richcompare(PyObject *self, PyObject *other, int op)
+slot_tp_richcompare(PyObject *self, PyObject *other, int op)
 {
        PyObject *func, *args, *res;
        static PyObject *op_str[6];
@@ -5090,28 +5090,6 @@ half_richcompare(PyObject *self, PyObject *other, int op)
        return res;
 }
 
-static PyObject *
-slot_tp_richcompare(PyObject *self, PyObject *other, int op)
-{
-       PyObject *res;
-
-       if (Py_TYPE(self)->tp_richcompare == slot_tp_richcompare) {
-               res = half_richcompare(self, other, op);
-               if (res != Py_NotImplemented)
-                       return res;
-               Py_DECREF(res);
-       }
-       if (Py_TYPE(other)->tp_richcompare == slot_tp_richcompare) {
-               res = half_richcompare(other, self, _Py_SwappedOp[op]);
-               if (res != Py_NotImplemented) {
-                       return res;
-               }
-               Py_DECREF(res);
-       }
-       Py_INCREF(Py_NotImplemented);
-       return Py_NotImplemented;
-}
-
 static PyObject *
 slot_tp_iter(PyObject *self)
 {