]> granicus.if.org Git - python/commitdiff
Issue 14814: Eliminate bytes warnings from ipaddress by correctly throwing an excepti...
authorNick Coghlan <ncoghlan@gmail.com>
Fri, 6 Jul 2012 15:43:31 +0000 (01:43 +1000)
committerNick Coghlan <ncoghlan@gmail.com>
Fri, 6 Jul 2012 15:43:31 +0000 (01:43 +1000)
Lib/ipaddress.py
Lib/test/test_ipaddress.py

index 352c9b87c466095ed70c37b18396431bbc0f094b..9a1ba728f2bea609491d2918fe7b9f78e6570312 100644 (file)
@@ -1250,7 +1250,9 @@ class IPv4Address(_BaseV4, _BaseAddress):
             return
 
         # Constructing from a packed address
-        if isinstance(address, bytes) and len(address) == 4:
+        if isinstance(address, bytes):
+            if len(address) != 4:
+                raise AddressValueError(address)
             self._ip = struct.unpack('!I', address)[0]
             return
 
@@ -1379,7 +1381,9 @@ class IPv4Network(_BaseV4, _BaseNetwork):
         _BaseNetwork.__init__(self, address)
 
         # Constructing from a packed address
-        if isinstance(address, bytes) and len(address) == 4:
+        if isinstance(address, bytes):
+            if len(address) != 4:
+                raise AddressValueError(address)
             self.network_address = IPv4Address(
                 struct.unpack('!I', address)[0])
             self._prefixlen = self._max_prefixlen
@@ -1864,7 +1868,9 @@ class IPv6Address(_BaseV6, _BaseAddress):
             return
 
         # Constructing from a packed address
-        if isinstance(address, bytes) and len(address) == 16:
+        if isinstance(address, bytes):
+            if len(address) != 16:
+                raise AddressValueError(address)
             tmp = struct.unpack('!QQ', address)
             self._ip = (tmp[0] << 64) | tmp[1]
             return
@@ -1996,7 +2002,9 @@ class IPv6Network(_BaseV6, _BaseNetwork):
             return
 
         # Constructing from a packed address
-        if isinstance(address, bytes) and len(address) == 16:
+        if isinstance(address, bytes):
+            if len(address) != 16:
+                raise AddressValueError(address)
             tmp = struct.unpack('!QQ', address)
             self.network_address = IPv6Address((tmp[0] << 64) | tmp[1])
             self._prefixlen = self._max_prefixlen
index c9ced59c2348ad688a583889e2c4b950360dc47d..5cd2ad4d198accccb69c84acbe6c542122ad46cf 100644 (file)
@@ -8,10 +8,6 @@ import unittest
 import ipaddress
 
 
-# Compatibility function to cast str to bytes objects
-_cb = lambda bytestr: bytes(bytestr, 'charmap')
-
-
 class IpaddrUnitTest(unittest.TestCase):
 
     def setUp(self):
@@ -267,25 +263,36 @@ class IpaddrUnitTest(unittest.TestCase):
                          6)
 
     def testIpFromPacked(self):
-        ip = ipaddress.ip_network
-
+        address = ipaddress.ip_address
         self.assertEqual(self.ipv4_interface._ip,
-                         ipaddress.ip_interface(_cb('\x01\x02\x03\x04'))._ip)
-        self.assertEqual(ip('255.254.253.252'),
-                         ip(_cb('\xff\xfe\xfd\xfc')))
-        self.assertRaises(ValueError, ipaddress.ip_network, _cb('\x00' * 3))
-        self.assertRaises(ValueError, ipaddress.ip_network, _cb('\x00' * 5))
+                         ipaddress.ip_interface(b'\x01\x02\x03\x04')._ip)
+        self.assertEqual(address('255.254.253.252'),
+                         address(b'\xff\xfe\xfd\xfc'))
         self.assertEqual(self.ipv6_interface.ip,
                          ipaddress.ip_interface(
-                _cb('\x20\x01\x06\x58\x02\x2a\xca\xfe'
-                    '\x02\x00\x00\x00\x00\x00\x00\x01')).ip)
-        self.assertEqual(ip('ffff:2:3:4:ffff::'),
-                         ip(_cb('\xff\xff\x00\x02\x00\x03\x00\x04' +
-                                '\xff\xff' + '\x00' * 6)))
-        self.assertEqual(ip('::'),
-                         ip(_cb('\x00' * 16)))
-        self.assertRaises(ValueError, ip, _cb('\x00' * 15))
-        self.assertRaises(ValueError, ip, _cb('\x00' * 17))
+                    b'\x20\x01\x06\x58\x02\x2a\xca\xfe'
+                    b'\x02\x00\x00\x00\x00\x00\x00\x01').ip)
+        self.assertEqual(address('ffff:2:3:4:ffff::'),
+                         address(b'\xff\xff\x00\x02\x00\x03\x00\x04' +
+                            b'\xff\xff' + b'\x00' * 6))
+        self.assertEqual(address('::'),
+                         address(b'\x00' * 16))
+
+    def testIpFromPackedErrors(self):
+        def assertInvalidPackedAddress(f, length):
+            self.assertRaises(ValueError, f, b'\x00' * length)
+        assertInvalidPackedAddress(ipaddress.ip_address, 3)
+        assertInvalidPackedAddress(ipaddress.ip_address, 5)
+        assertInvalidPackedAddress(ipaddress.ip_address, 15)
+        assertInvalidPackedAddress(ipaddress.ip_address, 17)
+        assertInvalidPackedAddress(ipaddress.ip_interface, 3)
+        assertInvalidPackedAddress(ipaddress.ip_interface, 5)
+        assertInvalidPackedAddress(ipaddress.ip_interface, 15)
+        assertInvalidPackedAddress(ipaddress.ip_interface, 17)
+        assertInvalidPackedAddress(ipaddress.ip_network, 3)
+        assertInvalidPackedAddress(ipaddress.ip_network, 5)
+        assertInvalidPackedAddress(ipaddress.ip_network, 15)
+        assertInvalidPackedAddress(ipaddress.ip_network, 17)
 
     def testGetIp(self):
         self.assertEqual(int(self.ipv4_interface.ip), 16909060)
@@ -893,17 +900,17 @@ class IpaddrUnitTest(unittest.TestCase):
 
     def testPacked(self):
         self.assertEqual(self.ipv4_address.packed,
-                         _cb('\x01\x02\x03\x04'))
+                         b'\x01\x02\x03\x04')
         self.assertEqual(ipaddress.IPv4Interface('255.254.253.252').packed,
-                         _cb('\xff\xfe\xfd\xfc'))
+                         b'\xff\xfe\xfd\xfc')
         self.assertEqual(self.ipv6_address.packed,
-                         _cb('\x20\x01\x06\x58\x02\x2a\xca\xfe'
-                             '\x02\x00\x00\x00\x00\x00\x00\x01'))
+                         b'\x20\x01\x06\x58\x02\x2a\xca\xfe'
+                         b'\x02\x00\x00\x00\x00\x00\x00\x01')
         self.assertEqual(ipaddress.IPv6Interface('ffff:2:3:4:ffff::').packed,
-                         _cb('\xff\xff\x00\x02\x00\x03\x00\x04\xff\xff'
-                            + '\x00' * 6))
+                         b'\xff\xff\x00\x02\x00\x03\x00\x04\xff\xff'
+                            + b'\x00' * 6)
         self.assertEqual(ipaddress.IPv6Interface('::1:0:0:0:0').packed,
-                         _cb('\x00' * 6 + '\x00\x01' + '\x00' * 8))
+                         b'\x00' * 6 + b'\x00\x01' + b'\x00' * 8)
 
     def testIpStrFromPrefixlen(self):
         ipv4 = ipaddress.IPv4Interface('1.2.3.4/24')