]> granicus.if.org Git - python/commitdiff
Keir Mierle's set operations for dict views (keys/items only of course).
authorGuido van Rossum <guido@python.org>
Fri, 24 Aug 2007 23:41:22 +0000 (23:41 +0000)
committerGuido van Rossum <guido@python.org>
Fri, 24 Aug 2007 23:41:22 +0000 (23:41 +0000)
Lib/test/test_dict.py
Objects/dictobject.c

index d81ecf59c49977e947962ca1379b497b41d1cab0..60f3f402a95fcbb9456a152c16e501253e5521e8 100644 (file)
@@ -458,6 +458,23 @@ class DictTest(unittest.TestCase):
         self.assertRaises(RuntimeError, lambda: d2.items() < d3.items())
         self.assertRaises(RuntimeError, lambda: d3.items() > d2.items())
 
+    def test_dictview_set_operations(self):
+        k1 = {1:1, 2:2}.keys()
+        k2 = {1:1, 2:2, 3:3}.keys()
+        k3 = {4:4}.keys()
+
+        self.assertEquals(k1 - k2, set())
+        self.assertEquals(k1 - k3, {1,2})
+        self.assertEquals(k2 - k1, {3})
+        self.assertEquals(k3 - k1, {4})
+        self.assertEquals(k1 & k2, {1,2})
+        self.assertEquals(k1 & k3, set())
+        self.assertEquals(k1 | k2, {1,2,3})
+        self.assertEquals(k1 ^ k2, {3})
+        self.assertEquals(k1 ^ k3, {1,2,4})
+
+        # XXX similar tests for .items()
+
     def test_missing(self):
         # Make sure dict doesn't have a __missing__ method
         self.assertEqual(hasattr(dict, "__missing__"), False)
index 9ef1fcc340fe16f86498a74a1e5909e18b0669b2..539d73418a654750d9526dc034d863b057ef1f58 100644 (file)
@@ -2489,6 +2489,98 @@ static PySequenceMethods dictkeys_as_sequence = {
        (objobjproc)dictkeys_contains,  /* sq_contains */
 };
 
+static PyObject*
+dictviews_sub(PyObject* self, PyObject *other)
+{
+       PyObject *result = PySet_New(self);
+       PyObject *tmp;
+       if (result == NULL)
+               return NULL;
+
+       tmp = PyObject_CallMethod(result, "difference_update", "O", other);
+       if (tmp == NULL) {
+               Py_DECREF(result);
+               return NULL;
+       }
+
+       Py_DECREF(tmp);
+       return result;
+}
+
+static PyObject*
+dictviews_and(PyObject* self, PyObject *other)
+{
+       PyObject *result = PySet_New(self);
+       PyObject *tmp;
+       if (result == NULL)
+               return NULL;
+
+       tmp = PyObject_CallMethod(result, "intersection_update", "O", other);
+       if (tmp == NULL) {
+               Py_DECREF(result);
+               return NULL;
+       }
+
+       Py_DECREF(tmp);
+       return result;
+}
+
+static PyObject*
+dictviews_or(PyObject* self, PyObject *other)
+{
+       PyObject *result = PySet_New(self);
+       PyObject *tmp;
+       if (result == NULL)
+               return NULL;
+
+       tmp = PyObject_CallMethod(result, "update", "O", other);
+       if (tmp == NULL) {
+               Py_DECREF(result);
+               return NULL;
+       }
+
+       Py_DECREF(tmp);
+       return result;
+}
+
+static PyObject*
+dictviews_xor(PyObject* self, PyObject *other)
+{
+       PyObject *result = PySet_New(self);
+       PyObject *tmp;
+       if (result == NULL)
+               return NULL;
+
+       tmp = PyObject_CallMethod(result, "symmetric_difference_update", "O",
+                                 other);
+       if (tmp == NULL) {
+               Py_DECREF(result);
+               return NULL;
+       }
+
+       Py_DECREF(tmp);
+       return result;
+}
+
+static PyNumberMethods dictviews_as_number = {
+       0,                              /*nb_add*/
+       (binaryfunc)dictviews_sub,      /*nb_subtract*/
+       0,                              /*nb_multiply*/
+       0,                              /*nb_remainder*/
+       0,                              /*nb_divmod*/
+       0,                              /*nb_power*/
+       0,                              /*nb_negative*/
+       0,                              /*nb_positive*/
+       0,                              /*nb_absolute*/
+       0,                              /*nb_bool*/
+       0,                              /*nb_invert*/
+       0,                              /*nb_lshift*/
+       0,                              /*nb_rshift*/
+       (binaryfunc)dictviews_and,      /*nb_and*/
+       (binaryfunc)dictviews_xor,      /*nb_xor*/
+       (binaryfunc)dictviews_or,       /*nb_or*/
+};
+
 static PyMethodDef dictkeys_methods[] = {
        {NULL,          NULL}           /* sentinel */
 };
@@ -2505,7 +2597,7 @@ PyTypeObject PyDictKeys_Type = {
        0,                                      /* tp_setattr */
        0,                                      /* tp_compare */
        0,                                      /* tp_repr */
-       0,                                      /* tp_as_number */
+       &dictviews_as_number,                   /* tp_as_number */
        &dictkeys_as_sequence,                  /* tp_as_sequence */
        0,                                      /* tp_as_mapping */
        0,                                      /* tp_hash */
@@ -2589,7 +2681,7 @@ PyTypeObject PyDictItems_Type = {
        0,                                      /* tp_setattr */
        0,                                      /* tp_compare */
        0,                                      /* tp_repr */
-       0,                                      /* tp_as_number */
+       &dictviews_as_number,                   /* tp_as_number */
        &dictitems_as_sequence,                 /* tp_as_sequence */
        0,                                      /* tp_as_mapping */
        0,                                      /* tp_hash */