]> granicus.if.org Git - python/commitdiff
bpo-28416: Break reference cycles in Pickler and Unpickler subclasses (#4080)
authorSerhiy Storchaka <storchaka@gmail.com>
Thu, 30 Nov 2017 20:48:31 +0000 (22:48 +0200)
committerGitHub <noreply@github.com>
Thu, 30 Nov 2017 20:48:31 +0000 (22:48 +0200)
with the persistent_id() and persistent_load() methods.

Lib/test/test_pickle.py
Misc/NEWS.d/next/Library/2017-10-23-12-05-33.bpo-28416.Ldnw8X.rst [new file with mode: 0644]
Modules/_pickle.c

index ee71c632733563846f29ad2e5032d0e9d958339b..895ed48df1d8d20c2ac6b1c3f8f68e994143c483 100644 (file)
@@ -6,6 +6,7 @@ import io
 import collections
 import struct
 import sys
+import weakref
 
 import unittest
 from test import support
@@ -117,6 +118,66 @@ class PyIdPersPicklerTests(AbstractIdentityPersistentPicklerTests,
     pickler = pickle._Pickler
     unpickler = pickle._Unpickler
 
+    @support.cpython_only
+    def test_pickler_reference_cycle(self):
+        def check(Pickler):
+            for proto in range(pickle.HIGHEST_PROTOCOL + 1):
+                f = io.BytesIO()
+                pickler = Pickler(f, proto)
+                pickler.dump('abc')
+                self.assertEqual(self.loads(f.getvalue()), 'abc')
+            pickler = Pickler(io.BytesIO())
+            self.assertEqual(pickler.persistent_id('def'), 'def')
+            r = weakref.ref(pickler)
+            del pickler
+            self.assertIsNone(r())
+
+        class PersPickler(self.pickler):
+            def persistent_id(subself, obj):
+                return obj
+        check(PersPickler)
+
+        class PersPickler(self.pickler):
+            @classmethod
+            def persistent_id(cls, obj):
+                return obj
+        check(PersPickler)
+
+        class PersPickler(self.pickler):
+            @staticmethod
+            def persistent_id(obj):
+                return obj
+        check(PersPickler)
+
+    @support.cpython_only
+    def test_unpickler_reference_cycle(self):
+        def check(Unpickler):
+            for proto in range(pickle.HIGHEST_PROTOCOL + 1):
+                unpickler = Unpickler(io.BytesIO(self.dumps('abc', proto)))
+                self.assertEqual(unpickler.load(), 'abc')
+            unpickler = Unpickler(io.BytesIO())
+            self.assertEqual(unpickler.persistent_load('def'), 'def')
+            r = weakref.ref(unpickler)
+            del unpickler
+            self.assertIsNone(r())
+
+        class PersUnpickler(self.unpickler):
+            def persistent_load(subself, pid):
+                return pid
+        check(PersUnpickler)
+
+        class PersUnpickler(self.unpickler):
+            @classmethod
+            def persistent_load(cls, pid):
+                return pid
+        check(PersUnpickler)
+
+        class PersUnpickler(self.unpickler):
+            @staticmethod
+            def persistent_load(pid):
+                return pid
+        check(PersUnpickler)
+
 
 class PyPicklerUnpicklerObjectTests(AbstractPicklerUnpicklerObjectTests):
 
@@ -197,7 +258,7 @@ if has_c_implementation:
         check_sizeof = support.check_sizeof
 
         def test_pickler(self):
-            basesize = support.calcobjsize('5P2n3i2n3iP')
+            basesize = support.calcobjsize('6P2n3i2n3iP')
             p = _pickle.Pickler(io.BytesIO())
             self.assertEqual(object.__sizeof__(p), basesize)
             MT_size = struct.calcsize('3nP0n')
@@ -214,7 +275,7 @@ if has_c_implementation:
                 0)  # Write buffer is cleared after every dump().
 
         def test_unpickler(self):
-            basesize = support.calcobjsize('2Pn2P 2P2n2i5P 2P3n6P2n2i')
+            basesize = support.calcobjsize('2P2n2P 2P2n2i5P 2P3n6P2n2i')
             unpickler = _pickle.Unpickler
             P = struct.calcsize('P')  # Size of memo table entry.
             n = struct.calcsize('n')  # Size of mark table entry.
diff --git a/Misc/NEWS.d/next/Library/2017-10-23-12-05-33.bpo-28416.Ldnw8X.rst b/Misc/NEWS.d/next/Library/2017-10-23-12-05-33.bpo-28416.Ldnw8X.rst
new file mode 100644 (file)
index 0000000..b101482
--- /dev/null
@@ -0,0 +1,3 @@
+Instances of pickle.Pickler subclass with the persistent_id() method and
+pickle.Unpickler subclass with the persistent_load() method no longer create
+reference cycles.
index 943c70112b79175acc30d87457a1c7f70d1ecf06..da915efd831595bacfb6c7b175fb8c750bab4ce6 100644 (file)
@@ -360,6 +360,69 @@ _Pickle_FastCall(PyObject *func, PyObject *obj)
 
 /*************************************************************************/
 
+/* Retrieve and deconstruct a method for avoiding a reference cycle
+   (pickler -> bound method of pickler -> pickler) */
+static int
+init_method_ref(PyObject *self, _Py_Identifier *name,
+                PyObject **method_func, PyObject **method_self)
+{
+    PyObject *func, *func2;
+
+    /* *method_func and *method_self should be consistent.  All refcount decrements
+       should be occurred after setting *method_self and *method_func. */
+    func = _PyObject_GetAttrId(self, name);
+    if (func == NULL) {
+        *method_self = NULL;
+        Py_CLEAR(*method_func);
+        if (!PyErr_ExceptionMatches(PyExc_AttributeError)) {
+            return -1;
+        }
+        PyErr_Clear();
+        return 0;
+    }
+
+    if (PyMethod_Check(func) && PyMethod_GET_SELF(func) == self) {
+        /* Deconstruct a bound Python method */
+        func2 = PyMethod_GET_FUNCTION(func);
+        Py_INCREF(func2);
+        *method_self = self; /* borrowed */
+        Py_XSETREF(*method_func, func2);
+        Py_DECREF(func);
+        return 0;
+    }
+    else {
+        *method_self = NULL;
+        Py_XSETREF(*method_func, func);
+        return 0;
+    }
+}
+
+/* Bind a method if it was deconstructed */
+static PyObject *
+reconstruct_method(PyObject *func, PyObject *self)
+{
+    if (self) {
+        return PyMethod_New(func, self);
+    }
+    else {
+        Py_INCREF(func);
+        return func;
+    }
+}
+
+static PyObject *
+call_method(PyObject *func, PyObject *self, PyObject *obj)
+{
+    if (self) {
+        return PyObject_CallFunctionObjArgs(func, self, obj, NULL);
+    }
+    else {
+        return PyObject_CallFunctionObjArgs(func, obj, NULL);
+    }
+}
+
+/*************************************************************************/
+
 /* Internal data type used as the unpickling stack. */
 typedef struct {
     PyObject_VAR_HEAD
@@ -552,6 +615,8 @@ typedef struct PicklerObject {
                                    objects to support self-referential objects
                                    pickling. */
     PyObject *pers_func;        /* persistent_id() method, can be NULL */
+    PyObject *pers_func_self;   /* borrowed reference to self if pers_func
+                                   is an unbound method, NULL otherwise */
     PyObject *dispatch_table;   /* private dispatch_table, can be NULL */
 
     PyObject *write;            /* write() method of the output stream. */
@@ -590,6 +655,8 @@ typedef struct UnpicklerObject {
     Py_ssize_t memo_len;        /* Number of objects in the memo */
 
     PyObject *pers_func;        /* persistent_load() method, can be NULL. */
+    PyObject *pers_func_self;   /* borrowed reference to self if pers_func
+                                   is an unbound method, NULL otherwise */
 
     Py_buffer buffer;
     char *input_buffer;
@@ -3444,7 +3511,7 @@ save_type(PicklerObject *self, PyObject *obj)
 }
 
 static int
-save_pers(PicklerObject *self, PyObject *obj, PyObject *func)
+save_pers(PicklerObject *self, PyObject *obj)
 {
     PyObject *pid = NULL;
     int status = 0;
@@ -3452,8 +3519,7 @@ save_pers(PicklerObject *self, PyObject *obj, PyObject *func)
     const char persid_op = PERSID;
     const char binpersid_op = BINPERSID;
 
-    Py_INCREF(obj);
-    pid = _Pickle_FastCall(func, obj);
+    pid = call_method(self->pers_func, self->pers_func_self, obj);
     if (pid == NULL)
         return -1;
 
@@ -3831,7 +3897,7 @@ save(PicklerObject *self, PyObject *obj, int pers_save)
              0   if it did nothing successfully;
              1   if a persistent id was saved.
          */
-        if ((status = save_pers(self, obj, self->pers_func)) != 0)
+        if ((status = save_pers(self, obj)) != 0)
             goto done;
     }
 
@@ -4246,13 +4312,10 @@ _pickle_Pickler___init___impl(PicklerObject *self, PyObject *file,
     self->fast_nesting = 0;
     self->fast_memo = NULL;
 
-    self->pers_func = _PyObject_GetAttrId((PyObject *)self,
-                                          &PyId_persistent_id);
-    if (self->pers_func == NULL) {
-        if (!PyErr_ExceptionMatches(PyExc_AttributeError)) {
-            return -1;
-        }
-        PyErr_Clear();
+    if (init_method_ref((PyObject *)self, &PyId_persistent_id,
+                        &self->pers_func, &self->pers_func_self) < 0)
+    {
+        return -1;
     }
 
     self->dispatch_table = _PyObject_GetAttrId((PyObject *)self,
@@ -4519,11 +4582,11 @@ Pickler_set_memo(PicklerObject *self, PyObject *obj)
 static PyObject *
 Pickler_get_persid(PicklerObject *self)
 {
-    if (self->pers_func == NULL)
+    if (self->pers_func == NULL) {
         PyErr_SetString(PyExc_AttributeError, "persistent_id");
-    else
-        Py_INCREF(self->pers_func);
-    return self->pers_func;
+        return NULL;
+    }
+    return reconstruct_method(self->pers_func, self->pers_func_self);
 }
 
 static int
@@ -4540,6 +4603,7 @@ Pickler_set_persid(PicklerObject *self, PyObject *value)
         return -1;
     }
 
+    self->pers_func_self = NULL;
     Py_INCREF(value);
     Py_XSETREF(self->pers_func, value);
 
@@ -5489,7 +5553,7 @@ load_stack_global(UnpicklerObject *self)
 static int
 load_persid(UnpicklerObject *self)
 {
-    PyObject *pid;
+    PyObject *pid, *obj;
     Py_ssize_t len;
     char *s;
 
@@ -5509,13 +5573,12 @@ load_persid(UnpicklerObject *self)
             return -1;
         }
 
-        /* This does not leak since _Pickle_FastCall() steals the reference
-           to pid first. */
-        pid = _Pickle_FastCall(self->pers_func, pid);
-        if (pid == NULL)
+        obj = call_method(self->pers_func, self->pers_func_self, pid);
+        Py_DECREF(pid);
+        if (obj == NULL)
             return -1;
 
-        PDATA_PUSH(self->stack, pid, -1);
+        PDATA_PUSH(self->stack, obj, -1);
         return 0;
     }
     else {
@@ -5530,20 +5593,19 @@ load_persid(UnpicklerObject *self)
 static int
 load_binpersid(UnpicklerObject *self)
 {
-    PyObject *pid;
+    PyObject *pid, *obj;
 
     if (self->pers_func) {
         PDATA_POP(self->stack, pid);
         if (pid == NULL)
             return -1;
 
-        /* This does not leak since _Pickle_FastCall() steals the
-           reference to pid first. */
-        pid = _Pickle_FastCall(self->pers_func, pid);
-        if (pid == NULL)
+        obj = call_method(self->pers_func, self->pers_func_self, pid);
+        Py_DECREF(pid);
+        if (obj == NULL)
             return -1;
 
-        PDATA_PUSH(self->stack, pid, -1);
+        PDATA_PUSH(self->stack, obj, -1);
         return 0;
     }
     else {
@@ -6690,13 +6752,10 @@ _pickle_Unpickler___init___impl(UnpicklerObject *self, PyObject *file,
 
     self->fix_imports = fix_imports;
 
-    self->pers_func = _PyObject_GetAttrId((PyObject *)self,
-                                          &PyId_persistent_load);
-    if (self->pers_func == NULL) {
-        if (!PyErr_ExceptionMatches(PyExc_AttributeError)) {
-            return -1;
-        }
-        PyErr_Clear();
+    if (init_method_ref((PyObject *)self, &PyId_persistent_load,
+                        &self->pers_func, &self->pers_func_self) < 0)
+    {
+        return -1;
     }
 
     self->stack = (Pdata *)Pdata_New();
@@ -6983,11 +7042,11 @@ Unpickler_set_memo(UnpicklerObject *self, PyObject *obj)
 static PyObject *
 Unpickler_get_persload(UnpicklerObject *self)
 {
-    if (self->pers_func == NULL)
+    if (self->pers_func == NULL) {
         PyErr_SetString(PyExc_AttributeError, "persistent_load");
-    else
-        Py_INCREF(self->pers_func);
-    return self->pers_func;
+        return NULL;
+    }
+    return reconstruct_method(self->pers_func, self->pers_func_self);
 }
 
 static int
@@ -7005,6 +7064,7 @@ Unpickler_set_persload(UnpicklerObject *self, PyObject *value)
         return -1;
     }
 
+    self->pers_func_self = NULL;
     Py_INCREF(value);
     Py_XSETREF(self->pers_func, value);