]> granicus.if.org Git - python/commitdiff
bpo-27575: port set intersection logic into dictview intersection (GH-7696)
authorForest Gregg <fgregg@users.noreply.github.com>
Mon, 26 Aug 2019 07:17:43 +0000 (02:17 -0500)
committerRaymond Hettinger <rhettinger@users.noreply.github.com>
Mon, 26 Aug 2019 07:17:43 +0000 (00:17 -0700)
Lib/test/test_dictviews.py
Misc/NEWS.d/next/Core and Builtins/2018-06-14-13-55-45.bpo-27575.mMYgzv.rst [new file with mode: 0644]
Objects/dictobject.c

index 2763cbfb4bbc64f12d19cc28092901c08dea4b95..b15cfebc98912d52c9b3760c251b9d9c2abb2859 100644 (file)
@@ -92,6 +92,12 @@ class DictSetTest(unittest.TestCase):
         d1 = {'a': 1, 'b': 2}
         d2 = {'b': 3, 'c': 2}
         d3 = {'d': 4, 'e': 5}
+        d4 = {'d': 4}
+
+        class CustomSet(set):
+            def intersection(self, other):
+                return CustomSet(super().intersection(other))
+
         self.assertEqual(d1.keys() & d1.keys(), {'a', 'b'})
         self.assertEqual(d1.keys() & d2.keys(), {'b'})
         self.assertEqual(d1.keys() & d3.keys(), set())
@@ -99,6 +105,14 @@ class DictSetTest(unittest.TestCase):
         self.assertEqual(d1.keys() & set(d2.keys()), {'b'})
         self.assertEqual(d1.keys() & set(d3.keys()), set())
         self.assertEqual(d1.keys() & tuple(d1.keys()), {'a', 'b'})
+        self.assertEqual(d3.keys() & d4.keys(), {'d'})
+        self.assertEqual(d4.keys() & d3.keys(), {'d'})
+        self.assertEqual(d4.keys() & set(d3.keys()), {'d'})
+        self.assertIsInstance(d4.keys() & frozenset(d3.keys()), set)
+        self.assertIsInstance(frozenset(d3.keys()) & d4.keys(), set)
+        self.assertIs(type(d4.keys() & CustomSet(d3.keys())), set)
+        self.assertIs(type(d1.keys() & []), set)
+        self.assertIs(type([] & d1.keys()), set)
 
         self.assertEqual(d1.keys() | d1.keys(), {'a', 'b'})
         self.assertEqual(d1.keys() | d2.keys(), {'a', 'b', 'c'})
diff --git a/Misc/NEWS.d/next/Core and Builtins/2018-06-14-13-55-45.bpo-27575.mMYgzv.rst b/Misc/NEWS.d/next/Core and Builtins/2018-06-14-13-55-45.bpo-27575.mMYgzv.rst
new file mode 100644 (file)
index 0000000..2c250dc
--- /dev/null
@@ -0,0 +1,2 @@
+Improve speed of dictview intersection by directly using set intersection
+logic. Patch by David Su.
index f168ad5d2f00de95aae41c4a4f0f0b29aa47521b..fec3a87e9f5de7c9022acb45b918ab82450aabbf 100644 (file)
@@ -4169,24 +4169,97 @@ dictviews_sub(PyObject* self, PyObject *other)
     return result;
 }
 
-PyObject*
+static int
+dictitems_contains(_PyDictViewObject *dv, PyObject *obj);
+
+PyObject *
 _PyDictView_Intersect(PyObject* self, PyObject *other)
 {
-    PyObject *result = PySet_New(self);
+    PyObject *result;
+    PyObject *it;
+    PyObject *key;
+    Py_ssize_t len_self;
+    int rv;
+    int (*dict_contains)(_PyDictViewObject *, PyObject *);
     PyObject *tmp;
-    _Py_IDENTIFIER(intersection_update);
 
+    /* Python interpreter swaps parameters when dict view
+       is on right side of & */
+    if (!PyDictViewSet_Check(self)) {
+        PyObject *tmp = other;
+        other = self;
+        self = tmp;
+    }
+
+    len_self = dictview_len((_PyDictViewObject *)self);
+
+    /* if other is a set and self is smaller than other,
+       reuse set intersection logic */
+    if (Py_TYPE(other) == &PySet_Type && len_self <= PyObject_Size(other)) {
+        _Py_IDENTIFIER(intersection);
+        return _PyObject_CallMethodIdObjArgs(other, &PyId_intersection, self, NULL);
+    }
+
+    /* if other is another dict view, and it is bigger than self,
+       swap them */
+    if (PyDictViewSet_Check(other)) {
+        Py_ssize_t len_other = dictview_len((_PyDictViewObject *)other);
+        if (len_other > len_self) {
+            PyObject *tmp = other;
+            other = self;
+            self = tmp;
+        }
+    }
+
+    /* at this point, two things should be true
+       1. self is a dictview
+       2. if other is a dictview then it is smaller than self */
+    result = PySet_New(NULL);
     if (result == NULL)
         return NULL;
 
+    it = PyObject_GetIter(other);
+
+    _Py_IDENTIFIER(intersection_update);
     tmp = _PyObject_CallMethodIdOneArg(result, &PyId_intersection_update, other);
     if (tmp == NULL) {
         Py_DECREF(result);
         return NULL;
     }
-
     Py_DECREF(tmp);
+
+    if (PyDictKeys_Check(self)) {
+        dict_contains = dictkeys_contains;
+    }
+    /* else PyDictItems_Check(self) */
+    else {
+        dict_contains = dictitems_contains;
+    }
+
+    while ((key = PyIter_Next(it)) != NULL) {
+        rv = dict_contains((_PyDictViewObject *)self, key);
+        if (rv < 0) {
+            goto error;
+        }
+        if (rv) {
+            if (PySet_Add(result, key)) {
+                goto error;
+            }
+        }
+        Py_DECREF(key);
+    }
+    Py_DECREF(it);
+    if (PyErr_Occurred()) {
+        Py_DECREF(result);
+        return NULL;
+    }
     return result;
+
+error:
+    Py_DECREF(it);
+    Py_DECREF(result);
+    Py_DECREF(key);
+    return NULL;
 }
 
 static PyObject*