]> granicus.if.org Git - python/commitdiff
Implement __contains__ for dict_keys and dict_items.
authorGuido van Rossum <guido@python.org>
Sat, 10 Feb 2007 18:55:06 +0000 (18:55 +0000)
committerGuido van Rossum <guido@python.org>
Sat, 10 Feb 2007 18:55:06 +0000 (18:55 +0000)
(Not for dict_values, where it can't be done faster than
the default implementation which just iterates the elements.)

Lib/test/test_dictviews.py
Objects/dictobject.c

index c0c1da1bf25b42f12bef6732abdb98b0c974089b..4c436f7eae2db143d478ebd7c44ed5d04a5d913f 100644 (file)
@@ -3,17 +3,39 @@ from test import test_support
 
 class DictSetTest(unittest.TestCase):
 
+    def test_constructors_not_callable(self):
+        kt = type({}.KEYS())
+        self.assertRaises(TypeError, kt, {})
+        self.assertRaises(TypeError, kt)
+        it = type({}.ITEMS())
+        self.assertRaises(TypeError, it, {})
+        self.assertRaises(TypeError, it)
+        vt = type({}.VALUES())
+        self.assertRaises(TypeError, vt, {})
+        self.assertRaises(TypeError, vt)
+
     def test_dict_keys(self):
         d = {1: 10, "a": "ABC"}
         keys = d.KEYS()
         self.assertEqual(set(keys), {1, "a"})
         self.assertEqual(len(keys), 2)
+        self.assert_(1 in keys)
+        self.assert_("a" in keys)
+        self.assert_(10 not in keys)
+        self.assert_("Z" not in keys)
 
     def test_dict_items(self):
         d = {1: 10, "a": "ABC"}
         items = d.ITEMS()
         self.assertEqual(set(items), {(1, 10), ("a", "ABC")})
         self.assertEqual(len(items), 2)
+        self.assert_((1, 10) in items)
+        self.assert_(("a", "ABC") in items)
+        self.assert_((1, 11) not in items)
+        self.assert_(1 not in items)
+        self.assert_(() not in items)
+        self.assert_((1,) not in items)
+        self.assert_((1, 2, 3) not in items)
 
     def test_dict_values(self):
         d = {1: 10, "a": "ABC"}
index e2e98db02f1e5dfad885684a341e25da06ce1284..ec14fcbc316f8d9c6cfb12ca6b9f1cc458e6f063 100644 (file)
@@ -2336,37 +2336,40 @@ PyTypeObject PyDictIterItem_Type = {
 };
 
 
+/***********************************************/
 /* View objects for keys(), items(), values(). */
+/***********************************************/
+
 /* While this is incomplete, we use KEYS(), ITEMS(), VALUES(). */
 
 /* The instance lay-out is the same for all three; but the type differs. */
 
 typedef struct {
        PyObject_HEAD
-       dictobject *ds_dict;
+       dictobject *dv_dict;
 } dictviewobject;
 
 
 static void
-dictview_dealloc(dictviewobject *ds)
+dictview_dealloc(dictviewobject *dv)
 {
-       Py_XDECREF(ds->ds_dict);
-       PyObject_Del(ds);
+       Py_XDECREF(dv->dv_dict);
+       PyObject_Del(dv);
 }
 
 static Py_ssize_t
-dictview_len(dictviewobject *ds)
+dictview_len(dictviewobject *dv)
 {
        Py_ssize_t len = 0;
-       if (ds->ds_dict != NULL)
-               len = ds->ds_dict->ma_used;
+       if (dv->dv_dict != NULL)
+               len = dv->dv_dict->ma_used;
        return len;
 }
 
 static PyObject *
 dictview_new(PyObject *dict, PyTypeObject *type)
 {
-       dictviewobject *ds;
+       dictviewobject *dv;
        if (dict == NULL) {
                PyErr_BadInternalCall();
                return NULL;
@@ -2378,23 +2381,31 @@ dictview_new(PyObject *dict, PyTypeObject *type)
                             type->tp_name, dict->ob_type->tp_name);
                return NULL;
        }
-       ds = PyObject_New(dictviewobject, type);
-       if (ds == NULL)
+       dv = PyObject_New(dictviewobject, type);
+       if (dv == NULL)
                return NULL;
        Py_INCREF(dict);
-       ds->ds_dict = (dictobject *)dict;
-       return (PyObject *)ds;
+       dv->dv_dict = (dictobject *)dict;
+       return (PyObject *)dv;
 }
 
-/* dict_keys */
+/*** dict_keys ***/
 
 static PyObject *
-dictkeys_iter(dictviewobject *ds)
+dictkeys_iter(dictviewobject *dv)
 {
-       if (ds->ds_dict == NULL) {
+       if (dv->dv_dict == NULL) {
                Py_RETURN_NONE;
        }
-       return dictiter_new(ds->ds_dict, &PyDictIterKey_Type);
+       return dictiter_new(dv->dv_dict, &PyDictIterKey_Type);
+}
+
+static int
+dictkeys_contains(dictviewobject *dv, PyObject *obj)
+{
+       if (dv->dv_dict == NULL)
+               return 0;
+       return PyDict_Contains((PyObject *)dv->dv_dict, obj);
 }
 
 static PySequenceMethods dictkeys_as_sequence = {
@@ -2405,7 +2416,7 @@ static PySequenceMethods dictkeys_as_sequence = {
        0,                              /* sq_slice */
        0,                              /* sq_ass_item */
        0,                              /* sq_ass_slice */
-       (objobjproc)0,                  /* sq_contains */
+       (objobjproc)dictkeys_contains,  /* sq_contains */
 };
 
 static PyMethodDef dictkeys_methods[] = {
@@ -2452,15 +2463,34 @@ dictkeys_new(PyObject *dict)
        return dictview_new(dict, &PyDictKeys_Type);
 }
 
-/* dict_items */
+/*** dict_items ***/
 
 static PyObject *
-dictitems_iter(dictviewobject *ds)
+dictitems_iter(dictviewobject *dv)
 {
-       if (ds->ds_dict == NULL) {
+       if (dv->dv_dict == NULL) {
                Py_RETURN_NONE;
        }
-       return dictiter_new(ds->ds_dict, &PyDictIterItem_Type);
+       return dictiter_new(dv->dv_dict, &PyDictIterItem_Type);
+}
+
+static int
+dictitems_contains(dictviewobject *dv, PyObject *obj)
+{
+       PyObject *key, *value, *found;
+       if (dv->dv_dict == NULL)
+               return 0;
+       if (!PyTuple_Check(obj) || PyTuple_GET_SIZE(obj) != 2)
+               return 0;
+       key = PyTuple_GET_ITEM(obj, 0);
+       value = PyTuple_GET_ITEM(obj, 1);
+       found = PyDict_GetItem((PyObject *)dv->dv_dict, key);
+       if (found == NULL) {
+               if (PyErr_Occurred())
+                       return -1;
+               return 0;
+       }
+       return PyObject_RichCompareBool(value, found, Py_EQ);
 }
 
 static PySequenceMethods dictitems_as_sequence = {
@@ -2471,7 +2501,7 @@ static PySequenceMethods dictitems_as_sequence = {
        0,                              /* sq_slice */
        0,                              /* sq_ass_item */
        0,                              /* sq_ass_slice */
-       (objobjproc)0,                  /* sq_contains */
+       (objobjproc)dictitems_contains, /* sq_contains */
 };
 
 static PyMethodDef dictitems_methods[] = {
@@ -2518,15 +2548,15 @@ dictitems_new(PyObject *dict)
        return dictview_new(dict, &PyDictItems_Type);
 }
 
-/* dict_values */
+/*** dict_values ***/
 
 static PyObject *
-dictvalues_iter(dictviewobject *ds)
+dictvalues_iter(dictviewobject *dv)
 {
-       if (ds->ds_dict == NULL) {
+       if (dv->dv_dict == NULL) {
                Py_RETURN_NONE;
        }
-       return dictiter_new(ds->ds_dict, &PyDictIterValue_Type);
+       return dictiter_new(dv->dv_dict, &PyDictIterValue_Type);
 }
 
 static PySequenceMethods dictvalues_as_sequence = {