]> granicus.if.org Git - python/commitdiff
Patch #1549 by Thomas Herve.
authorGuido van Rossum <guido@python.org>
Wed, 19 Dec 2007 22:51:13 +0000 (22:51 +0000)
committerGuido van Rossum <guido@python.org>
Wed, 19 Dec 2007 22:51:13 +0000 (22:51 +0000)
This changes the rules for when __hash__ is inherited slightly,
by allowing it to be inherited when one or more of __lt__, __le__,
__gt__, __ge__ are overridden, as long as __eq__ and __ne__ aren't.

Lib/test/test_richcmp.py
Objects/typeobject.c

index f412a89055cdb117b19f2215632e9d2c1bb7c3d3..db6d31ff9527bec6f1cbd0f1dd605e9d392ba28a 100644 (file)
@@ -85,6 +85,35 @@ class Vector:
             raise ValueError, "Cannot compare vectors of different length"
         return other
 
+
+class SimpleOrder(object):
+    """
+    A simple class that defines order but not full comparison.
+    """
+
+    def __init__(self, value):
+        self.value = value
+
+    def __lt__(self, other):
+        if not isinstance(other, SimpleOrder):
+            return True
+        return self.value < other.value
+
+    def __gt__(self, other):
+        if not isinstance(other, SimpleOrder):
+            return False
+        return self.value > other.value
+
+
+class DumbEqualityWithoutHash(object):
+    """
+    A class that define __eq__, but no __hash__: it shouldn't be hashable.
+    """
+
+    def __eq__(self, other):
+        return False
+
+
 opmap = {
     "lt": (lambda a,b: a< b, operator.lt, operator.__lt__),
     "le": (lambda a,b: a<=b, operator.le, operator.__le__),
@@ -330,8 +359,39 @@ class ListTest(unittest.TestCase):
         for op in opmap["lt"]:
             self.assertIs(op(x, y), True)
 
+
+class HashableTest(unittest.TestCase):
+    """
+    Test hashability of classes with rich operators defined.
+    """
+
+    def test_simpleOrderHashable(self):
+        """
+        A class that only defines __gt__ and/or __lt__ should be hashable.
+        """
+        a = SimpleOrder(1)
+        b = SimpleOrder(2)
+        self.assert_(a < b)
+        self.assert_(b > a)
+        self.assert_(a.__hash__ is not None)
+
+    def test_notHashableException(self):
+        """
+        If a class is not hashable, it should raise a TypeError with an
+        understandable message.
+        """
+        a = DumbEqualityWithoutHash()
+        try:
+            hash(a)
+        except TypeError, e:
+            self.assertEquals(str(e),
+                              "unhashable type: 'DumbEqualityWithoutHash'")
+        else:
+            raise test_support.TestFailed("Should not be here")
+
+
 def test_main():
-    test_support.run_unittest(VectorTest, NumberTest, MiscTest, DictTest, ListTest)
+    test_support.run_unittest(VectorTest, NumberTest, MiscTest, DictTest, ListTest, HashableTest)
 
 if __name__ == "__main__":
     test_main()
index e41813706c1308847218d12364a9ce4f11f837a0..74a00b1ba5c62068b71b20f4d4e2775bf6074fa3 100644 (file)
@@ -3230,28 +3230,22 @@ inherit_special(PyTypeObject *type, PyTypeObject *base)
                type->tp_flags |= Py_TPFLAGS_DICT_SUBCLASS;
 }
 
-/* Map rich comparison operators to their __xx__ namesakes */
-static char *name_op[] = {
-    "__lt__",
-    "__le__",
-    "__eq__",
-    "__ne__",
-    "__gt__",
-    "__ge__",
-    "__cmp__",
-       /* These are only for overrides_hash(): */
-    "__hash__",
+static char *hash_name_op[] = {
+       "__eq__",
+       "__cmp__",
+       "__hash__",
+       NULL
 };
 
 static int
 overrides_hash(PyTypeObject *type)
 {
-       int i;
+       char **p;
        PyObject *dict = type->tp_dict;
 
        assert(dict != NULL);
-       for (i = 0; i < 8; i++) {
-               if (PyDict_GetItemString(dict, name_op[i]) != NULL)
+       for (p = hash_name_op; *p; p++) {
+               if (PyDict_GetItemString(dict, *p) != NULL)
                        return 1;
        }
        return 0;
@@ -4846,7 +4840,7 @@ slot_tp_hash(PyObject *self)
 
        func = lookup_method(self, "__hash__", &hash_str);
 
-       if (func != NULL) {
+       if (func != NULL && func != Py_None) {
                PyObject *res = PyEval_CallObject(func, NULL);
                Py_DECREF(func);
                if (res == NULL)
@@ -4971,6 +4965,15 @@ slot_tp_setattro(PyObject *self, PyObject *name, PyObject *value)
        return 0;
 }
 
+static char *name_op[] = {
+    "__lt__",
+    "__le__",
+    "__eq__",
+    "__ne__",
+    "__gt__",
+    "__ge__",
+};
+
 static PyObject *
 half_richcompare(PyObject *self, PyObject *other, int op)
 {