]> granicus.if.org Git - python/commitdiff
bpo-33536: Validate make_dataclass() field names. (GH-6906)
authorEric V. Smith <ericvsmith@users.noreply.github.com>
Wed, 16 May 2018 15:31:29 +0000 (11:31 -0400)
committerGitHub <noreply@github.com>
Wed, 16 May 2018 15:31:29 +0000 (11:31 -0400)
Lib/dataclasses.py
Lib/test/test_dataclasses.py
Misc/NEWS.d/next/Library/2018-05-16-10-07-40.bpo-33536._s0TE8.rst [new file with mode: 0644]

index bb77d3b4052b2250cead4f2cb0abe1908454c173..2c5593bfc50dc060becb0aa49e2721de9758afb7 100644 (file)
@@ -3,6 +3,7 @@ import sys
 import copy
 import types
 import inspect
+import keyword
 
 __all__ = ['dataclass',
            'field',
@@ -1100,6 +1101,9 @@ def make_dataclass(cls_name, fields, *, bases=(), namespace=None, init=True,
         # Copy namespace since we're going to mutate it.
         namespace = namespace.copy()
 
+    # While we're looking through the field names, validate that they
+    # are identifiers, are not keywords, and not duplicates.
+    seen = set()
     anns = {}
     for item in fields:
         if isinstance(item, str):
@@ -1110,6 +1114,17 @@ def make_dataclass(cls_name, fields, *, bases=(), namespace=None, init=True,
         elif len(item) == 3:
             name, tp, spec = item
             namespace[name] = spec
+        else:
+            raise TypeError(f'Invalid field: {item!r}')
+
+        if not isinstance(name, str) or not name.isidentifier():
+            raise TypeError(f'Field names must be valid identifers: {name!r}')
+        if keyword.iskeyword(name):
+            raise TypeError(f'Field names must not be keywords: {name!r}')
+        if name in seen:
+            raise TypeError(f'Field name duplicated: {name!r}')
+
+        seen.add(name)
         anns[name] = tp
 
     namespace['__annotations__'] = anns
index b251c04cb9a1c7e0173e2c9ac5dbff564d76e15e..7c39b79142b29498195d6c2fed4b88f76f2de853 100755 (executable)
@@ -1826,114 +1826,6 @@ class TestCase(unittest.TestCase):
                     self.assertEqual(new_sample.x, another_new_sample.x)
                     self.assertEqual(sample.y, another_new_sample.y)
 
-    def test_helper_make_dataclass(self):
-        C = make_dataclass('C',
-                           [('x', int),
-                            ('y', int, field(default=5))],
-                           namespace={'add_one': lambda self: self.x + 1})
-        c = C(10)
-        self.assertEqual((c.x, c.y), (10, 5))
-        self.assertEqual(c.add_one(), 11)
-
-
-    def test_helper_make_dataclass_no_mutate_namespace(self):
-        # Make sure a provided namespace isn't mutated.
-        ns = {}
-        C = make_dataclass('C',
-                           [('x', int),
-                            ('y', int, field(default=5))],
-                           namespace=ns)
-        self.assertEqual(ns, {})
-
-    def test_helper_make_dataclass_base(self):
-        class Base1:
-            pass
-        class Base2:
-            pass
-        C = make_dataclass('C',
-                           [('x', int)],
-                           bases=(Base1, Base2))
-        c = C(2)
-        self.assertIsInstance(c, C)
-        self.assertIsInstance(c, Base1)
-        self.assertIsInstance(c, Base2)
-
-    def test_helper_make_dataclass_base_dataclass(self):
-        @dataclass
-        class Base1:
-            x: int
-        class Base2:
-            pass
-        C = make_dataclass('C',
-                           [('y', int)],
-                           bases=(Base1, Base2))
-        with self.assertRaisesRegex(TypeError, 'required positional'):
-            c = C(2)
-        c = C(1, 2)
-        self.assertIsInstance(c, C)
-        self.assertIsInstance(c, Base1)
-        self.assertIsInstance(c, Base2)
-
-        self.assertEqual((c.x, c.y), (1, 2))
-
-    def test_helper_make_dataclass_init_var(self):
-        def post_init(self, y):
-            self.x *= y
-
-        C = make_dataclass('C',
-                           [('x', int),
-                            ('y', InitVar[int]),
-                            ],
-                           namespace={'__post_init__': post_init},
-                           )
-        c = C(2, 3)
-        self.assertEqual(vars(c), {'x': 6})
-        self.assertEqual(len(fields(c)), 1)
-
-    def test_helper_make_dataclass_class_var(self):
-        C = make_dataclass('C',
-                           [('x', int),
-                            ('y', ClassVar[int], 10),
-                            ('z', ClassVar[int], field(default=20)),
-                            ])
-        c = C(1)
-        self.assertEqual(vars(c), {'x': 1})
-        self.assertEqual(len(fields(c)), 1)
-        self.assertEqual(C.y, 10)
-        self.assertEqual(C.z, 20)
-
-    def test_helper_make_dataclass_other_params(self):
-        C = make_dataclass('C',
-                           [('x', int),
-                            ('y', ClassVar[int], 10),
-                            ('z', ClassVar[int], field(default=20)),
-                            ],
-                           init=False)
-        # Make sure we have a repr, but no init.
-        self.assertNotIn('__init__', vars(C))
-        self.assertIn('__repr__', vars(C))
-
-        # Make sure random other params don't work.
-        with self.assertRaisesRegex(TypeError, 'unexpected keyword argument'):
-            C = make_dataclass('C',
-                               [],
-                               xxinit=False)
-
-    def test_helper_make_dataclass_no_types(self):
-        C = make_dataclass('Point', ['x', 'y', 'z'])
-        c = C(1, 2, 3)
-        self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3})
-        self.assertEqual(C.__annotations__, {'x': 'typing.Any',
-                                             'y': 'typing.Any',
-                                             'z': 'typing.Any'})
-
-        C = make_dataclass('Point', ['x', ('y', int), 'z'])
-        c = C(1, 2, 3)
-        self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3})
-        self.assertEqual(C.__annotations__, {'x': 'typing.Any',
-                                             'y': int,
-                                             'z': 'typing.Any'})
-
 
 class TestFieldNoAnnotation(unittest.TestCase):
     def test_field_without_annotation(self):
@@ -2947,5 +2839,170 @@ class TestStringAnnotations(unittest.TestCase):
                     self.assertNotIn('not_iv4', c.__dict__)
 
 
+class TestMakeDataclass(unittest.TestCase):
+    def test_simple(self):
+        C = make_dataclass('C',
+                           [('x', int),
+                            ('y', int, field(default=5))],
+                           namespace={'add_one': lambda self: self.x + 1})
+        c = C(10)
+        self.assertEqual((c.x, c.y), (10, 5))
+        self.assertEqual(c.add_one(), 11)
+
+
+    def test_no_mutate_namespace(self):
+        # Make sure a provided namespace isn't mutated.
+        ns = {}
+        C = make_dataclass('C',
+                           [('x', int),
+                            ('y', int, field(default=5))],
+                           namespace=ns)
+        self.assertEqual(ns, {})
+
+    def test_base(self):
+        class Base1:
+            pass
+        class Base2:
+            pass
+        C = make_dataclass('C',
+                           [('x', int)],
+                           bases=(Base1, Base2))
+        c = C(2)
+        self.assertIsInstance(c, C)
+        self.assertIsInstance(c, Base1)
+        self.assertIsInstance(c, Base2)
+
+    def test_base_dataclass(self):
+        @dataclass
+        class Base1:
+            x: int
+        class Base2:
+            pass
+        C = make_dataclass('C',
+                           [('y', int)],
+                           bases=(Base1, Base2))
+        with self.assertRaisesRegex(TypeError, 'required positional'):
+            c = C(2)
+        c = C(1, 2)
+        self.assertIsInstance(c, C)
+        self.assertIsInstance(c, Base1)
+        self.assertIsInstance(c, Base2)
+
+        self.assertEqual((c.x, c.y), (1, 2))
+
+    def test_init_var(self):
+        def post_init(self, y):
+            self.x *= y
+
+        C = make_dataclass('C',
+                           [('x', int),
+                            ('y', InitVar[int]),
+                            ],
+                           namespace={'__post_init__': post_init},
+                           )
+        c = C(2, 3)
+        self.assertEqual(vars(c), {'x': 6})
+        self.assertEqual(len(fields(c)), 1)
+
+    def test_class_var(self):
+        C = make_dataclass('C',
+                           [('x', int),
+                            ('y', ClassVar[int], 10),
+                            ('z', ClassVar[int], field(default=20)),
+                            ])
+        c = C(1)
+        self.assertEqual(vars(c), {'x': 1})
+        self.assertEqual(len(fields(c)), 1)
+        self.assertEqual(C.y, 10)
+        self.assertEqual(C.z, 20)
+
+    def test_other_params(self):
+        C = make_dataclass('C',
+                           [('x', int),
+                            ('y', ClassVar[int], 10),
+                            ('z', ClassVar[int], field(default=20)),
+                            ],
+                           init=False)
+        # Make sure we have a repr, but no init.
+        self.assertNotIn('__init__', vars(C))
+        self.assertIn('__repr__', vars(C))
+
+        # Make sure random other params don't work.
+        with self.assertRaisesRegex(TypeError, 'unexpected keyword argument'):
+            C = make_dataclass('C',
+                               [],
+                               xxinit=False)
+
+    def test_no_types(self):
+        C = make_dataclass('Point', ['x', 'y', 'z'])
+        c = C(1, 2, 3)
+        self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3})
+        self.assertEqual(C.__annotations__, {'x': 'typing.Any',
+                                             'y': 'typing.Any',
+                                             'z': 'typing.Any'})
+
+        C = make_dataclass('Point', ['x', ('y', int), 'z'])
+        c = C(1, 2, 3)
+        self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3})
+        self.assertEqual(C.__annotations__, {'x': 'typing.Any',
+                                             'y': int,
+                                             'z': 'typing.Any'})
+
+    def test_invalid_type_specification(self):
+        for bad_field in [(),
+                          (1, 2, 3, 4),
+                          ]:
+            with self.subTest(bad_field=bad_field):
+                with self.assertRaisesRegex(TypeError, r'Invalid field: '):
+                    make_dataclass('C', ['a', bad_field])
+
+        # And test for things with no len().
+        for bad_field in [float,
+                          lambda x:x,
+                          ]:
+            with self.subTest(bad_field=bad_field):
+                with self.assertRaisesRegex(TypeError, r'has no len\(\)'):
+                    make_dataclass('C', ['a', bad_field])
+
+    def test_duplicate_field_names(self):
+        for field in ['a', 'ab']:
+            with self.subTest(field=field):
+                with self.assertRaisesRegex(TypeError, 'Field name duplicated'):
+                    make_dataclass('C', [field, 'a', field])
+
+    def test_keyword_field_names(self):
+        for field in ['for', 'async', 'await', 'as']:
+            with self.subTest(field=field):
+                with self.assertRaisesRegex(TypeError, 'must not be keywords'):
+                    make_dataclass('C', ['a', field])
+                with self.assertRaisesRegex(TypeError, 'must not be keywords'):
+                    make_dataclass('C', [field])
+                with self.assertRaisesRegex(TypeError, 'must not be keywords'):
+                    make_dataclass('C', [field, 'a'])
+
+    def test_non_identifier_field_names(self):
+        for field in ['()', 'x,y', '*', '2@3', '', 'little johnny tables']:
+            with self.subTest(field=field):
+                with self.assertRaisesRegex(TypeError, 'must be valid identifers'):
+                    make_dataclass('C', ['a', field])
+                with self.assertRaisesRegex(TypeError, 'must be valid identifers'):
+                    make_dataclass('C', [field])
+                with self.assertRaisesRegex(TypeError, 'must be valid identifers'):
+                    make_dataclass('C', [field, 'a'])
+
+    def test_underscore_field_names(self):
+        # Unlike namedtuple, it's okay if dataclass field names have
+        # an underscore.
+        make_dataclass('C', ['_', '_a', 'a_a', 'a_'])
+
+    def test_funny_class_names_names(self):
+        # No reason to prevent weird class names, since
+        # types.new_class allows them.
+        for classname in ['()', 'x,y', '*', '2@3', '']:
+            with self.subTest(classname=classname):
+                C = make_dataclass(classname, ['a', 'b'])
+                self.assertEqual(C.__name__, classname)
+
+
 if __name__ == '__main__':
     unittest.main()
diff --git a/Misc/NEWS.d/next/Library/2018-05-16-10-07-40.bpo-33536._s0TE8.rst b/Misc/NEWS.d/next/Library/2018-05-16-10-07-40.bpo-33536._s0TE8.rst
new file mode 100644 (file)
index 0000000..2c10241
--- /dev/null
@@ -0,0 +1,2 @@
+dataclasses.make_dataclass now checks for invalid field names and duplicate
+fields. Also, added a check for invalid field specifications.