]> granicus.if.org Git - python/commitdiff
Issue #23268: Fixed bugs in the comparison of ipaddress classes.
authorSerhiy Storchaka <storchaka@gmail.com>
Mon, 26 Jan 2015 08:11:16 +0000 (10:11 +0200)
committerSerhiy Storchaka <storchaka@gmail.com>
Mon, 26 Jan 2015 08:11:16 +0000 (10:11 +0200)
Lib/ipaddress.py
Lib/test/test_ipaddress.py
Misc/NEWS

index ebc04bb145833081888b9bb6a73046886a4d967e..ac03c36ce08e2d4e78053fdc1e803fdce5ea6d9b 100644 (file)
@@ -388,40 +388,7 @@ def get_mixed_type_key(obj):
     return NotImplemented
 
 
-class _TotalOrderingMixin:
-    # Helper that derives the other comparison operations from
-    # __lt__ and __eq__
-    # We avoid functools.total_ordering because it doesn't handle
-    # NotImplemented correctly yet (http://bugs.python.org/issue10042)
-    def __eq__(self, other):
-        raise NotImplementedError
-    def __ne__(self, other):
-        equal = self.__eq__(other)
-        if equal is NotImplemented:
-            return NotImplemented
-        return not equal
-    def __lt__(self, other):
-        raise NotImplementedError
-    def __le__(self, other):
-        less = self.__lt__(other)
-        if less is NotImplemented or not less:
-            return self.__eq__(other)
-        return less
-    def __gt__(self, other):
-        less = self.__lt__(other)
-        if less is NotImplemented:
-            return NotImplemented
-        equal = self.__eq__(other)
-        if equal is NotImplemented:
-            return NotImplemented
-        return not (less or equal)
-    def __ge__(self, other):
-        less = self.__lt__(other)
-        if less is NotImplemented:
-            return NotImplemented
-        return not less
-
-class _IPAddressBase(_TotalOrderingMixin):
+class _IPAddressBase:
 
     """The mother class."""
 
@@ -554,6 +521,7 @@ class _IPAddressBase(_TotalOrderingMixin):
             self._report_invalid_netmask(ip_str)
 
 
+@functools.total_ordering
 class _BaseAddress(_IPAddressBase):
 
     """A generic IP object.
@@ -578,12 +546,11 @@ class _BaseAddress(_IPAddressBase):
             return NotImplemented
 
     def __lt__(self, other):
+        if not isinstance(other, _BaseAddress):
+            return NotImplemented
         if self._version != other._version:
             raise TypeError('%s and %s are not of the same version' % (
                              self, other))
-        if not isinstance(other, _BaseAddress):
-            raise TypeError('%s and %s are not of the same type' % (
-                             self, other))
         if self._ip != other._ip:
             return self._ip < other._ip
         return False
@@ -613,6 +580,7 @@ class _BaseAddress(_IPAddressBase):
         return (self._version, self)
 
 
+@functools.total_ordering
 class _BaseNetwork(_IPAddressBase):
 
     """A generic IP network object.
@@ -662,12 +630,11 @@ class _BaseNetwork(_IPAddressBase):
             return self._address_class(broadcast + n)
 
     def __lt__(self, other):
+        if not isinstance(other, _BaseNetwork):
+            return NotImplemented
         if self._version != other._version:
             raise TypeError('%s and %s are not of the same version' % (
                              self, other))
-        if not isinstance(other, _BaseNetwork):
-            raise TypeError('%s and %s are not of the same type' % (
-                             self, other))
         if self.network_address != other.network_address:
             return self.network_address < other.network_address
         if self.netmask != other.netmask:
index a0fe55c99f9edf6dea4ad6dedde1c24891186213..bfb569950f0d6e0fa6d45c53c4956d6cb2d5e4ac 100644 (file)
@@ -7,6 +7,7 @@
 import unittest
 import re
 import contextlib
+import functools
 import operator
 import ipaddress
 
@@ -528,6 +529,20 @@ class FactoryFunctionErrors(BaseTestCase):
         self.assertFactoryError(ipaddress.ip_network, "network")
 
 
+@functools.total_ordering
+class LargestObject:
+    def __eq__(self, other):
+        return isinstance(other, LargestObject)
+    def __lt__(self, other):
+        return False
+
+@functools.total_ordering
+class SmallestObject:
+    def __eq__(self, other):
+        return isinstance(other, SmallestObject)
+    def __gt__(self, other):
+        return False
+
 class ComparisonTests(unittest.TestCase):
 
     v4addr = ipaddress.IPv4Address(1)
@@ -581,6 +596,28 @@ class ComparisonTests(unittest.TestCase):
                 self.assertRaises(TypeError, lambda: lhs <= rhs)
                 self.assertRaises(TypeError, lambda: lhs >= rhs)
 
+    def test_foreign_type_ordering(self):
+        other = object()
+        smallest = SmallestObject()
+        largest = LargestObject()
+        for obj in self.objects:
+            with self.assertRaises(TypeError):
+                obj < other
+            with self.assertRaises(TypeError):
+                obj > other
+            with self.assertRaises(TypeError):
+                obj <= other
+            with self.assertRaises(TypeError):
+                obj >= other
+            self.assertTrue(obj < largest)
+            self.assertFalse(obj > largest)
+            self.assertTrue(obj <= largest)
+            self.assertFalse(obj >= largest)
+            self.assertFalse(obj < smallest)
+            self.assertTrue(obj > smallest)
+            self.assertFalse(obj <= smallest)
+            self.assertTrue(obj >= smallest)
+
     def test_mixed_type_key(self):
         # with get_mixed_type_key, you can sort addresses and network.
         v4_ordered = [self.v4addr, self.v4net, self.v4intf]
@@ -601,7 +638,7 @@ class ComparisonTests(unittest.TestCase):
         v4addr = ipaddress.ip_address('1.1.1.1')
         v4net = ipaddress.ip_network('1.1.1.1')
         v6addr = ipaddress.ip_address('::1')
-        v6net = ipaddress.ip_address('::1')
+        v6net = ipaddress.ip_network('::1')
 
         self.assertRaises(TypeError, v4addr.__lt__, v6addr)
         self.assertRaises(TypeError, v4addr.__gt__, v6addr)
@@ -1248,10 +1285,10 @@ class IpaddrUnitTest(unittest.TestCase):
         unsorted = [ip4, ip1, ip3, ip2]
         unsorted.sort()
         self.assertEqual(sorted, unsorted)
-        self.assertRaises(TypeError, ip1.__lt__,
-                          ipaddress.ip_address('10.10.10.0'))
-        self.assertRaises(TypeError, ip2.__lt__,
-                          ipaddress.ip_address('10.10.10.0'))
+        self.assertIs(ip1.__lt__(ipaddress.ip_address('10.10.10.0')),
+                      NotImplemented)
+        self.assertIs(ip2.__lt__(ipaddress.ip_address('10.10.10.0')),
+                      NotImplemented)
 
         # <=, >=
         self.assertTrue(ipaddress.ip_network('1.1.1.1') <=
index 6970e8043b2ef5173425b90ca1c38293b082a435..e9ac72f9e8a4b5c749f9ee53053b494f9c36011f 100644 (file)
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -50,6 +50,8 @@ Core and Builtins
 Library
 -------
 
+- Issue #23268: Fixed bugs in the comparison of ipaddress classes.
+
 - Issue #21408: Removed incorrect implementations of __ne__() which didn't
   returned NotImplemented if __eq__() returned NotImplemented.  The default
   __ne__() now works correctly.