]> granicus.if.org Git - python/commitdiff
Issue #27137: align Python & C implementations of functools.partial
authorNick Coghlan <ncoghlan@gmail.com>
Sat, 10 Sep 2016 10:00:02 +0000 (20:00 +1000)
committerNick Coghlan <ncoghlan@gmail.com>
Sat, 10 Sep 2016 10:00:02 +0000 (20:00 +1000)
The pure Python fallback implementation of functools.partial
now matches the behaviour of its accelerated C counterpart for
subclassing, pickling and text representation purposes.

Patch by Emanuel Barry and Serhiy Storchaka.

Lib/functools.py
Lib/test/test_functools.py
Misc/NEWS
Modules/_functoolsmodule.c

index 214523cbc24cfccc9fc74116153ec87cce097491..9845df224df4f2c7e4550f3a0b0c5d805036d138 100644 (file)
@@ -21,6 +21,7 @@ from abc import get_cache_token
 from collections import namedtuple
 from types import MappingProxyType
 from weakref import WeakKeyDictionary
+from reprlib import recursive_repr
 try:
     from _thread import RLock
 except ImportError:
@@ -237,26 +238,83 @@ except ImportError:
 ################################################################################
 
 # Purely functional, no descriptor behaviour
-def partial(func, *args, **keywords):
+class partial:
     """New function with partial application of the given arguments
     and keywords.
     """
-    if hasattr(func, 'func'):
-        args = func.args + args
-        tmpkw = func.keywords.copy()
-        tmpkw.update(keywords)
-        keywords = tmpkw
-        del tmpkw
-        func = func.func
-
-    def newfunc(*fargs, **fkeywords):
-        newkeywords = keywords.copy()
-        newkeywords.update(fkeywords)
-        return func(*(args + fargs), **newkeywords)
-    newfunc.func = func
-    newfunc.args = args
-    newfunc.keywords = keywords
-    return newfunc
+
+    __slots__ = "func", "args", "keywords", "__dict__", "__weakref__"
+
+    def __new__(*args, **keywords):
+        if not args:
+            raise TypeError("descriptor '__new__' of partial needs an argument")
+        if len(args) < 2:
+            raise TypeError("type 'partial' takes at least one argument")
+        cls, func, *args = args
+        if not callable(func):
+            raise TypeError("the first argument must be callable")
+        args = tuple(args)
+
+        if hasattr(func, "func"):
+            args = func.args + args
+            tmpkw = func.keywords.copy()
+            tmpkw.update(keywords)
+            keywords = tmpkw
+            del tmpkw
+            func = func.func
+
+        self = super(partial, cls).__new__(cls)
+
+        self.func = func
+        self.args = args
+        self.keywords = keywords
+        return self
+
+    def __call__(*args, **keywords):
+        if not args:
+            raise TypeError("descriptor '__call__' of partial needs an argument")
+        self, *args = args
+        newkeywords = self.keywords.copy()
+        newkeywords.update(keywords)
+        return self.func(*self.args, *args, **newkeywords)
+
+    @recursive_repr()
+    def __repr__(self):
+        qualname = type(self).__qualname__
+        args = [repr(self.func)]
+        args.extend(repr(x) for x in self.args)
+        args.extend(f"{k}={v!r}" for (k, v) in self.keywords.items())
+        if type(self).__module__ == "functools":
+            return f"functools.{qualname}({', '.join(args)})"
+        return f"{qualname}({', '.join(args)})"
+
+    def __reduce__(self):
+        return type(self), (self.func,), (self.func, self.args,
+               self.keywords or None, self.__dict__ or None)
+
+    def __setstate__(self, state):
+        if not isinstance(state, tuple):
+            raise TypeError("argument to __setstate__ must be a tuple")
+        if len(state) != 4:
+            raise TypeError(f"expected 4 items in state, got {len(state)}")
+        func, args, kwds, namespace = state
+        if (not callable(func) or not isinstance(args, tuple) or
+           (kwds is not None and not isinstance(kwds, dict)) or
+           (namespace is not None and not isinstance(namespace, dict))):
+            raise TypeError("invalid partial state")
+
+        args = tuple(args) # just in case it's a subclass
+        if kwds is None:
+            kwds = {}
+        elif type(kwds) is not dict: # XXX does it need to be *exactly* dict?
+            kwds = dict(kwds)
+        if namespace is None:
+            namespace = {}
+
+        self.__dict__ = namespace
+        self.func = func
+        self.args = args
+        self.keywords = kwds
 
 try:
     from _functools import partial
index 40f2234a7f2e05a5c78f50a208875d316c20ceef..fa66510bf1c8c038da1e7c6c547ad6e37eb1ba29 100644 (file)
@@ -8,6 +8,7 @@ import sys
 from test import support
 import unittest
 from weakref import proxy
+import contextlib
 try:
     import threading
 except ImportError:
@@ -20,6 +21,14 @@ c_functools = support.import_fresh_module('functools', fresh=['_functools'])
 
 decimal = support.import_fresh_module('decimal', fresh=['_decimal'])
 
+@contextlib.contextmanager
+def replaced_module(name, replacement):
+    original_module = sys.modules[name]
+    sys.modules[name] = replacement
+    try:
+        yield
+    finally:
+        sys.modules[name] = original_module
 
 def capture(*args, **kw):
     """capture all positional and keyword arguments"""
@@ -167,58 +176,35 @@ class TestPartial:
         p2.new_attr = 'spam'
         self.assertEqual(p2.new_attr, 'spam')
 
-
-@unittest.skipUnless(c_functools, 'requires the C _functools module')
-class TestPartialC(TestPartial, unittest.TestCase):
-    if c_functools:
-        partial = c_functools.partial
-
-    def test_attributes_unwritable(self):
-        # attributes should not be writable
-        p = self.partial(capture, 1, 2, a=10, b=20)
-        self.assertRaises(AttributeError, setattr, p, 'func', map)
-        self.assertRaises(AttributeError, setattr, p, 'args', (1, 2))
-        self.assertRaises(AttributeError, setattr, p, 'keywords', dict(a=1, b=2))
-
-        p = self.partial(hex)
-        try:
-            del p.__dict__
-        except TypeError:
-            pass
-        else:
-            self.fail('partial object allowed __dict__ to be deleted')
-
     def test_repr(self):
         args = (object(), object())
         args_repr = ', '.join(repr(a) for a in args)
         kwargs = {'a': object(), 'b': object()}
         kwargs_reprs = ['a={a!r}, b={b!r}'.format_map(kwargs),
                         'b={b!r}, a={a!r}'.format_map(kwargs)]
-        if self.partial is c_functools.partial:
+        if self.partial in (c_functools.partial, py_functools.partial):
             name = 'functools.partial'
         else:
             name = self.partial.__name__
 
         f = self.partial(capture)
-        self.assertEqual('{}({!r})'.format(name, capture),
-                         repr(f))
+        self.assertEqual(f'{name}({capture!r})', repr(f))
 
         f = self.partial(capture, *args)
-        self.assertEqual('{}({!r}, {})'.format(name, capture, args_repr),
-                         repr(f))
+        self.assertEqual(f'{name}({capture!r}, {args_repr})', repr(f))
 
         f = self.partial(capture, **kwargs)
         self.assertIn(repr(f),
-                      ['{}({!r}, {})'.format(name, capture, kwargs_repr)
+                      [f'{name}({capture!r}, {kwargs_repr})'
                        for kwargs_repr in kwargs_reprs])
 
         f = self.partial(capture, *args, **kwargs)
         self.assertIn(repr(f),
-                      ['{}({!r}, {}, {})'.format(name, capture, args_repr, kwargs_repr)
+                      [f'{name}({capture!r}, {args_repr}, {kwargs_repr})'
                        for kwargs_repr in kwargs_reprs])
 
     def test_recursive_repr(self):
-        if self.partial is c_functools.partial:
+        if self.partial in (c_functools.partial, py_functools.partial):
             name = 'functools.partial'
         else:
             name = self.partial.__name__
@@ -226,30 +212,31 @@ class TestPartialC(TestPartial, unittest.TestCase):
         f = self.partial(capture)
         f.__setstate__((f, (), {}, {}))
         try:
-            self.assertEqual(repr(f), '%s(%s(...))' % (name, name))
+            self.assertEqual(repr(f), '%s(...)' % (name,))
         finally:
             f.__setstate__((capture, (), {}, {}))
 
         f = self.partial(capture)
         f.__setstate__((capture, (f,), {}, {}))
         try:
-            self.assertEqual(repr(f), '%s(%r, %s(...))' % (name, capture, name))
+            self.assertEqual(repr(f), '%s(%r, ...)' % (name, capture,))
         finally:
             f.__setstate__((capture, (), {}, {}))
 
         f = self.partial(capture)
         f.__setstate__((capture, (), {'a': f}, {}))
         try:
-            self.assertEqual(repr(f), '%s(%r, a=%s(...))' % (name, capture, name))
+            self.assertEqual(repr(f), '%s(%r, a=...)' % (name, capture,))
         finally:
             f.__setstate__((capture, (), {}, {}))
 
     def test_pickle(self):
-        f = self.partial(signature, ['asdf'], bar=[True])
-        f.attr = []
-        for proto in range(pickle.HIGHEST_PROTOCOL + 1):
-            f_copy = pickle.loads(pickle.dumps(f, proto))
-            self.assertEqual(signature(f_copy), signature(f))
+        with self.AllowPickle():
+            f = self.partial(signature, ['asdf'], bar=[True])
+            f.attr = []
+            for proto in range(pickle.HIGHEST_PROTOCOL + 1):
+                f_copy = pickle.loads(pickle.dumps(f, proto))
+                self.assertEqual(signature(f_copy), signature(f))
 
     def test_copy(self):
         f = self.partial(signature, ['asdf'], bar=[True])
@@ -274,11 +261,13 @@ class TestPartialC(TestPartial, unittest.TestCase):
     def test_setstate(self):
         f = self.partial(signature)
         f.__setstate__((capture, (1,), dict(a=10), dict(attr=[])))
+
         self.assertEqual(signature(f),
                          (capture, (1,), dict(a=10), dict(attr=[])))
         self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
 
         f.__setstate__((capture, (1,), dict(a=10), None))
+
         self.assertEqual(signature(f), (capture, (1,), dict(a=10), {}))
         self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
 
@@ -325,38 +314,39 @@ class TestPartialC(TestPartial, unittest.TestCase):
         self.assertIs(type(r[0]), tuple)
 
     def test_recursive_pickle(self):
-        f = self.partial(capture)
-        f.__setstate__((f, (), {}, {}))
-        try:
-            for proto in range(pickle.HIGHEST_PROTOCOL + 1):
-                with self.assertRaises(RecursionError):
-                    pickle.dumps(f, proto)
-        finally:
-            f.__setstate__((capture, (), {}, {}))
-
-        f = self.partial(capture)
-        f.__setstate__((capture, (f,), {}, {}))
-        try:
-            for proto in range(pickle.HIGHEST_PROTOCOL + 1):
-                f_copy = pickle.loads(pickle.dumps(f, proto))
-                try:
-                    self.assertIs(f_copy.args[0], f_copy)
-                finally:
-                    f_copy.__setstate__((capture, (), {}, {}))
-        finally:
-            f.__setstate__((capture, (), {}, {}))
-
-        f = self.partial(capture)
-        f.__setstate__((capture, (), {'a': f}, {}))
-        try:
-            for proto in range(pickle.HIGHEST_PROTOCOL + 1):
-                f_copy = pickle.loads(pickle.dumps(f, proto))
-                try:
-                    self.assertIs(f_copy.keywords['a'], f_copy)
-                finally:
-                    f_copy.__setstate__((capture, (), {}, {}))
-        finally:
-            f.__setstate__((capture, (), {}, {}))
+        with self.AllowPickle():
+            f = self.partial(capture)
+            f.__setstate__((f, (), {}, {}))
+            try:
+                for proto in range(pickle.HIGHEST_PROTOCOL + 1):
+                    with self.assertRaises(RecursionError):
+                        pickle.dumps(f, proto)
+            finally:
+                f.__setstate__((capture, (), {}, {}))
+
+            f = self.partial(capture)
+            f.__setstate__((capture, (f,), {}, {}))
+            try:
+                for proto in range(pickle.HIGHEST_PROTOCOL + 1):
+                    f_copy = pickle.loads(pickle.dumps(f, proto))
+                    try:
+                        self.assertIs(f_copy.args[0], f_copy)
+                    finally:
+                        f_copy.__setstate__((capture, (), {}, {}))
+            finally:
+                f.__setstate__((capture, (), {}, {}))
+
+            f = self.partial(capture)
+            f.__setstate__((capture, (), {'a': f}, {}))
+            try:
+                for proto in range(pickle.HIGHEST_PROTOCOL + 1):
+                    f_copy = pickle.loads(pickle.dumps(f, proto))
+                    try:
+                        self.assertIs(f_copy.keywords['a'], f_copy)
+                    finally:
+                        f_copy.__setstate__((capture, (), {}, {}))
+            finally:
+                f.__setstate__((capture, (), {}, {}))
 
     # Issue 6083: Reference counting bug
     def test_setstate_refcount(self):
@@ -375,24 +365,60 @@ class TestPartialC(TestPartial, unittest.TestCase):
         f = self.partial(object)
         self.assertRaises(TypeError, f.__setstate__, BadSequence())
 
+@unittest.skipUnless(c_functools, 'requires the C _functools module')
+class TestPartialC(TestPartial, unittest.TestCase):
+    if c_functools:
+        partial = c_functools.partial
+
+    class AllowPickle:
+        def __enter__(self):
+            return self
+        def __exit__(self, type, value, tb):
+            return False
+
+    def test_attributes_unwritable(self):
+        # attributes should not be writable
+        p = self.partial(capture, 1, 2, a=10, b=20)
+        self.assertRaises(AttributeError, setattr, p, 'func', map)
+        self.assertRaises(AttributeError, setattr, p, 'args', (1, 2))
+        self.assertRaises(AttributeError, setattr, p, 'keywords', dict(a=1, b=2))
+
+        p = self.partial(hex)
+        try:
+            del p.__dict__
+        except TypeError:
+            pass
+        else:
+            self.fail('partial object allowed __dict__ to be deleted')
 
 class TestPartialPy(TestPartial, unittest.TestCase):
-    partial = staticmethod(py_functools.partial)
+    partial = py_functools.partial
 
+    class AllowPickle:
+        def __init__(self):
+            self._cm = replaced_module("functools", py_functools)
+        def __enter__(self):
+            return self._cm.__enter__()
+        def __exit__(self, type, value, tb):
+            return self._cm.__exit__(type, value, tb)
 
 if c_functools:
-    class PartialSubclass(c_functools.partial):
+    class CPartialSubclass(c_functools.partial):
         pass
 
+class PyPartialSubclass(py_functools.partial):
+    pass
 
 @unittest.skipUnless(c_functools, 'requires the C _functools module')
 class TestPartialCSubclass(TestPartialC):
     if c_functools:
-        partial = PartialSubclass
+        partial = CPartialSubclass
 
     # partial subclasses are not optimized for nested calls
     test_nested_optimization = None
 
+class TestPartialPySubclass(TestPartialPy):
+    partial = PyPartialSubclass
 
 class TestPartialMethod(unittest.TestCase):
 
@@ -683,9 +709,10 @@ class TestWraps(TestUpdateWrapper):
         self.assertEqual(wrapper.attr, 'This is a different test')
         self.assertEqual(wrapper.dict_attr, f.dict_attr)
 
-
+@unittest.skipUnless(c_functools, 'requires the C _functools module')
 class TestReduce(unittest.TestCase):
-    func = functools.reduce
+    if c_functools:
+        func = c_functools.reduce
 
     def test_reduce(self):
         class Squares:
index fd69eb777b7df6fa662cccb9297afad2fb7b99bf..a7a91046b143bcbd83d343d828a128cab00eb77a 100644 (file)
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -135,6 +135,11 @@ Core and Builtins
 Library
 -------
 
+- Issue #27137: the pure Python fallback implementation of ``functools.partial``
+  now matches the behaviour of its accelerated C counterpart for subclassing,
+  pickling and text representation purposes. Patch by Emanuel Barry and
+  Serhiy Storchaka.
+
 - Issue #28019: itertools.count() no longer rounds non-integer step in range
   between 1.0 and 2.0 to 1.
 
index 848a03cb07ae9573248cc16e77ffb8f93557d5f3..fa5fad3e754e79d5a2ad981b00c4fcbfd7a15e14 100644 (file)
@@ -229,7 +229,7 @@ partial_repr(partialobject *pto)
     if (status != 0) {
         if (status < 0)
             return NULL;
-        return PyUnicode_FromFormat("%s(...)", Py_TYPE(pto)->tp_name);
+        return PyUnicode_FromString("...");
     }
 
     arglist = PyUnicode_FromString("");