]> granicus.if.org Git - python/commitdiff
Issue #1721812: Binary operations and copy operations on set/frozenset
authorRaymond Hettinger <python@rcn.com>
Sun, 16 Nov 2008 11:44:54 +0000 (11:44 +0000)
committerRaymond Hettinger <python@rcn.com>
Sun, 16 Nov 2008 11:44:54 +0000 (11:44 +0000)
subclasses need to return the base type, not the subclass itself.

Lib/test/test_set.py
Misc/NEWS
Objects/setobject.c

index 614c9c09b0e1614c942729b2ec65ca35891fd5c1..07319fc20d1ce1c80c23cd73ca4a2cf9adaf39d8 100644 (file)
@@ -71,7 +71,7 @@ class TestJointOps(unittest.TestCase):
         for c in self.letters:
             self.assertEqual(c in u, c in self.d or c in self.otherword)
         self.assertEqual(self.s, self.thetype(self.word))
-        self.assertEqual(type(u), self.thetype)
+        self.assertEqual(type(u), self.basetype)
         self.assertRaises(PassThru, self.s.union, check_pass_thru())
         self.assertRaises(TypeError, self.s.union, [[]])
         for C in set, frozenset, dict.fromkeys, str, list, tuple:
@@ -97,7 +97,7 @@ class TestJointOps(unittest.TestCase):
         for c in self.letters:
             self.assertEqual(c in i, c in self.d and c in self.otherword)
         self.assertEqual(self.s, self.thetype(self.word))
-        self.assertEqual(type(i), self.thetype)
+        self.assertEqual(type(i), self.basetype)
         self.assertRaises(PassThru, self.s.intersection, check_pass_thru())
         for C in set, frozenset, dict.fromkeys, str, list, tuple:
             self.assertEqual(self.thetype('abcba').intersection(C('cdc')), set('cc'))
@@ -142,7 +142,7 @@ class TestJointOps(unittest.TestCase):
         for c in self.letters:
             self.assertEqual(c in i, c in self.d and c not in self.otherword)
         self.assertEqual(self.s, self.thetype(self.word))
-        self.assertEqual(type(i), self.thetype)
+        self.assertEqual(type(i), self.basetype)
         self.assertRaises(PassThru, self.s.difference, check_pass_thru())
         self.assertRaises(TypeError, self.s.difference, [[]])
         for C in set, frozenset, dict.fromkeys, str, list, tuple:
@@ -169,7 +169,7 @@ class TestJointOps(unittest.TestCase):
         for c in self.letters:
             self.assertEqual(c in i, (c in self.d) ^ (c in self.otherword))
         self.assertEqual(self.s, self.thetype(self.word))
-        self.assertEqual(type(i), self.thetype)
+        self.assertEqual(type(i), self.basetype)
         self.assertRaises(PassThru, self.s.symmetric_difference, check_pass_thru())
         self.assertRaises(TypeError, self.s.symmetric_difference, [[]])
         for C in set, frozenset, dict.fromkeys, str, list, tuple:
@@ -325,6 +325,7 @@ class TestJointOps(unittest.TestCase):
 
 class TestSet(TestJointOps):
     thetype = set
+    basetype = set
 
     def test_init(self):
         s = self.thetype()
@@ -357,6 +358,7 @@ class TestSet(TestJointOps):
         dup = self.s.copy()
         self.assertEqual(self.s, dup)
         self.assertNotEqual(id(self.s), id(dup))
+        self.assertEqual(type(dup), self.basetype)
 
     def test_add(self):
         self.s.add('Q')
@@ -595,6 +597,7 @@ class SetSubclass(set):
 
 class TestSetSubclass(TestSet):
     thetype = SetSubclass
+    basetype = set
 
 class SetSubclassWithKeywordArgs(set):
     def __init__(self, iterable=[], newarg=None):
@@ -608,6 +611,7 @@ class TestSetSubclassWithKeywordArgs(TestSet):
 
 class TestFrozenSet(TestJointOps):
     thetype = frozenset
+    basetype = frozenset
 
     def test_init(self):
         s = self.thetype(self.word)
@@ -673,6 +677,7 @@ class FrozenSetSubclass(frozenset):
 
 class TestFrozenSetSubclass(TestFrozenSet):
     thetype = FrozenSetSubclass
+    basetype = frozenset
 
     def test_constructor_identity(self):
         s = self.thetype(range(3))
index 035867c4291fac10a00463f3f07e88e6a0f2ca38..874f27bc041ca1d80e4d519382f5c0ec45566b01 100644 (file)
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -13,6 +13,11 @@ What's New in Python 3.0 release candiate 3?
 Core and Builtins
 -----------------
 
+- Issue #1721812:  Binary set operations and copy() returned the input type
+  instead of the appropriate base type.  This was incorrect because set 
+  subclasses would be created without their __init__() method being called.
+  The corrected behavior brings sets into line with lists and dicts.
+
 - Issue #4296: Fix PyObject_RichCompareBool so that "x in [x]" evaluates to
   True, even when x doesn't compare equal to itself.  This was a regression
   from 2.6.
index d24e1af0d2bd757672b3fc6042b94851e295ff65..d08ff5f33775370729b0753f91de73c72323541a 100644 (file)
@@ -1017,6 +1017,18 @@ make_new_set(PyTypeObject *type, PyObject *iterable)
        return (PyObject *)so;
 }
 
+static PyObject *
+make_new_set_basetype(PyTypeObject *type, PyObject *iterable)
+{
+       if (type != &PySet_Type && type != &PyFrozenSet_Type) {
+               if (PyType_IsSubtype(type, &PySet_Type))
+                       type = &PySet_Type;
+               else
+                       type = &PyFrozenSet_Type;
+       }
+       return make_new_set(type, iterable);
+}
+
 /* The empty frozenset is a singleton */
 static PyObject *emptyfrozenset = NULL;
 
@@ -1129,7 +1141,7 @@ set_swap_bodies(PySetObject *a, PySetObject *b)
 static PyObject *
 set_copy(PySetObject *so)
 {
-       return make_new_set(Py_TYPE(so), (PyObject *)so);
+       return make_new_set_basetype(Py_TYPE(so), (PyObject *)so);
 }
 
 static PyObject *
@@ -1225,7 +1237,7 @@ set_intersection(PySetObject *so, PyObject *other)
        if ((PyObject *)so == other)
                return set_copy(so);
 
-       result = (PySetObject *)make_new_set(Py_TYPE(so), NULL);
+       result = (PySetObject *)make_new_set_basetype(Py_TYPE(so), NULL);
        if (result == NULL)
                return NULL;
 
@@ -1520,7 +1532,7 @@ set_difference(PySetObject *so, PyObject *other)
                return NULL;
        }
        
-       result = make_new_set(Py_TYPE(so), NULL);
+       result = make_new_set_basetype(Py_TYPE(so), NULL);
        if (result == NULL)
                return NULL;
 
@@ -1641,7 +1653,7 @@ set_symmetric_difference_update(PySetObject *so, PyObject *other)
                Py_INCREF(other);
                otherset = (PySetObject *)other;
        } else {
-               otherset = (PySetObject *)make_new_set(Py_TYPE(so), other);
+               otherset = (PySetObject *)make_new_set_basetype(Py_TYPE(so), other);
                if (otherset == NULL)
                        return NULL;
        }
@@ -1672,7 +1684,7 @@ set_symmetric_difference(PySetObject *so, PyObject *other)
        PyObject *rv;
        PySetObject *otherset;
 
-       otherset = (PySetObject *)make_new_set(Py_TYPE(so), other);
+       otherset = (PySetObject *)make_new_set_basetype(Py_TYPE(so), other);
        if (otherset == NULL)
                return NULL;
        rv = set_symmetric_difference_update(otherset, (PyObject *)so);