]> granicus.if.org Git - python/commitdiff
Backport recent typing updates (GH-6759)
authorIvan Levkivskyi <levkivskyi@gmail.com>
Fri, 11 May 2018 03:15:14 +0000 (23:15 -0400)
committerGitHub <noreply@github.com>
Fri, 11 May 2018 03:15:14 +0000 (23:15 -0400)
Lib/test/test_typing.py
Lib/typing.py
Misc/NEWS.d/next/Library/2018-05-10-14-51-19.bpo-28556.y3zK6I.rst [new file with mode: 0644]

index fe5247c260278a201afc310e437d873a89a14d92..4843d6faf70414901ecc7b40f59c9d2a1de7c504 100644 (file)
@@ -1310,6 +1310,74 @@ class GenericTests(BaseTestCase):
         with self.assertRaises(Exception):
             D[T]
 
+    def test_new_with_args(self):
+
+        class A(Generic[T]):
+            pass
+
+        class B:
+            def __new__(cls, arg):
+                # call object
+                obj = super().__new__(cls)
+                obj.arg = arg
+                return obj
+
+        # mro: C, A, Generic, B, object
+        class C(A, B):
+            pass
+
+        c = C('foo')
+        self.assertEqual(c.arg, 'foo')
+
+    def test_new_with_args2(self):
+
+        class A:
+            def __init__(self, arg):
+                self.from_a = arg
+                # call object
+                super().__init__()
+
+        # mro: C, Generic, A, object
+        class C(Generic[T], A):
+            def __init__(self, arg):
+                self.from_c = arg
+                # call Generic
+                super().__init__(arg)
+
+        c = C('foo')
+        self.assertEqual(c.from_a, 'foo')
+        self.assertEqual(c.from_c, 'foo')
+
+    def test_new_no_args(self):
+
+        class A(Generic[T]):
+            pass
+
+        with self.assertRaises(TypeError):
+            A('foo')
+
+        class B:
+            def __new__(cls):
+                # call object
+                obj = super().__new__(cls)
+                obj.from_b = 'b'
+                return obj
+
+        # mro: C, A, Generic, B, object
+        class C(A, B):
+            def __init__(self, arg):
+                self.arg = arg
+
+            def __new__(cls, arg):
+                # call A
+                obj = super().__new__(cls)
+                obj.from_c = 'c'
+                return obj
+
+        c = C('foo')
+        self.assertEqual(c.arg, 'foo')
+        self.assertEqual(c.from_b, 'b')
+        self.assertEqual(c.from_c, 'c')
 
 class ClassVarTests(BaseTestCase):
 
@@ -1739,6 +1807,8 @@ class GetTypeHintTests(BaseTestCase):
         self.assertEqual(gth(HasForeignBaseClass),
                          {'some_xrepr': XRepr, 'other_a': mod_generics_cache.A,
                           'some_b': mod_generics_cache.B})
+        self.assertEqual(gth(XRepr.__new__),
+                         {'x': int, 'y': int})
         self.assertEqual(gth(mod_generics_cache.B),
                          {'my_inner_a1': mod_generics_cache.B.A,
                           'my_inner_a2': mod_generics_cache.B.A,
index b5564cc29a2d830eb833564d37e332a7e7ee9b49..f2b6aaf3a9278ceb08978392a6c169e09a581db2 100644 (file)
@@ -1181,10 +1181,18 @@ def _generic_new(base_cls, cls, *args, **kwds):
     # Assure type is erased on instantiation,
     # but attempt to store it in __orig_class__
     if cls.__origin__ is None:
-        return base_cls.__new__(cls)
+        if (base_cls.__new__ is object.__new__ and
+                cls.__init__ is not object.__init__):
+            return base_cls.__new__(cls)
+        else:
+            return base_cls.__new__(cls, *args, **kwds)
     else:
         origin = cls._gorg
-        obj = base_cls.__new__(origin)
+        if (base_cls.__new__ is object.__new__ and
+                cls.__init__ is not object.__init__):
+            obj = base_cls.__new__(origin)
+        else:
+            obj = base_cls.__new__(origin, *args, **kwds)
         try:
             obj.__orig_class__ = cls
         except AttributeError:
@@ -2146,6 +2154,7 @@ class NamedTupleMeta(type):
                                 "follow default field(s) {default_names}"
                                 .format(field_name=field_name,
                                         default_names=', '.join(defaults_dict.keys())))
+        nm_tpl.__new__.__annotations__ = collections.OrderedDict(types)
         nm_tpl.__new__.__defaults__ = tuple(defaults)
         nm_tpl._field_defaults = defaults_dict
         # update from user namespace without overriding special namedtuple attributes
diff --git a/Misc/NEWS.d/next/Library/2018-05-10-14-51-19.bpo-28556.y3zK6I.rst b/Misc/NEWS.d/next/Library/2018-05-10-14-51-19.bpo-28556.y3zK6I.rst
new file mode 100644 (file)
index 0000000..8ed4658
--- /dev/null
@@ -0,0 +1,3 @@
+Minor fixes in typing module: add annotations to ``NamedTuple.__new__``,
+pass ``*args`` and ``**kwds`` in ``Generic.__new__``.  Original PRs by
+Paulius Ĺ arka and Chad Dombrova.