]> granicus.if.org Git - python/commitdiff
Add support for copy_reg.dispatch_table.
authorGuido van Rossum <guido@python.org>
Fri, 7 Feb 2003 17:30:18 +0000 (17:30 +0000)
committerGuido van Rossum <guido@python.org>
Fri, 7 Feb 2003 17:30:18 +0000 (17:30 +0000)
Rewrote copy() and deepcopy() without avoidable try/except statements;
getattr(x, name, None) or dict.get() are much faster than try/except.

Lib/copy.py
Lib/test/test_copy.py

index 4133a1f1dc324f43ae632947414e123236432689..c1c0ec0cbdc7afdd683b6a8ee2e356572b82ceb9 100644 (file)
@@ -48,10 +48,8 @@ __getstate__() and __setstate__().  See the documentation for module
 "pickle" for information on these methods.
 """
 
-# XXX need to support copy_reg here too...
-
 import types
-from copy_reg import _better_reduce
+from copy_reg import _better_reduce, dispatch_table
 
 class Error(Exception):
     pass
@@ -70,25 +68,25 @@ def copy(x):
     See the module's __doc__ string for more info.
     """
 
-    try:
-        copierfunction = _copy_dispatch[type(x)]
-    except KeyError:
-        try:
-            copier = x.__copy__
-        except AttributeError:
-            try:
-                reductor = x.__class__.__reduce__
-                if reductor == object.__reduce__:
-                    reductor = _better_reduce
-            except AttributeError:
-                raise Error("un(shallow)copyable object of type %s" % type(x))
-            else:
-                y = _reconstruct(x, reductor(x), 0)
-        else:
-            y = copier()
-    else:
-        y = copierfunction(x)
-    return y
+    cls = type(x)
+
+    copier = _copy_dispatch.get(cls)
+    if copier:
+        return copier(x)
+
+    copier = getattr(cls, "__copy__", None)
+    if copier:
+        return copier(x)
+
+    reductor = dispatch_table.get(cls)
+    if not reductor:
+        reductor = getattr(cls, "__reduce__", None)
+        if reductor == object.__reduce__:
+            reductor = _better_reduce
+        elif not reductor:
+            raise Error("un(shallow)copyable object of type %s" % cls)
+
+    return _reconstruct(x, reductor(x), 0)
     
 
 _copy_dispatch = d = {}
@@ -153,7 +151,7 @@ d[types.InstanceType] = _copy_inst
 
 del d
 
-def deepcopy(x, memo = None):
+def deepcopy(x, memo=None, _nil=[]):
     """Deep copy operation on arbitrary Python objects.
 
     See the module's __doc__ string for more info.
@@ -161,35 +159,39 @@ def deepcopy(x, memo = None):
 
     if memo is None:
         memo = {}
+
     d = id(x)
-    if d in memo:
-        return memo[d]
-    try:
-        copierfunction = _deepcopy_dispatch[type(x)]
-    except KeyError:
+    y = memo.get(d, _nil)
+    if y is not _nil:
+        return y
+
+    cls = type(x)
+
+    copier = _deepcopy_dispatch.get(cls)
+    if copier:
+        y = copier(x, memo)
+    else:
         try:
-            issc = issubclass(type(x), type)
-        except TypeError:
+            issc = issubclass(cls, type)
+        except TypeError: # cls is not a class (old Boost; see SF #502085)
             issc = 0
         if issc:
-            y = _deepcopy_dispatch[type](x, memo)
+            copier = _deepcopy_atomic
         else:
-            try:
-                copier = x.__deepcopy__
-            except AttributeError:
-                try:
-                    reductor = x.__class__.__reduce__
-                    if reductor == object.__reduce__:
-                        reductor = _better_reduce
-                except AttributeError:
-                    raise Error("un(shallow)copyable object of type %s" %
-                                type(x))
-                else:
-                    y = _reconstruct(x, reductor(x), 1, memo)
-            else:
-                y = copier(memo)
-    else:
-        y = copierfunction(x, memo)
+            copier = getattr(cls, "__deepcopy__", None)
+
+        if copier:
+            y = copier(x, memo)
+        else:
+            reductor = dispatch_table.get(cls)
+            if not reductor:
+                reductor = getattr(cls, "__reduce__", None)
+                if reductor == object.__reduce__:
+                    reductor = _better_reduce
+                elif not reductor:
+                    raise Error("un(deep)copyable object of type %s" % cls)
+            y = _reconstruct(x, reductor(x), 1, memo)
+
     memo[d] = y
     _keep_alive(x, memo) # Make sure x lives at least as long as d
     return y
@@ -380,7 +382,7 @@ def _test():
         def __setstate__(self, state):
             for key, value in state.iteritems():
                 setattr(self, key, value)
-        def __deepcopy__(self, memo = None):
+        def __deepcopy__(self, memo=None):
             new = self.__class__(deepcopy(self.arg, memo))
             new.a = self.a
             return new
index c97d54d749102b5a01185094aa825948a823cc93..35ce46a5232813f3fef4f6869d6100dbe9f3c0d1 100644 (file)
@@ -2,6 +2,7 @@
 
 import sys
 import copy
+import copy_reg
 
 import unittest
 from test import test_support
@@ -32,6 +33,19 @@ class TestCopy(unittest.TestCase):
         self.assertEqual(y.__class__, x.__class__)
         self.assertEqual(y.foo, x.foo)
 
+    def test_copy_registry(self):
+        class C(object):
+            def __new__(cls, foo):
+                obj = object.__new__(cls)
+                obj.foo = foo
+                return obj
+        def pickle_C(obj):
+            return (C, (obj.foo,))
+        x = C(42)
+        self.assertRaises(TypeError, copy.copy, x)
+        copy_reg.pickle(C, pickle_C, C)
+        y = copy.copy(x)
+
     def test_copy_reduce(self):
         class C(object):
             def __reduce__(self):
@@ -182,6 +196,19 @@ class TestCopy(unittest.TestCase):
         self.assertEqual(y.__class__, x.__class__)
         self.assertEqual(y.foo, x.foo)
 
+    def test_deepcopy_registry(self):
+        class C(object):
+            def __new__(cls, foo):
+                obj = object.__new__(cls)
+                obj.foo = foo
+                return obj
+        def pickle_C(obj):
+            return (C, (obj.foo,))
+        x = C(42)
+        self.assertRaises(TypeError, copy.deepcopy, x)
+        copy_reg.pickle(C, pickle_C, C)
+        y = copy.deepcopy(x)
+
     def test_deepcopy_reduce(self):
         class C(object):
             def __reduce__(self):