]> granicus.if.org Git - python/commitdiff
Issue #5322: Fixed setting __new__ to a PyCFunction inside Python code.
authorSerhiy Storchaka <storchaka@gmail.com>
Fri, 2 Dec 2016 06:42:43 +0000 (08:42 +0200)
committerSerhiy Storchaka <storchaka@gmail.com>
Fri, 2 Dec 2016 06:42:43 +0000 (08:42 +0200)
Original patch by Andreas Stührk.

Lib/test/test_descr.py
Misc/NEWS
Objects/typeobject.c

index fc965f7e5ea137e982188302ec7aa3477c7bd23e..e02f6914a0c64ae49c1a6a251de46ea32fa9345d 100644 (file)
@@ -3,6 +3,7 @@ import gc
 import sys
 import types
 import unittest
+import warnings
 import weakref
 
 from copy import deepcopy
@@ -1550,6 +1551,84 @@ order (MRO) for bases """
         self.assertEqual(b.foo, 3)
         self.assertEqual(b.__class__, D)
 
+    def test_bad_new(self):
+        self.assertRaises(TypeError, object.__new__)
+        self.assertRaises(TypeError, object.__new__, '')
+        self.assertRaises(TypeError, list.__new__, object)
+        self.assertRaises(TypeError, object.__new__, list)
+        class C(object):
+            __new__ = list.__new__
+        self.assertRaises(TypeError, C)
+        class C(list):
+            __new__ = object.__new__
+        self.assertRaises(TypeError, C)
+
+    def test_object_new(self):
+        class A(object):
+            pass
+        object.__new__(A)
+        self.assertRaises(TypeError, object.__new__, A, 5)
+        object.__init__(A())
+        self.assertRaises(TypeError, object.__init__, A(), 5)
+
+        class A(object):
+            def __init__(self, foo):
+                self.foo = foo
+        object.__new__(A)
+        object.__new__(A, 5)
+        object.__init__(A(3))
+        self.assertRaises(TypeError, object.__init__, A(3), 5)
+
+        class A(object):
+            def __new__(cls, foo):
+                return object.__new__(cls)
+        object.__new__(A)
+        self.assertRaises(TypeError, object.__new__, A, 5)
+        object.__init__(A(3))
+        object.__init__(A(3), 5)
+
+        class A(object):
+            def __new__(cls, foo):
+                return object.__new__(cls)
+            def __init__(self, foo):
+                self.foo = foo
+        object.__new__(A)
+        with warnings.catch_warnings(record=True) as w:
+            warnings.simplefilter('always', DeprecationWarning)
+            a = object.__new__(A, 5)
+        self.assertEqual(type(a), A)
+        self.assertEqual(len(w), 1)
+        object.__init__(A(3))
+        a = A(3)
+        with warnings.catch_warnings(record=True) as w:
+            warnings.simplefilter('always', DeprecationWarning)
+            object.__init__(a, 5)
+        self.assertEqual(a.foo, 3)
+        self.assertEqual(len(w), 1)
+
+    def test_restored_object_new(self):
+        class A(object):
+            def __new__(cls, *args, **kwargs):
+                raise AssertionError
+        self.assertRaises(AssertionError, A)
+        class B(A):
+            __new__ = object.__new__
+            def __init__(self, foo):
+                self.foo = foo
+        with warnings.catch_warnings():
+            warnings.simplefilter('error', DeprecationWarning)
+            b = B(3)
+        self.assertEqual(b.foo, 3)
+        self.assertEqual(b.__class__, B)
+        del B.__new__
+        self.assertRaises(AssertionError, B)
+        del A.__new__
+        with warnings.catch_warnings():
+            warnings.simplefilter('error', DeprecationWarning)
+            b = B(3)
+        self.assertEqual(b.foo, 3)
+        self.assertEqual(b.__class__, B)
+
     def test_altmro(self):
         # Testing mro() and overriding it...
         class A(object):
@@ -3756,6 +3835,24 @@ order (MRO) for bases """
         self.assertEqual(isinstance(d, D), True)
         self.assertEqual(d.foo, 1)
 
+        class C(object):
+            @staticmethod
+            def __new__(*args):
+                return args
+        self.assertEqual(C(1, 2), (C, 1, 2))
+        class D(C):
+            pass
+        self.assertEqual(D(1, 2), (D, 1, 2))
+
+        class C(object):
+            @classmethod
+            def __new__(*args):
+                return args
+        self.assertEqual(C(1, 2), (C, C, 1, 2))
+        class D(C):
+            pass
+        self.assertEqual(D(1, 2), (D, D, 1, 2))
+
     def test_imul_bug(self):
         # Testing for __imul__ problems...
         # SF bug 544647
index cc2ab92604eb274a93a96443ec39846b0a3fcfc4..c0f27ca6eea1192ef08b927accb6ae2dde7134b6 100644 (file)
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -10,6 +10,9 @@ What's New in Python 2.7.13?
 Core and Builtins
 -----------------
 
+- Issue #5322: Fixed setting __new__ to a PyCFunction inside Python code.
+  Original patch by Andreas Stührk.
+
 - Issue #28847: dubmdbm no longer writes the index file in when it is not
   changed and supports reading read-only files.
 
index 932f9e994a50c0d21ea8d41823be24b57f090314..69a996a8fd8c303afcce35ca8d723c1581e38250 100644 (file)
@@ -6304,7 +6304,33 @@ update_one_slot(PyTypeObject *type, slotdef *p)
                sanity checks and constructing a new argument
                list.  Cut all that nonsense short -- this speeds
                up instance creation tremendously. */
-            specific = (void *)type->tp_new;
+            PyObject *self = PyCFunction_GET_SELF(descr);
+            if (!self || !PyType_Check(self)) {
+                /* This should never happen because
+                   tp_new_wrapper expects a type for self.
+                   Use slot_tp_new which will call
+                   tp_new_wrapper which will raise an
+                   exception. */
+                specific = (void *)slot_tp_new;
+            }
+            else {
+                specific = ((PyTypeObject *)self)->tp_new;
+                /* Check that the user does not do anything
+                   silly and unsafe like object.__new__(dict).
+                   To do this, we check that the most derived
+                   base that's not a heap type is this type. */
+                PyTypeObject *staticbase = type->tp_base;
+                while (staticbase &&
+                       (staticbase->tp_flags & Py_TPFLAGS_HEAPTYPE))
+                    staticbase = staticbase->tp_base;
+                if (staticbase &&
+                    staticbase->tp_new != specific)
+                    /* Seems to be unsafe, better use
+                       slot_tp_new which will call
+                       tp_new_wrapper which will raise an
+                       exception if it is unsafe. */
+                    specific = (void *)slot_tp_new;
+            }
             /* XXX I'm not 100% sure that there isn't a hole
                in this reasoning that requires additional
                sanity checks.  I'll buy the first person to