]> granicus.if.org Git - python/commitdiff
Fix some set algebra methods of WeakSet objects.
authorAntoine Pitrou <solipsis@pitrou.net>
Sun, 4 Mar 2012 19:47:05 +0000 (20:47 +0100)
committerAntoine Pitrou <solipsis@pitrou.net>
Sun, 4 Mar 2012 19:47:05 +0000 (20:47 +0100)
Lib/_weakrefset.py
Lib/test/test_weakset.py

index b8d804301acb8be304352f9266b81a183f932068..ff613e609c66c58390c6aa1449539e8b9c836a62 100644 (file)
@@ -123,26 +123,14 @@ class WeakSet(object):
         self.update(other)
         return self
 
-    # Helper functions for simple delegating methods.
-    def _apply(self, other, method):
-        if not isinstance(other, self.__class__):
-            other = self.__class__(other)
-        newdata = method(other.data)
-        newset = self.__class__()
-        newset.data = newdata
-        return newset
-
     def difference(self, other):
-        return self._apply(other, self.data.difference)
+        newset = self.copy()
+        newset.difference_update(other)
+        return newset
     __sub__ = difference
 
     def difference_update(self, other):
-        if self._pending_removals:
-            self._commit_removals()
-        if self is other:
-            self.data.clear()
-        else:
-            self.data.difference_update(ref(item) for item in other)
+        self.__isub__(other)
     def __isub__(self, other):
         if self._pending_removals:
             self._commit_removals()
@@ -153,13 +141,11 @@ class WeakSet(object):
         return self
 
     def intersection(self, other):
-        return self._apply(other, self.data.intersection)
+        return self.__class__(item for item in other if item in self)
     __and__ = intersection
 
     def intersection_update(self, other):
-        if self._pending_removals:
-            self._commit_removals()
-        self.data.intersection_update(ref(item) for item in other)
+        self.__iand__(other)
     def __iand__(self, other):
         if self._pending_removals:
             self._commit_removals()
@@ -186,27 +172,24 @@ class WeakSet(object):
         return self.data == set(ref(item) for item in other)
 
     def symmetric_difference(self, other):
-        return self._apply(other, self.data.symmetric_difference)
+        newset = self.copy()
+        newset.symmetric_difference_update(other)
+        return newset
     __xor__ = symmetric_difference
 
     def symmetric_difference_update(self, other):
-        if self._pending_removals:
-            self._commit_removals()
-        if self is other:
-            self.data.clear()
-        else:
-            self.data.symmetric_difference_update(ref(item) for item in other)
+        self.__ixor__(other)
     def __ixor__(self, other):
         if self._pending_removals:
             self._commit_removals()
         if self is other:
             self.data.clear()
         else:
-            self.data.symmetric_difference_update(ref(item) for item in other)
+            self.data.symmetric_difference_update(ref(item, self._remove) for item in other)
         return self
 
     def union(self, other):
-        return self._apply(other, self.data.union)
+        return self.__class__(e for s in (self, other) for e in s)
     __or__ = union
 
     def isdisjoint(self, other):
index f981bddb19337bfc15100be1b93ea57d8d10c493..1f82a7dda4dde81601dd086f21b85671a4a79fe7 100644 (file)
@@ -83,6 +83,11 @@ class TestWeakSet(unittest.TestCase):
             x = WeakSet(self.items + self.items2)
             c = C(self.items2)
             self.assertEqual(self.s.union(c), x)
+            del c
+        self.assertEqual(len(u), len(self.items) + len(self.items2))
+        self.items2.pop()
+        gc.collect()
+        self.assertEqual(len(u), len(self.items) + len(self.items2))
 
     def test_or(self):
         i = self.s.union(self.items2)
@@ -90,14 +95,19 @@ class TestWeakSet(unittest.TestCase):
         self.assertEqual(self.s | frozenset(self.items2), i)
 
     def test_intersection(self):
-        i = self.s.intersection(self.items2)
+        s = WeakSet(self.letters)
+        i = s.intersection(self.items2)
         for c in self.letters:
-            self.assertEqual(c in i, c in self.d and c in self.items2)
-        self.assertEqual(self.s, WeakSet(self.items))
+            self.assertEqual(c in i, c in self.items2 and c in self.letters)
+        self.assertEqual(s, WeakSet(self.letters))
         self.assertEqual(type(i), WeakSet)
         for C in set, frozenset, dict.fromkeys, list, tuple:
             x = WeakSet([])
-            self.assertEqual(self.s.intersection(C(self.items2)), x)
+            self.assertEqual(i.intersection(C(self.items)), x)
+        self.assertEqual(len(i), len(self.items2))
+        self.items2.pop()
+        gc.collect()
+        self.assertEqual(len(i), len(self.items2))
 
     def test_isdisjoint(self):
         self.assertTrue(self.s.isdisjoint(WeakSet(self.items2)))
@@ -128,6 +138,10 @@ class TestWeakSet(unittest.TestCase):
         self.assertEqual(self.s, WeakSet(self.items))
         self.assertEqual(type(i), WeakSet)
         self.assertRaises(TypeError, self.s.symmetric_difference, [[]])
+        self.assertEqual(len(i), len(self.items) + len(self.items2))
+        self.items2.pop()
+        gc.collect()
+        self.assertEqual(len(i), len(self.items) + len(self.items2))
 
     def test_xor(self):
         i = self.s.symmetric_difference(self.items2)