]> granicus.if.org Git - python/commitdiff
Issue #17711: Fixed unpickling by the persistent ID with protocol 0.
authorSerhiy Storchaka <storchaka@gmail.com>
Sun, 17 Jul 2016 08:24:17 +0000 (11:24 +0300)
committerSerhiy Storchaka <storchaka@gmail.com>
Sun, 17 Jul 2016 08:24:17 +0000 (11:24 +0300)
Original patch by Alexandre Vassalotti.

Lib/pickle.py
Lib/test/pickletester.py
Lib/test/test_pickle.py
Misc/NEWS
Modules/_pickle.c

index 7760425e76d631af828de543c0a2d8dac2646669..040ecb245fdfb3c7e0296d5318659574dbee5096 100644 (file)
@@ -529,7 +529,11 @@ class _Pickler:
             self.save(pid, save_persistent_id=False)
             self.write(BINPERSID)
         else:
-            self.write(PERSID + str(pid).encode("ascii") + b'\n')
+            try:
+                self.write(PERSID + str(pid).encode("ascii") + b'\n')
+            except UnicodeEncodeError:
+                raise PicklingError(
+                    "persistent IDs in protocol 0 must be ASCII strings")
 
     def save_reduce(self, func, args, state=None, listitems=None,
                     dictitems=None, obj=None):
@@ -1075,7 +1079,11 @@ class _Unpickler:
     dispatch[FRAME[0]] = load_frame
 
     def load_persid(self):
-        pid = self.readline()[:-1].decode("ascii")
+        try:
+            pid = self.readline()[:-1].decode("ascii")
+        except UnicodeDecodeError:
+            raise UnpicklingError(
+                "persistent IDs in protocol 0 must be ASCII strings")
         self.append(self.persistent_load(pid))
     dispatch[PERSID[0]] = load_persid
 
index f252a0a10f3344d97fc67c772e44a1f57e6ab590..7922b54f03e77fd52897308bebfa5ce9e65f5fd5 100644 (file)
@@ -2629,6 +2629,35 @@ class AbstractPersistentPicklerTests(unittest.TestCase):
             self.assertEqual(self.load_false_count, 1)
 
 
+class AbstractIdentityPersistentPicklerTests(unittest.TestCase):
+
+    def persistent_id(self, obj):
+        return obj
+
+    def persistent_load(self, pid):
+        return pid
+
+    def _check_return_correct_type(self, obj, proto):
+        unpickled = self.loads(self.dumps(obj, proto))
+        self.assertIsInstance(unpickled, type(obj))
+        self.assertEqual(unpickled, obj)
+
+    def test_return_correct_type(self):
+        for proto in protocols:
+            # Protocol 0 supports only ASCII strings.
+            if proto == 0:
+                self._check_return_correct_type("abc", 0)
+            else:
+                for obj in [b"abc\n", "abc\n", -1, -1.1 * 0.1, str]:
+                    self._check_return_correct_type(obj, proto)
+
+    def test_protocol0_is_ascii_only(self):
+        non_ascii_str = "\N{EMPTY SET}"
+        self.assertRaises(pickle.PicklingError, self.dumps, non_ascii_str, 0)
+        pickled = pickle.PERSID + non_ascii_str.encode('utf-8') + b'\n.'
+        self.assertRaises(pickle.UnpicklingError, self.loads, pickled)
+
+
 class AbstractPicklerUnpicklerObjectTests(unittest.TestCase):
 
     pickler_class = None
index ee7a667d0614f2975c44e8f17dc0dd4ac99beda6..d467d52badb9dae27c1964be10178dc98c3feca5 100644 (file)
@@ -14,6 +14,7 @@ from test.pickletester import AbstractUnpickleTests
 from test.pickletester import AbstractPickleTests
 from test.pickletester import AbstractPickleModuleTests
 from test.pickletester import AbstractPersistentPicklerTests
+from test.pickletester import AbstractIdentityPersistentPicklerTests
 from test.pickletester import AbstractPicklerUnpicklerObjectTests
 from test.pickletester import AbstractDispatchTableTests
 from test.pickletester import BigmemPickleTests
@@ -82,10 +83,7 @@ class InMemoryPickleTests(AbstractPickleTests, AbstractUnpickleTests,
         return pickle.loads(buf, **kwds)
 
 
-class PyPersPicklerTests(AbstractPersistentPicklerTests):
-
-    pickler = pickle._Pickler
-    unpickler = pickle._Unpickler
+class PersistentPicklerUnpicklerMixin(object):
 
     def dumps(self, arg, proto=None):
         class PersPickler(self.pickler):
@@ -94,8 +92,7 @@ class PyPersPicklerTests(AbstractPersistentPicklerTests):
         f = io.BytesIO()
         p = PersPickler(f, proto)
         p.dump(arg)
-        f.seek(0)
-        return f.read()
+        return f.getvalue()
 
     def loads(self, buf, **kwds):
         class PersUnpickler(self.unpickler):
@@ -106,6 +103,20 @@ class PyPersPicklerTests(AbstractPersistentPicklerTests):
         return u.load()
 
 
+class PyPersPicklerTests(AbstractPersistentPicklerTests,
+                         PersistentPicklerUnpicklerMixin):
+
+    pickler = pickle._Pickler
+    unpickler = pickle._Unpickler
+
+
+class PyIdPersPicklerTests(AbstractIdentityPersistentPicklerTests,
+                           PersistentPicklerUnpicklerMixin):
+
+    pickler = pickle._Pickler
+    unpickler = pickle._Unpickler
+
+
 class PyPicklerUnpicklerObjectTests(AbstractPicklerUnpicklerObjectTests):
 
     pickler_class = pickle._Pickler
@@ -144,6 +155,10 @@ if has_c_implementation:
         pickler = _pickle.Pickler
         unpickler = _pickle.Unpickler
 
+    class CIdPersPicklerTests(PyIdPersPicklerTests):
+        pickler = _pickle.Pickler
+        unpickler = _pickle.Unpickler
+
     class CDumpPickle_LoadPickle(PyPicklerTests):
         pickler = _pickle.Pickler
         unpickler = pickle._Unpickler
@@ -409,11 +424,13 @@ class CompatPickleTests(unittest.TestCase):
 
 
 def test_main():
-    tests = [PickleTests, PyUnpicklerTests, PyPicklerTests, PyPersPicklerTests,
+    tests = [PickleTests, PyUnpicklerTests, PyPicklerTests,
+             PyPersPicklerTests, PyIdPersPicklerTests,
              PyDispatchTableTests, PyChainDispatchTableTests,
              CompatPickleTests]
     if has_c_implementation:
-        tests.extend([CUnpicklerTests, CPicklerTests, CPersPicklerTests,
+        tests.extend([CUnpicklerTests, CPicklerTests,
+                      CPersPicklerTests, CIdPersPicklerTests,
                       CDumpPickle_LoadPickle, DumpPickle_CLoadPickle,
                       PyPicklerUnpicklerObjectTests,
                       CPicklerUnpicklerObjectTests,
index 75ca9b0133c9213b17d53b3a5ce72403349486e3..a79cfaf88a7b45da62f38de1fc501596e9eae191 100644 (file)
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -24,6 +24,9 @@ Core and Builtins
 Library
 -------
 
+- Issue #17711: Fixed unpickling by the persistent ID with protocol 0.
+  Original patch by Alexandre Vassalotti.
+
 - Issue #27522: Avoid an unintentional reference cycle in email.feedparser.
 
 - Issue #26844: Fix error message for imp.find_module() to refer to 'path'
index c5c963ef6b7640b0ce2faef6a0eb42cd878f8594..3c21b6aac415db845201182dd9f5903eae1433a9 100644 (file)
@@ -3406,26 +3406,30 @@ save_pers(PicklerObject *self, PyObject *obj, PyObject *func)
                 goto error;
         }
         else {
-            PyObject *pid_str = NULL;
-            char *pid_ascii_bytes;
-            Py_ssize_t size;
+            PyObject *pid_str;
 
             pid_str = PyObject_Str(pid);
             if (pid_str == NULL)
                 goto error;
 
-            /* XXX: Should it check whether the persistent id only contains
-               ASCII characters? And what if the pid contains embedded
+            /* XXX: Should it check whether the pid contains embedded
                newlines? */
-            pid_ascii_bytes = _PyUnicode_AsStringAndSize(pid_str, &size);
-            Py_DECREF(pid_str);
-            if (pid_ascii_bytes == NULL)
+            if (!PyUnicode_IS_ASCII(pid_str)) {
+                PyErr_SetString(_Pickle_GetGlobalState()->PicklingError,
+                                "persistent IDs in protocol 0 must be "
+                                "ASCII strings");
+                Py_DECREF(pid_str);
                 goto error;
+            }
 
             if (_Pickler_Write(self, &persid_op, 1) < 0 ||
-                _Pickler_Write(self, pid_ascii_bytes, size) < 0 ||
-                _Pickler_Write(self, "\n", 1) < 0)
+                _Pickler_Write(self, PyUnicode_DATA(pid_str),
+                               PyUnicode_GET_LENGTH(pid_str)) < 0 ||
+                _Pickler_Write(self, "\n", 1) < 0) {
+                Py_DECREF(pid_str);
                 goto error;
+            }
+            Py_DECREF(pid_str);
         }
         status = 1;
     }
@@ -5389,9 +5393,15 @@ load_persid(UnpicklerObject *self)
         if (len < 1)
             return bad_readline();
 
-        pid = PyBytes_FromStringAndSize(s, len - 1);
-        if (pid == NULL)
+        pid = PyUnicode_DecodeASCII(s, len - 1, "strict");
+        if (pid == NULL) {
+            if (PyErr_ExceptionMatches(PyExc_UnicodeDecodeError)) {
+                PyErr_SetString(_Pickle_GetGlobalState()->UnpicklingError,
+                                "persistent IDs in protocol 0 must be "
+                                "ASCII strings");
+            }
             return -1;
+        }
 
         /* This does not leak since _Pickle_FastCall() steals the reference
            to pid first. */