]> granicus.if.org Git - python/commitdiff
Issue #10042: Fixed the total_ordering decorator to handle cross-type
authorRaymond Hettinger <python@rcn.com>
Sat, 8 Jan 2011 07:01:56 +0000 (07:01 +0000)
committerRaymond Hettinger <python@rcn.com>
Sat, 8 Jan 2011 07:01:56 +0000 (07:01 +0000)
comparisons that could lead to infinite recursion.

Lib/functools.py
Lib/test/test_functools.py
Misc/ACKS
Misc/NEWS

index d4506344859e311e7eddceed17111fb6965fd051..1062a452e6bf691ca7ebdd3bb7575cc3457cf749 100644 (file)
@@ -68,17 +68,17 @@ def wraps(wrapped,
 def total_ordering(cls):
     """Class decorator that fills in missing ordering methods"""
     convert = {
-        '__lt__': [('__gt__', lambda self, other: other < self),
-                   ('__le__', lambda self, other: not other < self),
+        '__lt__': [('__gt__', lambda self, other: not (self < other or self == other)),
+                   ('__le__', lambda self, other: self < other or self == other),
                    ('__ge__', lambda self, other: not self < other)],
-        '__le__': [('__ge__', lambda self, other: other <= self),
-                   ('__lt__', lambda self, other: not other <= self),
+        '__le__': [('__ge__', lambda self, other: not self <= other or self == other),
+                   ('__lt__', lambda self, other: self <= other and not self == other),
                    ('__gt__', lambda self, other: not self <= other)],
-        '__gt__': [('__lt__', lambda self, other: other > self),
-                   ('__ge__', lambda self, other: not other > self),
+        '__gt__': [('__lt__', lambda self, other: not (self > other or self == other)),
+                   ('__ge__', lambda self, other: self > other or self == other),
                    ('__le__', lambda self, other: not self > other)],
-        '__ge__': [('__le__', lambda self, other: other >= self),
-                   ('__gt__', lambda self, other: not other >= self),
+        '__ge__': [('__le__', lambda self, other: (not self >= other) or self == other),
+                   ('__gt__', lambda self, other: self >= other and not self == other),
                    ('__lt__', lambda self, other: not self >= other)]
     }
     # Find user-defined comparisons (not those inherited from object).
index f41a144a1415558e6df9624974c03b67461cbaf0..73a77d63f27b5f1484c22a2f617349ed795a9fcb 100644 (file)
@@ -457,6 +457,8 @@ class TestTotalOrdering(unittest.TestCase):
                 self.value = value
             def __lt__(self, other):
                 return self.value < other.value
+            def __eq__(self, other):
+                return self.value == other.value
         self.assertTrue(A(1) < A(2))
         self.assertTrue(A(2) > A(1))
         self.assertTrue(A(1) <= A(2))
@@ -471,6 +473,8 @@ class TestTotalOrdering(unittest.TestCase):
                 self.value = value
             def __le__(self, other):
                 return self.value <= other.value
+            def __eq__(self, other):
+                return self.value == other.value
         self.assertTrue(A(1) < A(2))
         self.assertTrue(A(2) > A(1))
         self.assertTrue(A(1) <= A(2))
@@ -485,6 +489,8 @@ class TestTotalOrdering(unittest.TestCase):
                 self.value = value
             def __gt__(self, other):
                 return self.value > other.value
+            def __eq__(self, other):
+                return self.value == other.value
         self.assertTrue(A(1) < A(2))
         self.assertTrue(A(2) > A(1))
         self.assertTrue(A(1) <= A(2))
@@ -499,6 +505,8 @@ class TestTotalOrdering(unittest.TestCase):
                 self.value = value
             def __ge__(self, other):
                 return self.value >= other.value
+            def __eq__(self, other):
+                return self.value == other.value
         self.assertTrue(A(1) < A(2))
         self.assertTrue(A(2) > A(1))
         self.assertTrue(A(1) <= A(2))
@@ -524,6 +532,22 @@ class TestTotalOrdering(unittest.TestCase):
             class A:
                 pass
 
+    def test_bug_10042(self):
+        @functools.total_ordering
+        class TestTO:
+            def __init__(self, value):
+                self.value = value
+            def __eq__(self, other):
+                if isinstance(other, TestTO):
+                    return self.value == other.value
+                return False
+            def __lt__(self, other):
+                if isinstance(other, TestTO):
+                    return self.value < other.value
+                raise TypeError
+        with self.assertRaises(TypeError):
+            TestTO(8) <= ()
+
 class TestLRU(unittest.TestCase):
 
     def test_lru(self):
index 6010ace7118780645b50dbcab98e162c0f671841..6929e36cd392acd11c286dbbb1d9c0fa2d3b77f3 100644 (file)
--- a/Misc/ACKS
+++ b/Misc/ACKS
@@ -700,6 +700,7 @@ Bernhard Reiter
 Steven Reiz
 Roeland Rengelink
 Tim Rice
+Francesco Ricciardi
 Jan Pieter Riegel
 Armin Rigo
 Nicholas Riley
index 8c7c7fd73a4f4f396f9364743730065bb9af162d..06832452f2bb0b6fe5171d625363aa03a926ab07 100644 (file)
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -40,6 +40,9 @@ Core and Builtins
 Library
 -------
 
+- Issue #10042: Fixed the total_ordering decorator to handle cross-type
+  comparisons that could lead to infinite recursion.
+
 - Issue #10686: the email package now :rfc:`2047`\ -encodes headers with
   non-ASCII bytes (parsed by a Bytes Parser) when doing conversion to
   7bit-clean presentation, instead of replacing them with ?s.