]> granicus.if.org Git - python/commitdiff
Patch 1171 by mfenniak -- allow subclassing of bytes.
authorGuido van Rossum <guido@python.org>
Sat, 3 Nov 2007 00:24:24 +0000 (00:24 +0000)
committerGuido van Rossum <guido@python.org>
Sat, 3 Nov 2007 00:24:24 +0000 (00:24 +0000)
I suspect this has some problems when the subclass is evil,
but that's for later.

Lib/test/test_bytes.py
Objects/bytesobject.c

index 112cb3251a19088684686f27b0c528dc81dca0e6..932fa448bea9131a978cec9ebabfdbbfdcbd70bc 100644 (file)
@@ -3,6 +3,7 @@
 import os
 import re
 import sys
+import copy
 import pickle
 import tempfile
 import unittest
@@ -782,11 +783,89 @@ class BytesAsStringTest(test.string_tests.BaseTest):
         pass
 
 
+class BytesSubclass(bytes):
+    pass
+
+class BytesSubclassTest(unittest.TestCase):
+
+    def test_basic(self):
+        self.assert_(issubclass(BytesSubclass, bytes))
+        self.assert_(isinstance(BytesSubclass(), bytes))
+
+        a, b = b"abcd", b"efgh"
+        _a, _b = BytesSubclass(a), BytesSubclass(b)
+
+        # test comparison operators with subclass instances
+        self.assert_(_a == _a)
+        self.assert_(_a != _b)
+        self.assert_(_a < _b)
+        self.assert_(_a <= _b)
+        self.assert_(_b >= _a)
+        self.assert_(_b > _a)
+        self.assert_(_a is not a)
+
+        # test concat of subclass instances
+        self.assertEqual(a + b, _a + _b)
+        self.assertEqual(a + b, a + _b)
+        self.assertEqual(a + b, _a + b)
+
+        # test repeat
+        self.assert_(a*5 == _a*5)
+
+    def test_join(self):
+        # Make sure join returns a NEW object for single item sequences
+        # involving a subclass.
+        # Make sure that it is of the appropriate type.
+        s1 = BytesSubclass(b"abcd")
+        s2 = b"".join([s1])
+        self.assert_(s1 is not s2)
+        self.assert_(type(s2) is bytes)
+
+        # Test reverse, calling join on subclass
+        s3 = s1.join([b"abcd"])
+        self.assert_(type(s3) is bytes)
+
+    def test_pickle(self):
+        a = BytesSubclass(b"abcd")
+        a.x = 10
+        a.y = BytesSubclass(b"efgh")
+        for proto in range(pickle.HIGHEST_PROTOCOL):
+            b = pickle.loads(pickle.dumps(a, proto))
+            self.assertNotEqual(id(a), id(b))
+            self.assertEqual(a, b)
+            self.assertEqual(a.x, b.x)
+            self.assertEqual(a.y, b.y)
+            self.assertEqual(type(a), type(b))
+            self.assertEqual(type(a.y), type(b.y))
+
+    def test_copy(self):
+        a = BytesSubclass(b"abcd")
+        a.x = 10
+        a.y = BytesSubclass(b"efgh")
+        for copy_method in (copy.copy, copy.deepcopy):
+            b = copy_method(a)
+            self.assertNotEqual(id(a), id(b))
+            self.assertEqual(a, b)
+            self.assertEqual(a.x, b.x)
+            self.assertEqual(a.y, b.y)
+            self.assertEqual(type(a), type(b))
+            self.assertEqual(type(a.y), type(b.y))
+
+    def test_init_override(self):
+        class subclass(bytes):
+            def __init__(self, newarg=1, *args, **kwargs):
+                bytes.__init__(self, *args, **kwargs)
+        x = subclass(4, source=b"abcd")
+        self.assertEqual(x, b"abcd")
+        x = subclass(newarg=4, source=b"abcd")
+        self.assertEqual(x, b"abcd")
+
+
 def test_main():
     test.test_support.run_unittest(BytesTest)
     test.test_support.run_unittest(BytesAsStringTest)
+    test.test_support.run_unittest(BytesSubclassTest)
     test.test_support.run_unittest(BufferPEP3137Test)
 
 if __name__ == "__main__":
-    ##test_main()
-    unittest.main()
+    test_main()
index 2595ff2de93e502c1789b59b5f203b8899b467a9..3f2dbc2934a972dc9905f2370b7bd4c5e1a26ce0 100644 (file)
@@ -2921,13 +2921,21 @@ PyDoc_STRVAR(reduce_doc, "Return state information for pickling.");
 static PyObject *
 bytes_reduce(PyBytesObject *self)
 {
-    PyObject *latin1;
+    PyObject *latin1, *dict;
     if (self->ob_bytes)
         latin1 = PyUnicode_DecodeLatin1(self->ob_bytes,
                                         Py_Size(self), NULL);
     else
         latin1 = PyUnicode_FromString("");
-    return Py_BuildValue("(O(Ns))", Py_Type(self), latin1, "latin-1");
+
+    dict = PyObject_GetAttrString((PyObject *)self, "__dict__");
+    if (dict == NULL) {
+        PyErr_Clear();
+        dict = Py_None;
+        Py_INCREF(dict);
+    }
+
+    return Py_BuildValue("(O(Ns)N)", Py_Type(self), latin1, "latin-1", dict);
 }
 
 static PySequenceMethods bytes_as_sequence = {
@@ -3045,8 +3053,7 @@ PyTypeObject PyBytes_Type = {
     PyObject_GenericGetAttr,            /* tp_getattro */
     0,                                  /* tp_setattro */
     &bytes_as_buffer,                   /* tp_as_buffer */
-    /* bytes is 'final' or 'sealed' */
-    Py_TPFLAGS_DEFAULT,                 /* tp_flags */
+    Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */
     bytes_doc,                          /* tp_doc */
     0,                                  /* tp_traverse */
     0,                                  /* tp_clear */