]> granicus.if.org Git - python/commitdiff
Hopefully fix 3-way comparisons. This unfortunately adds yet another
authorGuido van Rossum <guido@python.org>
Tue, 18 Sep 2001 20:38:53 +0000 (20:38 +0000)
committerGuido van Rossum <guido@python.org>
Tue, 18 Sep 2001 20:38:53 +0000 (20:38 +0000)
hack, and it's even more disgusting than a PyInstance_Check() call.
If the tp_compare slot is the slot used for overrides in Python,
it's always called.

Add some tests that show what should work too.

Include/object.h
Lib/test/test_descr.py
Objects/object.c
Objects/typeobject.c

index d9c35144b32bb6cf8fac087d156e3cd61107e370..160331ed23c83a975efca9e2bbc0c33674fa777c 100644 (file)
@@ -346,6 +346,10 @@ extern DL_IMPORT(int) PyNumber_CoerceEx(PyObject **, PyObject **);
 
 extern DL_IMPORT(void) (*PyObject_ClearWeakRefs)(PyObject *);
 
+/* A slot function whose address we need to compare */
+extern int _PyObject_SlotCompare(PyObject *, PyObject *);
+
+
 /* PyObject_Dir(obj) acts like Python __builtin__.dir(obj), returning a
    list of strings.  PyObject_Dir(NULL) is like __builtin__.dir(),
    returning the names of the current locals.  In this case, if there are
index fc00318673765576aef619e07d649c0ec3ffb12d..bd046052a592fbf4d9dbf1ea21a05e0b6a6e29b3 100644 (file)
@@ -1831,6 +1831,33 @@ def str_subclass_as_dict_key():
     verify(cistr('ONe') in d)
     verify(d.get(cistr('thrEE')) == 3)
 
+def classic_comparisons():
+    if verbose: print "Testing classic comparisons..."
+    for base in (int, object):
+        if verbose: print "        (base = %s)" % base
+        class C(base):
+            def __init__(self, value):
+                self.value = int(value)
+            def __cmp__(self, other):
+                if isinstance(other, C):
+                    return cmp(self.value, other.value)
+                if isinstance(other, int) or isinstance(other, long):
+                    return cmp(self.value, other)
+                return NotImplemented
+        c1 = C(1)
+        c2 = C(2)
+        c3 = C(3)
+        verify(c1 == 1)
+        c = {1: c1, 2: c2, 3: c3}
+        for x in 1, 2, 3:
+            for y in 1, 2, 3:
+                verify(cmp(c[x], c[y]) == cmp(x, y), "x=%d, y=%d" % (x, y))
+                for op in "<", "<=", "==", "!=", ">", ">=":
+                    verify(eval("c[x] %s c[y]" % op) == eval("x %s y" % op),
+                           "x=%d, y=%d" % (x, y))
+                verify(cmp(c[x], y) == cmp(x, y), "x=%d, y=%d" % (x, y))
+                verify(cmp(x, c[y]) == cmp(x, y), "x=%d, y=%d" % (x, y))
+
 
 def all():
     lists()
@@ -1869,6 +1896,7 @@ def all():
     keywords()
     restricted()
     str_subclass_as_dict_key()
+    classic_comparisons()
 
 all()
 
index c56c3be9175a613a59e2be27fba5862552521c3e..668bd4f33246fee6957f028877bafce94a0973db 100644 (file)
@@ -455,11 +455,25 @@ try_3way_compare(PyObject *v, PyObject *w)
        /* Comparisons involving instances are given to instance_compare,
           which has the same return conventions as this function. */
 
+       f = v->ob_type->tp_compare;
        if (PyInstance_Check(v))
-               return (*v->ob_type->tp_compare)(v, w);
+               return (*f)(v, w);
        if (PyInstance_Check(w))
                return (*w->ob_type->tp_compare)(v, w);
 
+       /* If both have the same (non-NULL) tp_compare, use it. */
+       if (f != NULL && f == w->ob_type->tp_compare) {
+               c = (*f)(v, w);
+               if (c < 0 && PyErr_Occurred())
+                       return -1;
+               return c < 0 ? -1 : c > 0 ? 1 : 0;
+       }
+
+       /* If either tp_compare is _PyObject_SlotCompare, that's safe. */
+       if (f == _PyObject_SlotCompare ||
+           w->ob_type->tp_compare == _PyObject_SlotCompare)
+               return _PyObject_SlotCompare(v, w);
+
        /* Try coercion; if it fails, give up */
        c = PyNumber_CoerceEx(&v, &w);
        if (c < 0)
index 792a9f3c286842fdc56329820217197d88ff802e..26ddabe0c267f23647174a21943e5e0bc5669045 100644 (file)
@@ -2761,17 +2761,18 @@ half_compare(PyObject *self, PyObject *other)
        return 2;
 }
 
-static int
-slot_tp_compare(PyObject *self, PyObject *other)
+/* This slot is published for the benefit of try_3way_compare in object.c */
+int
+_PyObject_SlotCompare(PyObject *self, PyObject *other)
 {
        int c;
 
-       if (self->ob_type->tp_compare == slot_tp_compare) {
+       if (self->ob_type->tp_compare == _PyObject_SlotCompare) {
                c = half_compare(self, other);
                if (c <= 1)
                        return c;
        }
-       if (other->ob_type->tp_compare == slot_tp_compare) {
+       if (other->ob_type->tp_compare == _PyObject_SlotCompare) {
                c = half_compare(other, self);
                if (c < -1)
                        return -2;
@@ -3190,7 +3191,7 @@ override_slots(PyTypeObject *type, PyObject *dict)
            PyDict_GetItemString(dict, "__repr__"))
                type->tp_print = NULL;
 
-       TPSLOT("__cmp__", tp_compare, slot_tp_compare);
+       TPSLOT("__cmp__", tp_compare, _PyObject_SlotCompare);
        TPSLOT("__repr__", tp_repr, slot_tp_repr);
        TPSLOT("__hash__", tp_hash, slot_tp_hash);
        TPSLOT("__call__", tp_call, slot_tp_call);