]> granicus.if.org Git - python/commitdiff
Issue 8743: Improve interoperability between sets and the collections.Set abstract...
authorRaymond Hettinger <python@rcn.com>
Mon, 26 May 2014 05:13:41 +0000 (22:13 -0700)
committerRaymond Hettinger <python@rcn.com>
Mon, 26 May 2014 05:13:41 +0000 (22:13 -0700)
Lib/_abcoll.py
Lib/test/test_collections.py
Lib/test/test_set.py
Misc/NEWS
Objects/setobject.c

index 8b650a7763f423fb425fbc78a5c0577482ee6809..a943263320d2448257c7af83e6e81cdf4a479285 100644 (file)
@@ -165,12 +165,17 @@ class Set(Sized, Iterable, Container):
     def __gt__(self, other):
         if not isinstance(other, Set):
             return NotImplemented
-        return other.__lt__(self)
+        return len(self) > len(other) and self.__ge__(other)
 
     def __ge__(self, other):
         if not isinstance(other, Set):
             return NotImplemented
-        return other.__le__(self)
+        if len(self) < len(other):
+            return False
+        for elem in other:
+            if elem not in self:
+                return False
+        return True
 
     def __eq__(self, other):
         if not isinstance(other, Set):
@@ -194,6 +199,8 @@ class Set(Sized, Iterable, Container):
             return NotImplemented
         return self._from_iterable(value for value in other if value in self)
 
+    __rand__ = __and__
+
     def isdisjoint(self, other):
         'Return True if two sets have a null intersection.'
         for value in other:
@@ -207,6 +214,8 @@ class Set(Sized, Iterable, Container):
         chain = (e for s in (self, other) for e in s)
         return self._from_iterable(chain)
 
+    __ror__ = __or__
+
     def __sub__(self, other):
         if not isinstance(other, Set):
             if not isinstance(other, Iterable):
@@ -215,6 +224,14 @@ class Set(Sized, Iterable, Container):
         return self._from_iterable(value for value in self
                                    if value not in other)
 
+    def __rsub__(self, other):
+        if not isinstance(other, Set):
+            if not isinstance(other, Iterable):
+                return NotImplemented
+            other = self._from_iterable(other)
+        return self._from_iterable(value for value in other
+                                   if value not in self)
+
     def __xor__(self, other):
         if not isinstance(other, Set):
             if not isinstance(other, Iterable):
@@ -222,6 +239,8 @@ class Set(Sized, Iterable, Container):
             other = self._from_iterable(other)
         return (self - other) | (other - self)
 
+    __rxor__ = __xor__
+
     # Sets are not hashable by default, but subclasses can change this
     __hash__ = None
 
index 784a31880e3bf1709609abb91b6573d35256850c..de4ba86df844f57cf5b900b9695bfd7d99f0d8ed 100644 (file)
@@ -8,6 +8,7 @@ import pickle, cPickle, copy
 from random import randrange, shuffle
 import keyword
 import re
+import sets
 import sys
 from collections import Hashable, Iterable, Iterator
 from collections import Sized, Container, Callable
@@ -618,10 +619,173 @@ class TestCollectionABCs(ABCTestCase):
 
         cs = MyComparableSet()
         ncs = MyNonComparableSet()
-        self.assertFalse(ncs < cs)
-        self.assertFalse(ncs <= cs)
-        self.assertFalse(cs > ncs)
-        self.assertFalse(cs >= ncs)
+
+        # Run all the variants to make sure they don't mutually recurse
+        ncs < cs
+        ncs <= cs
+        ncs > cs
+        ncs >= cs
+        cs < ncs
+        cs <= ncs
+        cs > ncs
+        cs >= ncs
+
+    def assertSameSet(self, s1, s2):
+        # coerce both to a real set then check equality
+        self.assertEqual(set(s1), set(s2))
+
+    def test_Set_interoperability_with_real_sets(self):
+        # Issue: 8743
+        class ListSet(Set):
+            def __init__(self, elements=()):
+                self.data = []
+                for elem in elements:
+                    if elem not in self.data:
+                        self.data.append(elem)
+            def __contains__(self, elem):
+                return elem in self.data
+            def __iter__(self):
+                return iter(self.data)
+            def __len__(self):
+                return len(self.data)
+            def __repr__(self):
+                return 'Set({!r})'.format(self.data)
+
+        r1 = set('abc')
+        r2 = set('bcd')
+        r3 = set('abcde')
+        f1 = ListSet('abc')
+        f2 = ListSet('bcd')
+        f3 = ListSet('abcde')
+        l1 = list('abccba')
+        l2 = list('bcddcb')
+        l3 = list('abcdeedcba')
+        p1 = sets.Set('abc')
+        p2 = sets.Set('bcd')
+        p3 = sets.Set('abcde')
+
+        target = r1 & r2
+        self.assertSameSet(f1 & f2, target)
+        self.assertSameSet(f1 & r2, target)
+        self.assertSameSet(r2 & f1, target)
+        self.assertSameSet(f1 & p2, target)
+        self.assertSameSet(p2 & f1, target)
+        self.assertSameSet(f1 & l2, target)
+
+        target = r1 | r2
+        self.assertSameSet(f1 | f2, target)
+        self.assertSameSet(f1 | r2, target)
+        self.assertSameSet(r2 | f1, target)
+        self.assertSameSet(f1 | p2, target)
+        self.assertSameSet(p2 | f1, target)
+        self.assertSameSet(f1 | l2, target)
+
+        fwd_target = r1 - r2
+        rev_target = r2 - r1
+        self.assertSameSet(f1 - f2, fwd_target)
+        self.assertSameSet(f2 - f1, rev_target)
+        self.assertSameSet(f1 - r2, fwd_target)
+        self.assertSameSet(f2 - r1, rev_target)
+        self.assertSameSet(r1 - f2, fwd_target)
+        self.assertSameSet(r2 - f1, rev_target)
+        self.assertSameSet(f1 - p2, fwd_target)
+        self.assertSameSet(f2 - p1, rev_target)
+        self.assertSameSet(p1 - f2, fwd_target)
+        self.assertSameSet(p2 - f1, rev_target)
+        self.assertSameSet(f1 - l2, fwd_target)
+        self.assertSameSet(f2 - l1, rev_target)
+
+        target = r1 ^ r2
+        self.assertSameSet(f1 ^ f2, target)
+        self.assertSameSet(f1 ^ r2, target)
+        self.assertSameSet(r2 ^ f1, target)
+        self.assertSameSet(f1 ^ p2, target)
+        self.assertSameSet(p2 ^ f1, target)
+        self.assertSameSet(f1 ^ l2, target)
+
+        # proper subset
+        self.assertTrue(f1 < f3)
+        self.assertFalse(f1 < f1)
+        self.assertFalse(f1 < f2)
+        self.assertTrue(r1 < f3)
+        self.assertFalse(r1 < f1)
+        self.assertFalse(r1 < f2)
+        self.assertTrue(r1 < r3)
+        self.assertFalse(r1 < r1)
+        self.assertFalse(r1 < r2)
+        # python 2 only, cross-type compares will succeed
+        f1 < l3
+        f1 < l1
+        f1 < l2
+
+        # any subset
+        self.assertTrue(f1 <= f3)
+        self.assertTrue(f1 <= f1)
+        self.assertFalse(f1 <= f2)
+        self.assertTrue(r1 <= f3)
+        self.assertTrue(r1 <= f1)
+        self.assertFalse(r1 <= f2)
+        self.assertTrue(r1 <= r3)
+        self.assertTrue(r1 <= r1)
+        self.assertFalse(r1 <= r2)
+        # python 2 only, cross-type compares will succeed
+        f1 <= l3
+        f1 <= l1
+        f1 <= l2
+
+        # proper superset
+        self.assertTrue(f3 > f1)
+        self.assertFalse(f1 > f1)
+        self.assertFalse(f2 > f1)
+        self.assertTrue(r3 > r1)
+        self.assertFalse(f1 > r1)
+        self.assertFalse(f2 > r1)
+        self.assertTrue(r3 > r1)
+        self.assertFalse(r1 > r1)
+        self.assertFalse(r2 > r1)
+        # python 2 only, cross-type compares will succeed
+        f1 > l3
+        f1 > l1
+        f1 > l2
+
+        # any superset
+        self.assertTrue(f3 >= f1)
+        self.assertTrue(f1 >= f1)
+        self.assertFalse(f2 >= f1)
+        self.assertTrue(r3 >= r1)
+        self.assertTrue(f1 >= r1)
+        self.assertFalse(f2 >= r1)
+        self.assertTrue(r3 >= r1)
+        self.assertTrue(r1 >= r1)
+        self.assertFalse(r2 >= r1)
+        # python 2 only, cross-type compares will succeed
+        f1 >= l3
+        f1 >=l1
+        f1 >= l2
+
+        # equality
+        self.assertTrue(f1 == f1)
+        self.assertTrue(r1 == f1)
+        self.assertTrue(f1 == r1)
+        self.assertFalse(f1 == f3)
+        self.assertFalse(r1 == f3)
+        self.assertFalse(f1 == r3)
+        # python 2 only, cross-type compares will succeed
+        f1 == l3
+        f1 == l1
+        f1 == l2
+
+        # inequality
+        self.assertFalse(f1 != f1)
+        self.assertFalse(r1 != f1)
+        self.assertFalse(f1 != r1)
+        self.assertTrue(f1 != f3)
+        self.assertTrue(r1 != f3)
+        self.assertTrue(f1 != r3)
+        # python 2 only, cross-type compares will succeed
+        f1 != l3
+        f1 != l1
+        f1 != l2
 
     def test_Mapping(self):
         for sample in [dict]:
index 82a08fec0b35ceb84b5b35192d37b9215b6e4ad0..610be7c4369ed850a18d0614de83de3c96213dbb 100644 (file)
@@ -1017,8 +1017,6 @@ class TestBinaryOps(unittest.TestCase):
         # without calling __cmp__.
         self.assertEqual(cmp(a, a), 0)
 
-        self.assertRaises(TypeError, cmp, a, 12)
-        self.assertRaises(TypeError, cmp, "abc", a)
 
 #==============================================================================
 
@@ -1269,17 +1267,6 @@ class TestOnlySetsInBinaryOps(unittest.TestCase):
         self.assertEqual(self.other != self.set, True)
         self.assertEqual(self.set != self.other, True)
 
-    def test_ge_gt_le_lt(self):
-        self.assertRaises(TypeError, lambda: self.set < self.other)
-        self.assertRaises(TypeError, lambda: self.set <= self.other)
-        self.assertRaises(TypeError, lambda: self.set > self.other)
-        self.assertRaises(TypeError, lambda: self.set >= self.other)
-
-        self.assertRaises(TypeError, lambda: self.other < self.set)
-        self.assertRaises(TypeError, lambda: self.other <= self.set)
-        self.assertRaises(TypeError, lambda: self.other > self.set)
-        self.assertRaises(TypeError, lambda: self.other >= self.set)
-
     def test_update_operator(self):
         try:
             self.set |= self.other
@@ -1392,18 +1379,6 @@ class TestOnlySetsDict(TestOnlySetsInBinaryOps):
 
 #------------------------------------------------------------------------------
 
-class TestOnlySetsOperator(TestOnlySetsInBinaryOps):
-    def setUp(self):
-        self.set   = set((1, 2, 3))
-        self.other = operator.add
-        self.otherIsIterable = False
-
-    def test_ge_gt_le_lt(self):
-        with test_support.check_py3k_warnings():
-            super(TestOnlySetsOperator, self).test_ge_gt_le_lt()
-
-#------------------------------------------------------------------------------
-
 class TestOnlySetsTuple(TestOnlySetsInBinaryOps):
     def setUp(self):
         self.set   = set((1, 2, 3))
@@ -1801,7 +1776,6 @@ def test_main(verbose=None):
         TestSubsetNonOverlap,
         TestOnlySetsNumeric,
         TestOnlySetsDict,
-        TestOnlySetsOperator,
         TestOnlySetsTuple,
         TestOnlySetsString,
         TestOnlySetsGenerator,
index 04e9221d944058931dd4870ae17b5c00c3ab4b65..34c007489dcd84e841aa1873eb92920338348278 100644 (file)
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -18,6 +18,9 @@ Core and Builtins
 Library
 -------
 
+- Issue #8743: Fix interoperability between set objects and the
+  collections.Set() abstract base class.
+
 Tests
 -----
 
index db5cee284037a3c8753348e8cea276925ddef81a..b4b117802158ccbc5ad4de0f2ab69d4374a89308 100644 (file)
@@ -1796,12 +1796,8 @@ set_richcompare(PySetObject *v, PyObject *w, int op)
     PyObject *r1, *r2;
 
     if(!PyAnySet_Check(w)) {
-        if (op == Py_EQ)
-            Py_RETURN_FALSE;
-        if (op == Py_NE)
-            Py_RETURN_TRUE;
-        PyErr_SetString(PyExc_TypeError, "can only compare to a set");
-        return NULL;
+        Py_INCREF(Py_NotImplemented);
+        return Py_NotImplemented;
     }
     switch (op) {
     case Py_EQ: