]> granicus.if.org Git - python/commitdiff
Issue #16373: Prevent infinite recursion for ABC Set class operations.
authorAndrew Svetlov <andrew.svetlov@gmail.com>
Thu, 1 Nov 2012 11:28:54 +0000 (13:28 +0200)
committerAndrew Svetlov <andrew.svetlov@gmail.com>
Thu, 1 Nov 2012 11:28:54 +0000 (13:28 +0200)
Lib/_abcoll.py
Lib/test/test_collections.py

index 2417d187cd3a8cbca2a6d735ada741a8a0a54945..5ddcea3a3018ac0b89031335dda276df4682930d 100644 (file)
@@ -184,12 +184,12 @@ class Set(Sized, Iterable, Container):
     def __gt__(self, other):
         if not isinstance(other, Set):
             return NotImplemented
-        return other < self
+        return other.__lt__(self)
 
     def __ge__(self, other):
         if not isinstance(other, Set):
             return NotImplemented
-        return other <= self
+        return other.__le__(self)
 
     def __eq__(self, other):
         if not isinstance(other, Set):
index 8dc5559cae3aebef445fce6a7e1a930b0df02ae5..b2a5f052604d92c3864abd2652f6f31850372f41 100644 (file)
@@ -651,6 +651,39 @@ class TestCollectionABCs(ABCTestCase):
         s |= s
         self.assertEqual(s, full)
 
+    def test_issue16373(self):
+        # Recursion error comparing comparable and noncomparable
+        # Set instances
+        class MyComparableSet(Set):
+            def __contains__(self, x):
+                return False
+            def __len__(self):
+                return 0
+            def __iter__(self):
+                return iter([])
+        class MyNonComparableSet(Set):
+            def __contains__(self, x):
+                return False
+            def __len__(self):
+                return 0
+            def __iter__(self):
+                return iter([])
+            def __le__(self, x):
+                return NotImplemented
+            def __lt__(self, x):
+                return NotImplemented
+
+        cs = MyComparableSet()
+        ncs = MyNonComparableSet()
+        with self.assertRaises(TypeError):
+            ncs < cs
+        with self.assertRaises(TypeError):
+            ncs <= cs
+        with self.assertRaises(TypeError):
+            cs > ncs
+        with self.assertRaises(TypeError):
+            cs >= ncs
+
     def test_Mapping(self):
         for sample in [dict]:
             self.assertIsInstance(sample(), Mapping)