]> granicus.if.org Git - python/commitdiff
Add cPickle support for PROTO. Duplicated PROTO/LONG1/LONG4 code in
authorTim Peters <tim.peters@gmail.com>
Sun, 2 Feb 2003 16:09:05 +0000 (16:09 +0000)
committerTim Peters <tim.peters@gmail.com>
Sun, 2 Feb 2003 16:09:05 +0000 (16:09 +0000)
the hitherto unknown (to me) noload() cPickle function, which is (a)
something we don't test at all, and (b) pickle.py doesn't have.

Lib/test/pickletester.py
Modules/cPickle.c

index 2a1ca179c1b2bd33cff92d4b569c48a5786a21b4..5ef0cf2e2484b417853dfcba1f72b22616c57d89 100644 (file)
@@ -1,4 +1,6 @@
 import unittest
+import pickle
+
 from test.test_support import TestFailed, have_unicode, TESTFN
 
 # Tests that try a number of pickle protocols should have a
@@ -296,6 +298,25 @@ class AbstractPickleTests(unittest.TestCase):
 
     # Tests for protocol 2
 
+    def test_proto(self):
+        build_none = pickle.NONE + pickle.STOP
+        for proto in protocols:
+            expected = build_none
+            if proto >= 2:
+                expected = pickle.PROTO + chr(proto) + expected
+            p = self.dumps(None, proto)
+            self.assertEqual(p, expected)
+
+        oob = protocols[-1] + 1     # a future protocol
+        badpickle = pickle.PROTO + chr(oob) + build_none
+        try:
+            self.loads(badpickle)
+        except ValueError, detail:
+            self.failUnless(str(detail).startswith(
+                                            "unsupported pickle protocol"))
+        else:
+            self.fail("expected bad protocol number to raise ValueError")
+
     def test_long1(self):
         x = 12345678910111213141516178920L
         s = self.dumps(x, 2)
@@ -314,14 +335,14 @@ class AbstractPickleTests(unittest.TestCase):
         c = (1, 2)
         d = (1, 2, 3)
         e = (1, 2, 3, 4)
-        for proto in 0, 1, 2:
+        for proto in protocols:
             for x in a, b, c, d, e:
                 s = self.dumps(x, proto)
                 y = self.loads(s)
                 self.assertEqual(x, y, (proto, x, s, y))
 
     def test_singletons(self):
-        for proto in 0, 1, 2:
+        for proto in protocols:
             for x in None, False, True:
                 s = self.dumps(x, proto)
                 y = self.loads(s)
index b59f573dfee44ee1650ad1a71604c3a962355127..43a8d331029a0bcf1d1d1461bb676374f58d6a37 100644 (file)
@@ -2213,13 +2213,22 @@ dump(Picklerobject *self, PyObject *args)
 {
        static char stop = STOP;
 
+       if (self->proto >= 2) {
+               char bytes[2];
+
+               bytes[0] = PROTO;
+               bytes[1] = CURRENT_PROTOCOL_NUMBER;
+               if (self->write_func(self, bytes, 2) < 0)
+                       return -1;
+       }
+
        if (save(self, args, 0) < 0)
                return -1;
 
-       if ((*self->write_func)(self, &stop, 1) < 0)
+       if (self->write_func(self, &stop, 1) < 0)
                return -1;
 
-       if ((*self->write_func)(self, NULL, 0) < 0)
+       if (self->write_func(self, NULL, 0) < 0)
                return -1;
 
        return 0;
@@ -3870,6 +3879,31 @@ load_reduce(Unpicklerobject *self)
        return 0;
 }
 
+/* Just raises an error if we don't know the protocol specified.  PROTO
+ * is the first opcode for protocols >= 2.
+ */
+static int
+load_proto(Unpicklerobject *self)
+{
+       int i;
+       char *protobyte;
+
+       i = self->read_func(self, &protobyte, 1);
+       if (i < 0)
+               return -1;
+
+       i = calc_binint(protobyte, 1);
+       /* No point checking for < 0, since calc_binint returns an unsigned
+        * int when chewing on 1 byte.
+        */
+       assert(i >= 0);
+       if (i <= CURRENT_PROTOCOL_NUMBER)
+               return 0;
+
+       PyErr_Format(PyExc_ValueError, "unsupported pickle protocol: %d", i);
+       return -1;
+}
+
 static PyObject *
 load(Unpicklerobject *self)
 {
@@ -4099,6 +4133,11 @@ load(Unpicklerobject *self)
                                break;
                        continue;
 
+               case PROTO:
+                       if (load_proto(self) < 0)
+                               break;
+                       continue;
+
                case '\0':
                        /* end of file */
                        PyErr_SetNone(PyExc_EOFError);
@@ -4227,6 +4266,16 @@ noload(Unpicklerobject *self)
                                break;
                        continue;
 
+               case LONG1:
+                       if (load_counted_long(self, 1) < 0)
+                               break;
+                       continue;
+
+               case LONG4:
+                       if (load_counted_long(self, 4) < 0)
+                               break;
+                       continue;
+
                case FLOAT:
                        if (load_float(self) < 0)
                                break;
@@ -4402,6 +4451,11 @@ noload(Unpicklerobject *self)
                                break;
                        continue;
 
+               case PROTO:
+                       if (load_proto(self) < 0)
+                               break;
+                       continue;
+
                default:
                        cPickle_ErrFormat(UnpicklingError,
                                          "invalid load key, '%s'.",