]> granicus.if.org Git - python/commitdiff
cPickle produces NEWOBJ appropriately now. It still doesn't know
authorTim Peters <tim.peters@gmail.com>
Fri, 14 Feb 2003 23:05:28 +0000 (23:05 +0000)
committerTim Peters <tim.peters@gmail.com>
Fri, 14 Feb 2003 23:05:28 +0000 (23:05 +0000)
how to unpickle the new slot-full state tuples.

Lib/test/pickletester.py
Modules/cPickle.c

index e1a333b1d103898302e0854ce436a94bdb57546a..57e051c8128cd23c1f679a1fc91593b41d36bf63 100644 (file)
@@ -718,12 +718,6 @@ class AbstractPickleTests(unittest.TestCase):
             else:
                 self.failUnless(num_setitems >= 2)
 
-# XXX Temporary hack, so long as the C implementation of pickle protocol
-# XXX 2 isn't ready.  When it is, move the methods in TempAbstractPickleTests
-# XXX into AbstractPickleTests above, and get rid of TempAbstractPickleTests
-# XXX along with the references to it in test_pickle.py.
-class TempAbstractPickleTests(unittest.TestCase):
-
     def test_simple_newobj(self):
         x = object.__new__(SimpleNewObj)  # avoid __init__
         x.abc = 666
@@ -734,6 +728,12 @@ class TempAbstractPickleTests(unittest.TestCase):
             self.assertEqual(y.abc, 666)
             self.assertEqual(x.__dict__, y.__dict__)
 
+# XXX Temporary hack, so long as the C implementation of pickle protocol
+# XXX 2 isn't ready.  When it is, move the methods in TempAbstractPickleTests
+# XXX into AbstractPickleTests above, and get rid of TempAbstractPickleTests
+# XXX along with the references to it in test_pickle.py.
+class TempAbstractPickleTests(unittest.TestCase):
+
     def test_newobj_list_slots(self):
         x = SlotList([1, 2, 3])
         x.foo = 42
index 09485b9ba54feeaaba73434a86961f7eefdd091a..f09e502b10f4d006e83fa2b13d7f72db9c67541a 100644 (file)
@@ -119,6 +119,12 @@ static PyObject *extension_cache;
 /* For looking up name pairs in copy_reg._extension_registry. */
 static PyObject *two_tuple;
 
+/* object.__reduce__, the default reduce callable. */
+PyObject *object_reduce;
+
+/* copy_reg._better_reduce, the protocol 2 reduction function. */
+PyObject *better_reduce;
+
 static PyObject *__class___str, *__getinitargs___str, *__dict___str,
   *__getstate___str, *__setstate___str, *__name___str, *__reduce___str,
   *write_str, *append_str,
@@ -2181,38 +2187,142 @@ save_pers(Picklerobject *self, PyObject *args, PyObject *f)
        return res;
 }
 
-
+/* We're saving ob, and args is the 2-thru-5 tuple returned by the
+ * appropriate __reduce__ method for ob.
+ */
 static int
-save_reduce(Picklerobject *self, PyObject *callable,
-            PyObject *tup, PyObject *state, PyObject *ob)
-{
-       static char reduce = REDUCE, build = BUILD;
-
-       if (save(self, callable, 0) < 0)
+save_reduce(Picklerobject *self, PyObject *args, PyObject *ob)
+{
+       PyObject *callable;
+       PyObject *argtup;
+        PyObject *state = NULL;
+        PyObject *listitems = NULL;
+        PyObject *dictitems = NULL;
+
+       int use_newobj = self->proto >= 2;
+
+       static char reduce = REDUCE;
+       static char build = BUILD;
+       static char newobj = NEWOBJ;
+
+       if (! PyArg_UnpackTuple(args, "save_reduce", 2, 5,
+                               &callable,
+                               &argtup,
+                               &state,
+                               &listitems,
+                               &dictitems))
                return -1;
 
-       if (save(self, tup, 0) < 0)
-               return -1;
+       if (state == Py_None)
+               state = NULL;
+       if (listitems == Py_None)
+               listitems = NULL;
+       if (dictitems == Py_None)
+               dictitems = NULL;
 
-       if (self->write_func(self, &reduce, 1) < 0)
-               return -1;
+        /* Protocol 2 special case: if callable's name is __newobj__, use
+         * NEWOBJ.  This consumes a lot of code.
+         */
+        if (use_newobj) {
+               PyObject *temp = PyObject_GetAttr(callable, __name___str);
+
+               if (temp == NULL) {
+                       PyErr_Clear();
+                       use_newobj = 0;
+               }
+               else {
+                       use_newobj = PyString_Check(temp) &&
+                                    strcmp(PyString_AS_STRING(temp),
+                                           "__newobj__") == 0;
+                       Py_DECREF(temp);
+               }
+       }
+       if (use_newobj) {
+               PyObject *cls;
+               PyObject *newargtup;
+               int n, i;
+
+               /* Sanity checks. */
+               n = PyTuple_Size(argtup);
+               if (n < 1) {
+                       PyErr_SetString(PicklingError, "__newobj__ arglist "
+                               "is empty");
+                       return -1;
+               }
+
+               cls = PyTuple_GET_ITEM(argtup, 0);
+               if (! PyObject_HasAttrString(cls, "__new__")) {
+                       PyErr_SetString(PicklingError, "args[0] from "
+                               "__newobj__ args has no __new__");
+                       return -1;
+               }
+
+               /* XXX How could ob be NULL? */
+               if (ob != NULL) {
+                       PyObject *ob_dot_class;
 
+                       ob_dot_class = PyObject_GetAttr(ob, __class___str);
+                       if (ob_dot_class == NULL)
+                               PyErr_Clear();
+                       i = ob_dot_class != cls; /* true iff a problem */
+                       Py_XDECREF(ob_dot_class);
+                       if (i) {
+                               PyErr_SetString(PicklingError, "args[0] from "
+                                       "__newobj__ args has the wrong class");
+                               return -1;
+                       }
+               }
+
+               /* Save the class and its __new__ arguments. */
+               if (save(self, cls, 0) < 0)
+                       return -1;
+
+               newargtup = PyTuple_New(n-1);  /* argtup[1:] */
+               if (newargtup == NULL)
+                       return -1;
+               for (i = 1; i < n; ++i) {
+                       PyObject *temp = PyTuple_GET_ITEM(argtup, i);
+                       Py_INCREF(temp);
+                       PyTuple_SET_ITEM(newargtup, i-1, temp);
+               }
+               i = save(self, newargtup, 0) < 0;
+               Py_DECREF(newargtup);
+               if (i < 0)
+                       return -1;
+
+               /* Add NEWOBJ opcode. */
+               if (self->write_func(self, &newobj, 1) < 0)
+                       return -1;
+       }
+       else {
+               /* Not using NEWOBJ. */
+               if (save(self, callable, 0) < 0 ||
+                   save(self, argtup, 0) < 0 ||
+                   self->write_func(self, &reduce, 1) < 0)
+                       return -1;
+       }
+
+       /* Memoize. */
+       /* XXX How can ob be NULL? */
        if (ob != NULL) {
                if (state && !PyDict_Check(state)) {
                        if (put2(self, ob) < 0)
                                return -1;
                }
-               else {
-                       if (put(self, ob) < 0)
+               else if (put(self, ob) < 0)
                                return -1;
-               }
        }
 
-       if (state) {
-               if (save(self, state, 0) < 0)
-                       return -1;
 
-               if (self->write_func(self, &build, 1) < 0)
+        if (listitems && batch_list(self, listitems) < 0)
+               return -1;
+
+        if (dictitems && batch_dict(self, dictitems) < 0)
+               return -1;
+
+       if (state) {
+               if (save(self, state, 0) < 0 ||
+                   self->write_func(self, &build, 1) < 0)
                        return -1;
        }
 
@@ -2223,9 +2333,10 @@ static int
 save(Picklerobject *self, PyObject *args, int pers_save)
 {
        PyTypeObject *type;
-       PyObject *py_ob_id = 0, *__reduce__ = 0, *t = 0, *arg_tup = 0,
-               *callable = 0, *state = 0;
-       int res = -1, tmp, size;
+       PyObject *py_ob_id = 0, *__reduce__ = 0, *t = 0;
+       PyObject *arg_tup;
+       int res = -1;
+       int tmp, size;
 
         if (self->nesting++ > Py_GetRecursionLimit()){
                PyErr_SetString(PyExc_RuntimeError,
@@ -2392,72 +2503,80 @@ save(Picklerobject *self, PyObject *args, int pers_save)
                goto finally;
        }
 
-       assert(t == NULL);      /* just a reminder */
+       /* Get a reduction callable.  This may come from
+        * copy_reg.dispatch_table, the object's __reduce__ method,
+        * the default object.__reduce__, or copy_reg._better_reduce.
+        */
        __reduce__ = PyDict_GetItem(dispatch_table, (PyObject *)type);
        if (__reduce__ != NULL) {
                Py_INCREF(__reduce__);
-               Py_INCREF(args);
-               ARG_TUP(self, args);
-               if (self->arg) {
-                       t = PyObject_Call(__reduce__, self->arg, NULL);
-                       FREE_ARG_TUP(self);
-               }
-               if (! t) goto finally;
        }
        else {
-               __reduce__ = PyObject_GetAttr(args, __reduce___str);
-               if (__reduce__ == NULL)
+               /* Check for a __reduce__ method.
+                * Subtle: get the unbound method from the class, so that
+                * protocol 2 can override the default __reduce__ that all
+                * classes inherit from object.
+                * XXX object.__reduce__ should really be rewritten so that
+                * XXX we don't need to call back into Python code here
+                * XXX (better_reduce), but no time to do that.
+                */
+               __reduce__ = PyObject_GetAttr((PyObject *)type,
+                                             __reduce___str);
+               if (__reduce__ == NULL) {
                        PyErr_Clear();
-               else {
-                       t = PyObject_Call(__reduce__, empty_tuple, NULL);
-                       if (!t)
-                               goto finally;
-               }
-       }
-
-       if (t) {
-               if (PyString_Check(t)) {
-                       res = save_global(self, args, t);
-                       goto finally;
-               }
-
-               if (!PyTuple_Check(t)) {
-                       cPickle_ErrFormat(PicklingError, "Value returned by "
-                                       "%s must be a tuple",
-                                       "O", __reduce__);
+                       PyErr_SetObject(UnpickleableError, args);
                        goto finally;
                }
 
-               size = PyTuple_Size(t);
-
-               if (size != 3 && size != 2) {
-                       cPickle_ErrFormat(PicklingError, "tuple returned by "
-                               "%s must contain only two or three elements",
-                               "O", __reduce__);
-                       goto finally;
+               if (self->proto >= 2 && __reduce__ == object_reduce) {
+                       /* Proto 2 can do better than the default. */
+                       Py_DECREF(__reduce__);
+                       Py_INCREF(better_reduce);
+                       __reduce__ = better_reduce;
                }
+       }
 
-               callable = PyTuple_GET_ITEM(t, 0);
-               arg_tup = PyTuple_GET_ITEM(t, 1);
+       /* Call the reduction callable, setting t to the result. */
+       assert(__reduce__ != NULL);
+       assert(t == NULL);
+       Py_INCREF(args);
+       ARG_TUP(self, args);
+       if (self->arg) {
+               t = PyObject_Call(__reduce__, self->arg, NULL);
+               FREE_ARG_TUP(self);
+       }
+       if (t == NULL)
+               goto finally;
 
-               if (size > 2) {
-                       state = PyTuple_GET_ITEM(t, 2);
-                       if (state == Py_None)
-                               state = NULL;
-               }
+       if (PyString_Check(t)) {
+               res = save_global(self, args, t);
+               goto finally;
+       }
 
-               if (!( PyTuple_Check(arg_tup) || arg_tup==Py_None ))  {
-                       cPickle_ErrFormat(PicklingError, "Second element of "
-                               "tuple returned by %s must be a tuple",
+       if (! PyTuple_Check(t)) {
+               cPickle_ErrFormat(PicklingError, "Value returned by "
+                               "%s must be string or tuple",
                                "O", __reduce__);
-                       goto finally;
-               }
+               goto finally;
+       }
 
-               res = save_reduce(self, callable, arg_tup, state, args);
+       size = PyTuple_Size(t);
+       if (size < 2 || size > 5) {
+               cPickle_ErrFormat(PicklingError, "tuple returned by "
+                       "%s must contain 2 through 5 elements",
+                       "O", __reduce__);
                goto finally;
        }
 
-       PyErr_SetObject(UnpickleableError, args);
+       arg_tup = PyTuple_GET_ITEM(t, 1);
+       if (!(PyTuple_Check(arg_tup) || arg_tup == Py_None))  {
+               cPickle_ErrFormat(PicklingError, "Second element of "
+                       "tuple returned by %s must be a tuple",
+                       "O", __reduce__);
+               goto finally;
+       }
+
+       res = save_reduce(self, t, args);
 
   finally:
        self->nesting--;
@@ -5447,8 +5566,15 @@ init_stuff(PyObject *module_dict)
                                "_extension_cache");
        if (!extension_cache) return -1;
 
+       better_reduce = PyObject_GetAttrString(copy_reg, "_better_reduce");
+       if (!better_reduce) return -1;
+
        Py_DECREF(copy_reg);
 
+       object_reduce = PyObject_GetAttrString((PyObject *)&PyBaseObject_Type,
+                                              "__reduce__");
+       if (object_reduce == NULL) return -1;
+
        if (!(empty_tuple = PyTuple_New(0)))
                return -1;