]> granicus.if.org Git - python/commitdiff
Fix obscure set crashers (#8420). Backport of d56b3cafb1e6, reviewed by Raymond.
authorÉric Araujo <merwok@netwok.org>
Wed, 23 Mar 2011 01:08:07 +0000 (02:08 +0100)
committerÉric Araujo <merwok@netwok.org>
Wed, 23 Mar 2011 01:08:07 +0000 (02:08 +0100)
Lib/test/test_set.py
Objects/setobject.c

index fdbfe19c91fc0de925d9388c34e1ef14211c5ef7..99d5c70e0a362f3860111abf42823a75a7424843 100644 (file)
@@ -1660,6 +1660,39 @@ class TestVariousIteratorArgs(unittest.TestCase):
                 self.assertRaises(TypeError, getattr(set('january'), methname), N(data))
                 self.assertRaises(ZeroDivisionError, getattr(set('january'), methname), E(data))
 
+class bad_eq:
+    def __eq__(self, other):
+        if be_bad:
+            set2.clear()
+            raise ZeroDivisionError
+        return self is other
+    def __hash__(self):
+        return 0
+
+class bad_dict_clear:
+    def __eq__(self, other):
+        if be_bad:
+            dict2.clear()
+        return self is other
+    def __hash__(self):
+        return 0
+
+class TestWeirdBugs(unittest.TestCase):
+    def test_8420_set_merge(self):
+        # This used to segfault
+        global be_bad, set2, dict2
+        be_bad = False
+        set1 = {bad_eq()}
+        set2 = {bad_eq() for i in range(75)}
+        be_bad = True
+        self.assertRaises(ZeroDivisionError, set1.update, set2)
+
+        be_bad = False
+        set1 = {bad_dict_clear()}
+        dict2 = {bad_dict_clear(): None}
+        be_bad = True
+        set1.symmetric_difference_update(dict2)
+
 # Application tests (based on David Eppstein's graph recipes ====================================
 
 def powerset(U):
@@ -1804,6 +1837,7 @@ def test_main(verbose=None):
         TestIdentities,
         TestVariousIteratorArgs,
         TestGraphs,
+        TestWeirdBugs,
         )
 
     support.run_unittest(*test_classes)
index 30afc7c1eefc16188f70f8b7c28bcea58613f94a..7aa1a7faee3a4d91ad77e5161c6680f87d422a65 100644 (file)
@@ -364,12 +364,14 @@ static int
 set_add_entry(register PySetObject *so, setentry *entry)
 {
     register Py_ssize_t n_used;
+    PyObject *key = entry->key;
+    long hash = entry->hash;
 
     assert(so->fill <= so->mask);  /* at least one empty slot */
     n_used = so->used;
-    Py_INCREF(entry->key);
-    if (set_insert_key(so, entry->key, (long) entry->hash) == -1) {
-        Py_DECREF(entry->key);
+    Py_INCREF(key);
+    if (set_insert_key(so, key, hash) == -1) {
+        Py_DECREF(key);
         return -1;
     }
     if (!(so->used > n_used && so->fill*3 >= (so->mask+1)*2))
@@ -637,6 +639,8 @@ static int
 set_merge(PySetObject *so, PyObject *otherset)
 {
     PySetObject *other;
+    PyObject *key;
+    long hash;
     register Py_ssize_t i;
     register setentry *entry;
 
@@ -657,11 +661,13 @@ set_merge(PySetObject *so, PyObject *otherset)
     }
     for (i = 0; i <= other->mask; i++) {
         entry = &other->table[i];
-        if (entry->key != NULL &&
-            entry->key != dummy) {
-            Py_INCREF(entry->key);
-            if (set_insert_key(so, entry->key, (long) entry->hash) == -1) {
-                Py_DECREF(entry->key);
+        key = entry->key;
+        hash = entry->hash;
+        if (key != NULL &&
+            key != dummy) {
+            Py_INCREF(key);
+            if (set_insert_key(so, key, hash) == -1) {
+                Py_DECREF(key);
                 return -1;
             }
         }
@@ -1642,15 +1648,22 @@ set_symmetric_difference_update(PySetObject *so, PyObject *other)
         while (_PyDict_Next(other, &pos, &key, &value, &hash)) {
             setentry an_entry;
 
+            Py_INCREF(key);
             an_entry.hash = hash;
             an_entry.key = key;
+
             rv = set_discard_entry(so, &an_entry);
-            if (rv == -1)
+            if (rv == -1) {
+                Py_DECREF(key);
                 return NULL;
+            }
             if (rv == DISCARD_NOTFOUND) {
-                if (set_add_entry(so, &an_entry) == -1)
+                if (set_add_entry(so, &an_entry) == -1) {
+                    Py_DECREF(key);
                     return NULL;
+                }
             }
+            Py_DECREF(key);
         }
         Py_RETURN_NONE;
     }