]> granicus.if.org Git - python/commitdiff
Issue #7830: Flatten nested functools.partial.
authorAlexander Belopolsky <alexander.belopolsky@gmail.com>
Sun, 1 Mar 2015 20:08:17 +0000 (15:08 -0500)
committerAlexander Belopolsky <alexander.belopolsky@gmail.com>
Sun, 1 Mar 2015 20:08:17 +0000 (15:08 -0500)
Lib/functools.py
Lib/test/test_functools.py
Misc/NEWS
Modules/_functoolsmodule.c

index 20a26f9a2c7583b3b1d5b420e0b65dcd5ee8f33d..91e9685b98471846fd1292a46dd255fa2265cd24 100644 (file)
@@ -241,6 +241,14 @@ def partial(func, *args, **keywords):
     """New function with partial application of the given arguments
     and keywords.
     """
+    if hasattr(func, 'func'):
+        args = func.args + args
+        tmpkw = func.keywords.copy()
+        tmpkw.update(keywords)
+        keywords = tmpkw
+        del tmpkw
+        func = func.func
+
     def newfunc(*fargs, **fkeywords):
         newkeywords = keywords.copy()
         newkeywords.update(fkeywords)
index 55b96b4b8f4d65adea32dfb97d26d34f19c7196f..c549ac4cc4255b9af91608108544c3d282e17b38 100644 (file)
@@ -131,6 +131,16 @@ class TestPartial:
         join = self.partial(''.join)
         self.assertEqual(join(data), '0123456789')
 
+    def test_nested_optimization(self):
+        partial = self.partial
+        # Only "true" partial is optimized
+        if partial.__name__ != 'partial':
+            return
+        inner = partial(signature, 'asdf')
+        nested = partial(inner, bar=True)
+        flat = partial(signature, 'asdf', bar=True)
+        self.assertEqual(signature(nested), signature(flat))
+
 
 @unittest.skipUnless(c_functools, 'requires the C _functools module')
 class TestPartialC(TestPartial, unittest.TestCase):
index 1c6a90ce33448bf1fe5f4ace802008b2e52ffb4e..211303985eb4167da9b7b9ccac3b36a8539820fd 100644 (file)
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -13,6 +13,8 @@ Core and Builtins
 Library
 -------
 
+- Issue #7830: Flatten nested functools.partial.
+
 - Issue #20204: Added the __module__ attribute to _tkinter classes.
 
 - Issue #19980: Improved help() for non-recognized strings.  help('') now
index 57dfba041013b82ce0b4fa8ec2a013e2663ad41b..3413b12dfe9078368616f31d94c207976468f990 100644 (file)
@@ -25,7 +25,7 @@ static PyTypeObject partial_type;
 static PyObject *
 partial_new(PyTypeObject *type, PyObject *args, PyObject *kw)
 {
-    PyObject *func;
+    PyObject *func, *pargs, *nargs, *pkw;
     partialobject *pto;
 
     if (PyTuple_GET_SIZE(args) < 1) {
@@ -34,7 +34,16 @@ partial_new(PyTypeObject *type, PyObject *args, PyObject *kw)
         return NULL;
     }
 
+    pargs = pkw = Py_None;
     func = PyTuple_GET_ITEM(args, 0);
+    if (Py_TYPE(func) == &partial_type && type == &partial_type) {
+        partialobject *part = (partialobject *)func;
+        if (part->dict == NULL) {
+            pargs = part->args;
+            pkw = part->kw;
+            func = part->fn;
+        }
+    }
     if (!PyCallable_Check(func)) {
         PyErr_SetString(PyExc_TypeError,
                         "the first argument must be callable");
@@ -48,21 +57,53 @@ partial_new(PyTypeObject *type, PyObject *args, PyObject *kw)
 
     pto->fn = func;
     Py_INCREF(func);
-    pto->args = PyTuple_GetSlice(args, 1, PY_SSIZE_T_MAX);
-    if (pto->args == NULL) {
+
+    nargs = PyTuple_GetSlice(args, 1, PY_SSIZE_T_MAX);
+    if (nargs == NULL) {
+        pto->args = NULL;
         pto->kw = NULL;
         Py_DECREF(pto);
         return NULL;
     }
+    if (pargs == Py_None || PyTuple_GET_SIZE(pargs) == 0) {
+        pto->args = nargs;
+        Py_INCREF(nargs);
+    }
+    else if (PyTuple_GET_SIZE(nargs) == 0) {
+        pto->args = pargs;
+        Py_INCREF(pargs);
+    }
+    else {
+        pto->args = PySequence_Concat(pargs, nargs);
+        if (pto->args == NULL) {
+            pto->kw = NULL;
+            Py_DECREF(pto);
+            return NULL;
+        }
+    }
+    Py_DECREF(nargs);
+
     if (kw != NULL) {
-        pto->kw = PyDict_Copy(kw);
+        if (pkw == Py_None) {
+            pto->kw = PyDict_Copy(kw);
+        }
+        else {
+            pto->kw = PyDict_Copy(pkw);
+            if (pto->kw != NULL) {
+                if (PyDict_Merge(pto->kw, kw, 1) != 0) {
+                    Py_DECREF(pto);
+                    return NULL;
+                }
+            }
+        }
         if (pto->kw == NULL) {
             Py_DECREF(pto);
             return NULL;
         }
-    } else {
-        pto->kw = Py_None;
-        Py_INCREF(Py_None);
+    }
+    else {
+        pto->kw = pkw;
+        Py_INCREF(pkw);
     }
 
     pto->weakreflist = NULL;