]> granicus.if.org Git - python/commitdiff
Rich comparisons.
authorGuido van Rossum <guido@python.org>
Wed, 17 Jan 2001 15:28:20 +0000 (15:28 +0000)
committerGuido van Rossum <guido@python.org>
Wed, 17 Jan 2001 15:28:20 +0000 (15:28 +0000)
- Got rid of instance_cmp(); refactored instance_compare().

- Added instance_richcompare() which calls __lt__() etc.

Some unrelated stuff mixed in:

- Aligned comments in various large struct initializers.

- Better test to avoid recursion if __coerce__ returns self as the
  first argument (this is an unrelated fix by Neil Schemenauer!).

- Style nit: don't use Py_DECREF(Py_NotImplemented); use
  Py_DECREF(result) -- it just looks better. :-)

Objects/classobject.c

index 4dc72d24c0ef9ef4803fcc84e22adc046d0e4078..7f76d6ed5ab458812c6ad03092d637953a860b47 100644 (file)
@@ -386,24 +386,24 @@ PyTypeObject PyClass_Type = {
        "class",
        sizeof(PyClassObject) + PyGC_HEAD_SIZE,
        0,
-       (destructor)class_dealloc, /*tp_dealloc*/
-       0,              /*tp_print*/
-       0,              /*tp_getattr*/
-       0,              /*tp_setattr*/
-       0,              /*tp_compare*/
-       (reprfunc)class_repr, /*tp_repr*/
-       0,              /*tp_as_number*/
-       0,              /*tp_as_sequence*/
-       0,              /*tp_as_mapping*/
-       0,              /*tp_hash*/
-       0,              /*tp_call*/
-       (reprfunc)class_str, /*tp_str*/
-       (getattrofunc)class_getattr, /*tp_getattro*/
-       (setattrofunc)class_setattr, /*tp_setattro*/
-       0,              /* tp_as_buffer */
-       Py_TPFLAGS_DEFAULT | Py_TPFLAGS_GC, /*tp_flags*/
-       0,              /* tp_doc */
-       (traverseproc)class_traverse,   /* tp_traverse */
+       (destructor)class_dealloc,              /* tp_dealloc */
+       0,                                      /* tp_print */
+       0,                                      /* tp_getattr */
+       0,                                      /* tp_setattr */
+       0,                                      /* tp_compare */
+       (reprfunc)class_repr,                   /* tp_repr */
+       0,                                      /* tp_as_number */
+       0,                                      /* tp_as_sequence */
+       0,                                      /* tp_as_mapping */
+       0,                                      /* tp_hash */
+       0,                                      /* tp_call */
+       (reprfunc)class_str,                    /* tp_str */
+       (getattrofunc)class_getattr,            /* tp_getattro */
+       (setattrofunc)class_setattr,            /* tp_setattro */
+       0,                                      /* tp_as_buffer */
+       Py_TPFLAGS_DEFAULT | Py_TPFLAGS_GC,     /* tp_flags */
+       0,                                      /* tp_doc */
+       (traverseproc)class_traverse,           /* tp_traverse */
 };
 
 int
@@ -909,9 +909,9 @@ instance_ass_subscript(PyInstanceObject *inst, PyObject *key, PyObject *value)
 }
 
 static PyMappingMethods instance_as_mapping = {
-       (inquiry)instance_length, /*mp_length*/
-       (binaryfunc)instance_subscript, /*mp_subscript*/
-       (objobjargproc)instance_ass_subscript, /*mp_ass_subscript*/
+       (inquiry)instance_length,               /* mp_length */
+       (binaryfunc)instance_subscript,         /* mp_subscript */
+       (objobjargproc)instance_ass_subscript,  /* mp_ass_subscript */
 };
 
 static PyObject *
@@ -1134,14 +1134,14 @@ static int instance_contains(PyInstanceObject *inst, PyObject *member)
 
 static PySequenceMethods
 instance_as_sequence = {
-       (inquiry)instance_length, /*sq_length*/
-       0, /*sq_concat*/
-       0, /*sq_repeat*/
-       (intargfunc)instance_item, /*sq_item*/
-       (intintargfunc)instance_slice, /*sq_slice*/
-       (intobjargproc)instance_ass_item, /*sq_ass_item*/
-       (intintobjargproc)instance_ass_slice, /*sq_ass_slice*/
-       (objobjproc)instance_contains, /* sq_contains */
+       (inquiry)instance_length,               /* sq_length */
+       0,                                      /* sq_concat */
+       0,                                      /* sq_repeat */
+       (intargfunc)instance_item,              /* sq_item */
+       (intintargfunc)instance_slice,          /* sq_slice */
+       (intobjargproc)instance_ass_item,       /* sq_ass_item */
+       (intintobjargproc)instance_ass_slice,   /* sq_ass_slice */
+       (objobjproc)instance_contains,          /* sq_contains */
 };
 
 static PyObject *
@@ -1232,10 +1232,10 @@ half_binop(PyObject *v, PyObject *w, char *opname, binaryfunc thisfunc,
        }
        v1 = PyTuple_GetItem(coerced, 0);
        w = PyTuple_GetItem(coerced, 1);
-       if (v1 == v) {
+       if (v1->ob_type == v->ob_type && PyInstance_Check(v)) {
                /* prevent recursion if __coerce__ returns self as the first
                 * argument */
-               result = generic_binary_op(v, w, opname);
+               result = generic_binary_op(v1, w, opname);
        } else {
                if (swapped)
                        result = (thisfunc)(w, v1);
@@ -1253,7 +1253,7 @@ do_binop(PyObject *v, PyObject *w, char *opname, char *ropname,
 {
        PyObject *result = half_binop(v, w, opname, thisfunc, 0);
        if (result == Py_NotImplemented) {
-               Py_DECREF(Py_NotImplemented);
+               Py_DECREF(result);
                result = half_binop(w, v, ropname, thisfunc, 1);
        }
        return result;
@@ -1265,7 +1265,7 @@ do_binop_inplace(PyObject *v, PyObject *w, char *iopname, char *opname,
 {
        PyObject *result = half_binop(v, w, iopname, thisfunc, 0);
        if (result == Py_NotImplemented) {
-               Py_DECREF(Py_NotImplemented);
+               Py_DECREF(result);
                result = do_binop(v, w, opname, ropname, thisfunc);
        }
        return result;
@@ -1371,62 +1371,119 @@ BINARY_INPLACE(instance_imul, "mul", PyNumber_InPlaceMultiply)
 BINARY_INPLACE(instance_idiv, "div", PyNumber_InPlaceDivide)
 BINARY_INPLACE(instance_imod, "mod", PyNumber_InPlaceRemainder)
 
-static PyObject *
-do_cmp(PyObject *v, PyObject *w)
+/* Try a 3-way comparison, returning an int; v is an instance.  Return:
+   -2 for an exception;
+   -1 if v < w;
+   0 if v == w;
+   1 if v > w;
+   2 if this particular 3-way comparison is not implemented or undefined.
+*/
+static int
+half_cmp(PyObject *v, PyObject *w)
 {
-       int cmp = PyObject_Compare(v, w);
-       if (PyErr_Occurred()) {
-               return NULL;
+       static PyObject *cmp_obj;
+       PyObject *args;
+       PyObject *cmpfunc;
+       PyObject *result;
+       long l;
+
+       assert(PyInstance_Check(v));
+
+       if (cmp_obj == NULL) {
+               cmp_obj = PyString_InternFromString("__cmp__");
+               if (cmp_obj == NULL)
+                       return -2;
        }
-       return PyInt_FromLong(cmp);
-}
 
-static PyObject *
-instance_cmp(PyObject *v, PyObject *w)
-{
-       PyObject *result = half_binop(v, w, "__cmp__", do_cmp, 0);
+       cmpfunc = PyObject_GetAttr(v, cmp_obj);
+       if (cmpfunc == NULL) {
+               PyErr_Clear();
+               return 2;
+       }
+
+       args = Py_BuildValue("(O)", w);
+       if (args == NULL)
+               return -2;
+
+       result = PyEval_CallObject(cmpfunc, args);
+       Py_DECREF(args);
+       Py_DECREF(cmpfunc);
+
+       if (result == NULL)
+               return -2;
+
        if (result == Py_NotImplemented) {
-               Py_DECREF(Py_NotImplemented);
-                /* __rcmp__ is not called on instances, instead they
-                 * automaticly reverse the arguments and return the negative of
-                 * __cmp__ if it exists */
-               result = half_binop(w, v, "__cmp__", do_cmp, 0);
-                        
-                if (result != Py_NotImplemented && result != NULL) {
-                       PyObject *r = PyNumber_Negative(result);
-                       Py_DECREF(result);
-                       result = r;
-               }
+               Py_DECREF(result);
+               return 2;
        }
-       return result;
+
+       l = PyInt_AsLong(result);
+       Py_DECREF(result);
+       if (l == -1 && PyErr_Occurred()) {
+               PyErr_SetString(PyExc_TypeError,
+                            "comparison did not return an int");
+               return -2;
+       }
+
+       return l < 0 ? -1 : l > 0 ? 1 : 0;
 }
 
+/* Try a 3-way comparison, returning an int; either v or w is an instance.
+   We first try a coercion.  Return:
+   -2 for an exception;
+   -1 if v < w;
+   0 if v == w;
+   1 if v > w;
+   2 if this particular 3-way comparison is not implemented or undefined.
+   THIS IS ONLY CALLED FROM object.c!
+*/
 static int
-instance_compare(PyObject *inst, PyObject *other)
+instance_compare(PyObject *v, PyObject *w)
 {
-       PyObject *result;
-       long outcome;
-       result = instance_cmp(inst, other);
-       if (result == NULL) {
-               return -1;
+       int c;
+
+       c = PyNumber_CoerceEx(&v, &w);
+       if (c < 0)
+               return -2;
+       if (c == 0) {
+               /* If neither is now an instance, use regular comparison */
+               if (!PyInstance_Check(v) && !PyInstance_Check(w)) {
+                       c = PyObject_Compare(v, w);
+                       Py_DECREF(v);
+                       Py_DECREF(w);
+                       if (PyErr_Occurred())
+                               return -2;
+                       return c < 0 ? -1 : c > 0 ? 1 : 0;
+               }
        }
-       if (result == Py_NotImplemented) {
-               Py_DECREF(result);
-               return -1;
+       else {
+               /* The coercion didn't do anything.
+                  Treat this the same as returning v and w unchanged. */
+               Py_INCREF(v);
+               Py_INCREF(w);
        }
-       if (!PyInt_Check(result)) {
-               Py_DECREF(result);
-               PyErr_SetString(PyExc_TypeError,
-                               "comparison did not return an int");
-               return -1;
+
+       if (PyInstance_Check(v)) {
+               c = half_cmp(v, w);
+               if (c <= 1) {
+                       Py_DECREF(v);
+                       Py_DECREF(w);
+                       return c;
+               }
        }
-       outcome = PyInt_AsLong(result);
-       Py_DECREF(result);
-       if (outcome < 0)
-               return -1;
-       else if (outcome > 0)
-               return 1;
-       return 0;
+       if (PyInstance_Check(w)) {
+               c = half_cmp(w, v);
+               if (c <= 1) {
+                       Py_DECREF(v);
+                       Py_DECREF(w);
+                       if (c >= -1)
+                               c = -c;
+                       return c;
+               }
+       }
+       Py_DECREF(v);
+       Py_DECREF(w);
+       return 2;
 }
 
 static int
@@ -1550,43 +1607,116 @@ instance_ipow(PyObject *v, PyObject *w, PyObject *z)
 }
 
 
+/* Map rich comparison operators to their __xx__ namesakes */
+static char *name_op[] = {
+       "__lt__",
+       "__le__",
+       "__eq__",
+       "__ne__",
+       "__gt__",
+       "__ge__",
+};
+
+static PyObject *
+half_richcompare(PyObject *v, PyObject *w, int op)
+{
+       PyObject *name;
+       PyObject *method;
+       PyObject *args;
+       PyObject *res;
+
+       assert(PyInstance_Check(v));
+
+       name = PyString_InternFromString(name_op[op]);
+       if (name == NULL)
+               return NULL;
+
+       method = PyObject_GetAttr(v, name);
+       Py_DECREF(name);
+       if (method == NULL) {
+               if (!PyErr_ExceptionMatches(PyExc_AttributeError))
+                       return NULL;
+               PyErr_Clear();
+               res = Py_NotImplemented;
+               Py_INCREF(res);
+               return res;
+       }
+
+       args = Py_BuildValue("(O)", w);
+       if (args == NULL) {
+               Py_DECREF(method);
+               return NULL;
+       }
+
+       res = PyEval_CallObject(method, args);
+       Py_DECREF(args);
+       Py_DECREF(method);
+
+       return res;
+}
+
+/* Map rich comparison operators to their swapped version, e.g. LT --> GT */
+static int swapped_op[] = {Py_GT, Py_GE, Py_EQ, Py_NE, Py_LT, Py_LE};
+
+static PyObject *
+instance_richcompare(PyObject *v, PyObject *w, int op)
+{
+       PyObject *res;
+
+       if (PyInstance_Check(v)) {
+               res = half_richcompare(v, w, op);
+               if (res != Py_NotImplemented)
+                       return res;
+               Py_DECREF(res);
+       }
+
+       if (PyInstance_Check(w)) {
+               res = half_richcompare(w, v, swapped_op[op]);
+               if (res != Py_NotImplemented)
+                       return res;
+               Py_DECREF(res);
+       }
+
+       Py_INCREF(Py_NotImplemented);
+       return Py_NotImplemented;
+}
+
 
 static PyNumberMethods instance_as_number = {
-       (binaryfunc)instance_add, /*nb_add*/
-       (binaryfunc)instance_sub, /*nb_subtract*/
-       (binaryfunc)instance_mul, /*nb_multiply*/
-       (binaryfunc)instance_div, /*nb_divide*/
-       (binaryfunc)instance_mod, /*nb_remainder*/
-       (binaryfunc)instance_divmod, /*nb_divmod*/
-       (ternaryfunc)instance_pow, /*nb_power*/
-       (unaryfunc)instance_neg, /*nb_negative*/
-       (unaryfunc)instance_pos, /*nb_positive*/
-       (unaryfunc)instance_abs, /*nb_absolute*/
-       (inquiry)instance_nonzero, /*nb_nonzero*/
-       (unaryfunc)instance_invert, /*nb_invert*/
-       (binaryfunc)instance_lshift, /*nb_lshift*/
-       (binaryfunc)instance_rshift, /*nb_rshift*/
-       (binaryfunc)instance_and, /*nb_and*/
-       (binaryfunc)instance_xor, /*nb_xor*/
-       (binaryfunc)instance_or, /*nb_or*/
-       (coercion)instance_coerce, /*nb_coerce*/
-       (unaryfunc)instance_int, /*nb_int*/
-       (unaryfunc)instance_long, /*nb_long*/
-       (unaryfunc)instance_float, /*nb_float*/
-       (unaryfunc)instance_oct, /*nb_oct*/
-       (unaryfunc)instance_hex, /*nb_hex*/
-       (binaryfunc)instance_iadd, /*nb_inplace_add*/
-       (binaryfunc)instance_isub, /*nb_inplace_subtract*/
-       (binaryfunc)instance_imul, /*nb_inplace_multiply*/
-       (binaryfunc)instance_idiv, /*nb_inplace_divide*/
-       (binaryfunc)instance_imod, /*nb_inplace_remainder*/
-       (ternaryfunc)instance_ipow, /*nb_inplace_power*/
-       (binaryfunc)instance_ilshift, /*nb_inplace_lshift*/
-       (binaryfunc)instance_irshift, /*nb_inplace_rshift*/
-       (binaryfunc)instance_iand, /*nb_inplace_and*/
-       (binaryfunc)instance_ixor, /*nb_inplace_xor*/
-       (binaryfunc)instance_ior, /*nb_inplace_or*/
-       (binaryfunc)instance_cmp, /*nb_cmp*/
+       (binaryfunc)instance_add,               /* nb_add */
+       (binaryfunc)instance_sub,               /* nb_subtract */
+       (binaryfunc)instance_mul,               /* nb_multiply */
+       (binaryfunc)instance_div,               /* nb_divide */
+       (binaryfunc)instance_mod,               /* nb_remainder */
+       (binaryfunc)instance_divmod,            /* nb_divmod */
+       (ternaryfunc)instance_pow,              /* nb_power */
+       (unaryfunc)instance_neg,                /* nb_negative */
+       (unaryfunc)instance_pos,                /* nb_positive */
+       (unaryfunc)instance_abs,                /* nb_absolute */
+       (inquiry)instance_nonzero,              /* nb_nonzero */
+       (unaryfunc)instance_invert,             /* nb_invert */
+       (binaryfunc)instance_lshift,            /* nb_lshift */
+       (binaryfunc)instance_rshift,            /* nb_rshift */
+       (binaryfunc)instance_and,               /* nb_and */
+       (binaryfunc)instance_xor,               /* nb_xor */
+       (binaryfunc)instance_or,                /* nb_or */
+       (coercion)instance_coerce,              /* nb_coerce */
+       (unaryfunc)instance_int,                /* nb_int */
+       (unaryfunc)instance_long,               /* nb_long */
+       (unaryfunc)instance_float,              /* nb_float */
+       (unaryfunc)instance_oct,                /* nb_oct */
+       (unaryfunc)instance_hex,                /* nb_hex */
+       (binaryfunc)instance_iadd,              /* nb_inplace_add */
+       (binaryfunc)instance_isub,              /* nb_inplace_subtract */
+       (binaryfunc)instance_imul,              /* nb_inplace_multiply */
+       (binaryfunc)instance_idiv,              /* nb_inplace_divide */
+       (binaryfunc)instance_imod,              /* nb_inplace_remainder */
+       (ternaryfunc)instance_ipow,             /* nb_inplace_power */
+       (binaryfunc)instance_ilshift,           /* nb_inplace_lshift */
+       (binaryfunc)instance_irshift,           /* nb_inplace_rshift */
+       (binaryfunc)instance_iand,              /* nb_inplace_and */
+       (binaryfunc)instance_ixor,              /* nb_inplace_xor */
+       (binaryfunc)instance_ior,               /* nb_inplace_or */
 };
 
 PyTypeObject PyInstance_Type = {
@@ -1595,24 +1725,26 @@ PyTypeObject PyInstance_Type = {
        "instance",
        sizeof(PyInstanceObject) + PyGC_HEAD_SIZE,
        0,
-       (destructor)instance_dealloc, /*tp_dealloc*/
-       0,                      /*tp_print*/
-       0,                      /*tp_getattr*/
-       0,                      /*tp_setattr*/
-       instance_compare,       /*tp_compare*/
-       (reprfunc)instance_repr, /*tp_repr*/
-       &instance_as_number,    /*tp_as_number*/
-       &instance_as_sequence,  /*tp_as_sequence*/
-       &instance_as_mapping,   /*tp_as_mapping*/
-       (hashfunc)instance_hash, /*tp_hash*/
-       0,                      /*tp_call*/
-       0,                      /*tp_str*/
-       (getattrofunc)instance_getattr, /*tp_getattro*/
-       (setattrofunc)instance_setattr, /*tp_setattro*/
-       0, /* tp_as_buffer */
-       Py_TPFLAGS_DEFAULT | Py_TPFLAGS_GC | Py_TPFLAGS_NEWSTYLENUMBER, /*tp_flags*/
-       0,              /* tp_doc */
+       (destructor)instance_dealloc,           /* tp_dealloc */
+       0,                                      /* tp_print */
+       0,                                      /* tp_getattr */
+       0,                                      /* tp_setattr */
+       instance_compare,                       /* tp_compare */
+       (reprfunc)instance_repr,                /* tp_repr */
+       &instance_as_number,                    /* tp_as_number */
+       &instance_as_sequence,                  /* tp_as_sequence */
+       &instance_as_mapping,                   /* tp_as_mapping */
+       (hashfunc)instance_hash,                /* tp_hash */
+       0,                                      /* tp_call */
+       0,                                      /* tp_str */
+       (getattrofunc)instance_getattr,         /* tp_getattro */
+       (setattrofunc)instance_setattr,         /* tp_setattro */
+       0,                                      /* tp_as_buffer */
+       Py_TPFLAGS_DEFAULT | Py_TPFLAGS_GC | Py_TPFLAGS_CHECKTYPES,/*tp_flags*/
+       0,                                      /* tp_doc */
        (traverseproc)instance_traverse,        /* tp_traverse */
+       0,                                      /* tp_clear */
+       instance_richcompare,                   /* tp_richcompare */
 };
 
 
@@ -1854,23 +1986,23 @@ PyTypeObject PyMethod_Type = {
        "instance method",
        sizeof(PyMethodObject) + PyGC_HEAD_SIZE,
        0,
-       (destructor)instancemethod_dealloc, /*tp_dealloc*/
-       0,                      /*tp_print*/
-       0,                      /*tp_getattr*/
-       0,                      /*tp_setattr*/
-       (cmpfunc)instancemethod_compare, /*tp_compare*/
-       (reprfunc)instancemethod_repr, /*tp_repr*/
-       0,                      /*tp_as_number*/
-       0,                      /*tp_as_sequence*/
-       0,                      /*tp_as_mapping*/
-       (hashfunc)instancemethod_hash, /*tp_hash*/
-       0,                      /*tp_call*/
-       0,                      /*tp_str*/
-       (getattrofunc)instancemethod_getattro, /*tp_getattro*/
-       (setattrofunc)instancemethod_setattro, /*tp_setattro*/
-       0,                      /* tp_as_buffer */
-       Py_TPFLAGS_DEFAULT | Py_TPFLAGS_GC, /*tp_flags*/
-       0,                      /* tp_doc */
+       (destructor)instancemethod_dealloc,     /* tp_dealloc */
+       0,                                      /* tp_print */
+       0,                                      /* tp_getattr */
+       0,                                      /* tp_setattr */
+       (cmpfunc)instancemethod_compare,        /* tp_compare */
+       (reprfunc)instancemethod_repr,          /* tp_repr */
+       0,                                      /* tp_as_number */
+       0,                                      /* tp_as_sequence */
+       0,                                      /* tp_as_mapping */
+       (hashfunc)instancemethod_hash,          /* tp_hash */
+       0,                                      /* tp_call */
+       0,                                      /* tp_str */
+       (getattrofunc)instancemethod_getattro,  /* tp_getattro */
+       (setattrofunc)instancemethod_setattro,  /* tp_setattro */
+       0,                                      /* tp_as_buffer */
+       Py_TPFLAGS_DEFAULT | Py_TPFLAGS_GC,     /* tp_flags */
+       0,                                      /* tp_doc */
        (traverseproc)instancemethod_traverse,  /* tp_traverse */
 };