]> granicus.if.org Git - python/commitdiff
Removed implicit convertions of str object to bytes from base64.
authorAlexandre Vassalotti <alexandre@peadrop.com>
Sat, 3 May 2008 04:39:38 +0000 (04:39 +0000)
committerAlexandre Vassalotti <alexandre@peadrop.com>
Sat, 3 May 2008 04:39:38 +0000 (04:39 +0000)
This also exposed some bugs in urlib2 and email.base64mime, which I
tried my best to fix. However, someone will probably have to double
check.

Lib/base64.py
Lib/email/base64mime.py
Lib/test/test_base64.py
Lib/urllib2.py

index fc05ea9f81efb93bda0cda787fec0b9af596a4d5..fc80835add0ef7b59a2dfd3b4f52cc9091e77134 100755 (executable)
@@ -53,12 +53,13 @@ def b64encode(s, altchars=None):
     The encoded byte string is returned.
     """
     if not isinstance(s, bytes_types):
-        s = bytes(s, "ascii")
+        raise TypeError("expected bytes, not %s" % s.__class__.__name__)
     # Strip off the trailing newline
     encoded = binascii.b2a_base64(s)[:-1]
     if altchars is not None:
         if not isinstance(altchars, bytes_types):
-            altchars = bytes(altchars, "ascii")
+            altchars = TypeError("expected bytes, not %s"
+                                 % altchars.__class__.__name__)
         assert len(altchars) == 2, repr(altchars)
         return _translate(encoded, {'+': altchars[0:1], '/': altchars[1:2]})
     return encoded
@@ -76,10 +77,11 @@ def b64decode(s, altchars=None):
     present in the string.
     """
     if not isinstance(s, bytes_types):
-        s = bytes(s)
+        raise TypeError("expected bytes, not %s" % s.__class__.__name__)
     if altchars is not None:
         if not isinstance(altchars, bytes_types):
-            altchars = bytes(altchars, "ascii")
+            raise TypeError("expected bytes, not %s"
+                            % altchars.__class__.__name__)
         assert len(altchars) == 2, repr(altchars)
         s = _translate(s, {chr(altchars[0]): b'+', chr(altchars[1]): b'/'})
     return binascii.a2b_base64(s)
@@ -148,7 +150,7 @@ def b32encode(s):
     s is the byte string to encode.  The encoded byte string is returned.
     """
     if not isinstance(s, bytes_types):
-        s = bytes(s)
+        raise TypeError("expected bytes, not %s" % s.__class__.__name__)
     quanta, leftover = divmod(len(s), 5)
     # Pad the last quantum with zero bits if necessary
     if leftover:
@@ -205,16 +207,16 @@ def b32decode(s, casefold=False, map01=None):
     characters present in the input.
     """
     if not isinstance(s, bytes_types):
-        s = bytes(s)
+        raise TypeError("expected bytes, not %s" % s.__class__.__name__)
     quanta, leftover = divmod(len(s), 8)
     if leftover:
         raise binascii.Error('Incorrect padding')
     # Handle section 2.4 zero and one mapping.  The flag map01 will be either
     # False, or the character to map the digit 1 (one) to.  It should be
     # either L (el) or I (eye).
-    if map01:
+    if map01 is not None:
         if not isinstance(map01, bytes_types):
-            map01 = bytes(map01)
+            raise TypeError("expected bytes, not %s" % map01.__class__.__name__)
         assert len(map01) == 1, repr(map01)
         s = _translate(s, {b'0': b'O', b'1': map01})
     if casefold:
@@ -269,6 +271,8 @@ def b16encode(s):
 
     s is the byte string to encode.  The encoded byte string is returned.
     """
+    if not isinstance(s, bytes_types):
+        raise TypeError("expected bytes, not %s" % s.__class__.__name__)
     return binascii.hexlify(s).upper()
 
 
@@ -284,7 +288,7 @@ def b16decode(s, casefold=False):
     present in the string.
     """
     if not isinstance(s, bytes_types):
-        s = bytes(s)
+        raise TypeError("expected bytes, not %s" % s.__class__.__name__)
     if casefold:
         s = s.upper()
     if re.search('[^0-9A-F]', s):
index c60f8dbe253a2c525a483a29f3d84e8ab3b08d95..6db007dc1978bd329bfe05cdbbc4efe564be0ede 100644 (file)
@@ -66,9 +66,10 @@ def header_encode(header_bytes, charset='iso-8859-1'):
     charset names the character set to use to encode the header.  It defaults
     to iso-8859-1.  Base64 encoding is defined in RFC 2045.
     """
-    # Return empty headers unchanged
     if not header_bytes:
-        return str(header_bytes)
+        return ""
+    if isinstance(header_bytes, str):
+        header_bytes = header_bytes.encode(charset)
     encoded = b64encode(header_bytes).decode("ascii")
     return '=?%s?b?%s?=' % (charset, encoded)
 
index 6f886957d8285a4d641b9d8a3b9b60d88e9b377a..c50652cac50e51cde382b526be5d7c7ef5af9e4d 100644 (file)
@@ -19,6 +19,7 @@ class LegacyBase64TestCase(unittest.TestCase):
            b"YWJjZGVmZ2hpamtsbW5vcHFyc3R1dnd4eXpBQkNE"
            b"RUZHSElKS0xNTk9QUVJTVFVWV1hZWjAxMjM0\nNT"
            b"Y3ODkhQCMwXiYqKCk7Ojw+LC4gW117fQ==\n")
+        self.assertRaises(TypeError, base64.encodestring, "")
 
     def test_decodestring(self):
         eq = self.assertEqual
@@ -33,6 +34,7 @@ class LegacyBase64TestCase(unittest.TestCase):
            b"ABCDEFGHIJKLMNOPQRSTUVWXYZ"
            b"0123456789!@#0^&*();:<>,. []{}")
         eq(base64.decodestring(b''), b'')
+        self.assertRaises(TypeError, base64.decodestring, "")
 
     def test_encode(self):
         eq = self.assertEqual
@@ -54,7 +56,6 @@ class LegacyBase64TestCase(unittest.TestCase):
         base64.decode(infp, outfp)
         self.assertEqual(outfp.getvalue(), b'www.python.org')
 
-
 \f
 class BaseXYTestCase(unittest.TestCase):
     def test_b64encode(self):
@@ -73,7 +74,10 @@ class BaseXYTestCase(unittest.TestCase):
            b"RUZHSElKS0xNTk9QUVJTVFVWV1hZWjAxMjM0NT"
            b"Y3ODkhQCMwXiYqKCk7Ojw+LC4gW117fQ==")
         # Test with arbitrary alternative characters
-        eq(base64.b64encode(b'\xd3V\xbeo\xf7\x1d', altchars='*$'), b'01a*b$cd')
+        eq(base64.b64encode(b'\xd3V\xbeo\xf7\x1d', altchars=b'*$'), b'01a*b$cd')
+        # Check if passing a str object raises an error
+        self.assertRaises(TypeError, base64.b64encode, "")
+        self.assertRaises(TypeError, base64.b64encode, b"", altchars="")
         # Test standard alphabet
         eq(base64.standard_b64encode(b"www.python.org"), b"d3d3LnB5dGhvbi5vcmc=")
         eq(base64.standard_b64encode(b"a"), b"YQ==")
@@ -86,8 +90,13 @@ class BaseXYTestCase(unittest.TestCase):
            b"YWJjZGVmZ2hpamtsbW5vcHFyc3R1dnd4eXpBQkNE"
            b"RUZHSElKS0xNTk9QUVJTVFVWV1hZWjAxMjM0NT"
            b"Y3ODkhQCMwXiYqKCk7Ojw+LC4gW117fQ==")
+        # Check if passing a str object raises an error
+        self.assertRaises(TypeError, base64.standard_b64encode, "")
+        self.assertRaises(TypeError, base64.standard_b64encode, b"", altchars="")
         # Test with 'URL safe' alternative characters
         eq(base64.urlsafe_b64encode(b'\xd3V\xbeo\xf7\x1d'), b'01a-b_cd')
+        # Check if passing a str object raises an error
+        self.assertRaises(TypeError, base64.urlsafe_b64encode, "")
 
     def test_b64decode(self):
         eq = self.assertEqual
@@ -104,7 +113,10 @@ class BaseXYTestCase(unittest.TestCase):
            b"0123456789!@#0^&*();:<>,. []{}")
         eq(base64.b64decode(b''), b'')
         # Test with arbitrary alternative characters
-        eq(base64.b64decode(b'01a*b$cd', altchars='*$'), b'\xd3V\xbeo\xf7\x1d')
+        eq(base64.b64decode(b'01a*b$cd', altchars=b'*$'), b'\xd3V\xbeo\xf7\x1d')
+        # Check if passing a str object raises an error
+        self.assertRaises(TypeError, base64.b64decode, "")
+        self.assertRaises(TypeError, base64.b64decode, b"", altchars="")
         # Test standard alphabet
         eq(base64.standard_b64decode(b"d3d3LnB5dGhvbi5vcmc="), b"www.python.org")
         eq(base64.standard_b64decode(b"YQ=="), b"a")
@@ -117,8 +129,12 @@ class BaseXYTestCase(unittest.TestCase):
            b"abcdefghijklmnopqrstuvwxyz"
            b"ABCDEFGHIJKLMNOPQRSTUVWXYZ"
            b"0123456789!@#0^&*();:<>,. []{}")
+        # Check if passing a str object raises an error
+        self.assertRaises(TypeError, base64.standard_b64decode, "")
+        self.assertRaises(TypeError, base64.standard_b64decode, b"", altchars="")
         # Test with 'URL safe' alternative characters
         eq(base64.urlsafe_b64decode(b'01a-b_cd'), b'\xd3V\xbeo\xf7\x1d')
+        self.assertRaises(TypeError, base64.urlsafe_b64decode, "")
 
     def test_b64decode_error(self):
         self.assertRaises(binascii.Error, base64.b64decode, b'abc')
@@ -132,6 +148,7 @@ class BaseXYTestCase(unittest.TestCase):
         eq(base64.b32encode(b'abc'), b'MFRGG===')
         eq(base64.b32encode(b'abcd'), b'MFRGGZA=')
         eq(base64.b32encode(b'abcde'), b'MFRGGZDF')
+        self.assertRaises(TypeError, base64.b32encode, "")
 
     def test_b32decode(self):
         eq = self.assertEqual
@@ -142,6 +159,7 @@ class BaseXYTestCase(unittest.TestCase):
         eq(base64.b32decode(b'MFRGG==='), b'abc')
         eq(base64.b32decode(b'MFRGGZA='), b'abcd')
         eq(base64.b32decode(b'MFRGGZDF'), b'abcde')
+        self.assertRaises(TypeError, base64.b32decode, "")
 
     def test_b32decode_casefold(self):
         eq = self.assertEqual
@@ -163,6 +181,7 @@ class BaseXYTestCase(unittest.TestCase):
         eq(base64.b32decode(b'MLO23456'), b'b\xdd\xad\xf3\xbe')
         eq(base64.b32decode(b'M1023456', map01=b'L'), b'b\xdd\xad\xf3\xbe')
         eq(base64.b32decode(b'M1023456', map01=b'I'), b'b\x1d\xad\xf3\xbe')
+        self.assertRaises(TypeError, base64.b32decode, b"", map01="")
 
     def test_b32decode_error(self):
         self.assertRaises(binascii.Error, base64.b32decode, b'abc')
@@ -172,6 +191,7 @@ class BaseXYTestCase(unittest.TestCase):
         eq = self.assertEqual
         eq(base64.b16encode(b'\x01\x02\xab\xcd\xef'), b'0102ABCDEF')
         eq(base64.b16encode(b'\x00'), b'00')
+        self.assertRaises(TypeError, base64.b16encode, "")
 
     def test_b16decode(self):
         eq = self.assertEqual
@@ -181,6 +201,7 @@ class BaseXYTestCase(unittest.TestCase):
         self.assertRaises(binascii.Error, base64.b16decode, b'0102abcdef')
         # Case fold
         eq(base64.b16decode(b'0102abcdef', True), b'\x01\x02\xab\xcd\xef')
+        self.assertRaises(TypeError, base64.b16decode, "")
 
     def test_ErrorHeritage(self):
         self.assert_(issubclass(binascii.Error, ValueError))
index 145882673ae13fbdb6191cc6e81fd34516bcc637..76035a3cffaa56f5bc50e7978162748aea4d904e 100644 (file)
@@ -682,7 +682,7 @@ class ProxyHandler(BaseHandler):
             proxy_type = orig_type
         if user and password:
             user_pass = '%s:%s' % (unquote(user), unquote(password))
-            creds = str(base64.b64encode(user_pass)).strip()
+            creds = base64.b64encode(user_pass.encode()).decode("ascii")
             req.add_header('Proxy-authorization', 'Basic ' + creds)
         hostport = unquote(hostport)
         req.set_proxy(hostport, proxy_type)
@@ -808,7 +808,7 @@ class AbstractBasicAuthHandler:
         user, pw = self.passwd.find_user_password(realm, host)
         if pw is not None:
             raw = "%s:%s" % (user, pw)
-            auth = 'Basic %s' % base64.b64encode(raw).strip().decode()
+            auth = "Basic " + base64.b64encode(raw.encode()).decode("ascii")
             if req.headers.get(self.auth_header, None) == auth:
                 return None
             req.add_header(self.auth_header, auth)