]> granicus.if.org Git - python/commitdiff
Make parameterized tests in email less hackish.
authorR David Murray <rdmurray@bitdance.com>
Thu, 31 May 2012 01:53:40 +0000 (21:53 -0400)
committerR David Murray <rdmurray@bitdance.com>
Thu, 31 May 2012 01:53:40 +0000 (21:53 -0400)
Or perhaps more hackish, depending on your perspective.  But at least this
way it is now possible to run the individual tests using the unittest CLI.

Lib/test/test_email/__init__.py
Lib/test/test_email/test_generator.py
Lib/test/test_email/test_headerregistry.py
Lib/test/test_email/test_pickleable.py

index 75dc64d244b7b2c5ea196d8c2de5445d7ad26edf..bd9d52c3c5ad0f9cf72097b77d74b94816356aba 100644 (file)
@@ -71,3 +71,82 @@ class TestEmailBase(unittest.TestCase):
         for i in range(len(actual)):
             self.assertIsInstance(actual[i], expected[i],
                                     'item {}'.format(i))
+
+
+# Metaclass to allow for parameterized tests
+class Parameterized(type):
+
+    """Provide a test method parameterization facility.
+
+    Parameters are specified as the value of a class attribute that ends with
+    the string '_params'.  Call the portion before '_params' the prefix.  Then
+    a method to be parameterized must have the same prefix, the string
+    '_as_', and an arbitrary suffix.
+
+    The value of the _params attribute may be either a dictionary or a list.
+    The values in the dictionary and the elements of the list may either be
+    single values, or a list.  If single values, they are turned into single
+    element tuples.  However derived, the resulting sequence is passed via
+    *args to the parameterized test function.
+
+    In a _params dictioanry, the keys become part of the name of the generated
+    tests.  In a _params list, the values in the list are converted into a
+    string by joining the string values of the elements of the tuple by '_' and
+    converting any blanks into '_'s, and this become part of the name.  The
+    full name of a generated test is the portion of the _params name before the
+    '_params' portion, plus an '_', plus the name derived as explained above.
+
+    For example, if we have:
+
+        count_params = range(2)
+
+        def count_as_foo_arg(self, foo):
+            self.assertEqual(foo+1, myfunc(foo))
+
+    we will get parameterized test methods named:
+        test_foo_arg_0
+        test_foo_arg_1
+        test_foo_arg_2
+
+    Or we could have:
+
+        example_params = {'foo': ('bar', 1), 'bing': ('bang', 2)}
+
+        def example_as_myfunc_input(self, name, count):
+            self.assertEqual(name+str(count), myfunc(name, count))
+
+    and get:
+        test_myfunc_input_foo
+        test_myfunc_input_bing
+
+    Note: if and only if the generated test name is a valid identifier can it
+    be used to select the test individually from the unittest command line.
+
+    """
+
+    def __new__(meta, classname, bases, classdict):
+        paramdicts = {}
+        for name, attr in classdict.items():
+            if name.endswith('_params'):
+                if not hasattr(attr, 'keys'):
+                    d = {}
+                    for x in attr:
+                        if not hasattr(x, '__iter__'):
+                            x = (x,)
+                        n = '_'.join(str(v) for v in x).replace(' ', '_')
+                        d[n] = x
+                    attr = d
+                paramdicts[name[:-7] + '_as_'] = attr
+        testfuncs = {}
+        for name, attr in classdict.items():
+            for paramsname, paramsdict in paramdicts.items():
+                if name.startswith(paramsname):
+                    testnameroot = 'test_' + name[len(paramsname):]
+                    for paramname, params in paramsdict.items():
+                        test = (lambda self, name=name, params=params:
+                                        getattr(self, name)(*params))
+                        testname = testnameroot + '_' + paramname
+                        test.__name__ = testname
+                        testfuncs[testname] = test
+        classdict.update(testfuncs)
+        return super().__new__(meta, classname, bases, classdict)
index c1af024f571c1a1b1e38ef201af1e83472f5e84c..7e1529be381dad26deafe3c2398892afeb68bb63 100644 (file)
@@ -4,10 +4,10 @@ import unittest
 from email import message_from_string, message_from_bytes
 from email.generator import Generator, BytesGenerator
 from email import policy
-from test.test_email import TestEmailBase
+from test.test_email import TestEmailBase, Parameterized
 
 
-class TestGeneratorBase:
+class TestGeneratorBase(metaclass=Parameterized):
 
     policy = policy.default
 
@@ -80,31 +80,23 @@ class TestGeneratorBase:
               "\n"
               "None\n")
 
-    def _test_maxheaderlen_parameter(self, n):
+    length_params = [n for n in refold_long_expected]
+
+    def length_as_maxheaderlen_parameter(self, n):
         msg = self.msgmaker(self.typ(self.refold_long_expected[0]))
         s = self.ioclass()
         g = self.genclass(s, maxheaderlen=n, policy=self.policy)
         g.flatten(msg)
         self.assertEqual(s.getvalue(), self.typ(self.refold_long_expected[n]))
 
-    for n in refold_long_expected:
-        locals()['test_maxheaderlen_parameter_' + str(n)] = (
-            lambda self, n=n:
-                self._test_maxheaderlen_parameter(n))
-
-    def _test_max_line_length_policy(self, n):
+    def length_as_max_line_length_policy(self, n):
         msg = self.msgmaker(self.typ(self.refold_long_expected[0]))
         s = self.ioclass()
         g = self.genclass(s, policy=self.policy.clone(max_line_length=n))
         g.flatten(msg)
         self.assertEqual(s.getvalue(), self.typ(self.refold_long_expected[n]))
 
-    for n in refold_long_expected:
-        locals()['test_max_line_length_policy' + str(n)] = (
-            lambda self, n=n:
-                self._test_max_line_length_policy(n))
-
-    def _test_maxheaderlen_parm_overrides_policy(self, n):
+    def length_as_maxheaderlen_parm_overrides_policy(self, n):
         msg = self.msgmaker(self.typ(self.refold_long_expected[0]))
         s = self.ioclass()
         g = self.genclass(s, maxheaderlen=n,
@@ -112,12 +104,7 @@ class TestGeneratorBase:
         g.flatten(msg)
         self.assertEqual(s.getvalue(), self.typ(self.refold_long_expected[n]))
 
-    for n in refold_long_expected:
-        locals()['test_maxheaderlen_parm_overrides_policy' + str(n)] = (
-            lambda self, n=n:
-                self._test_maxheaderlen_parm_overrides_policy(n))
-
-    def _test_refold_none_does_not_fold(self, n):
+    def length_as_max_line_length_with_refold_none_does_not_fold(self, n):
         msg = self.msgmaker(self.typ(self.refold_long_expected[0]))
         s = self.ioclass()
         g = self.genclass(s, policy=self.policy.clone(refold_source='none',
@@ -125,12 +112,7 @@ class TestGeneratorBase:
         g.flatten(msg)
         self.assertEqual(s.getvalue(), self.typ(self.refold_long_expected[0]))
 
-    for n in refold_long_expected:
-        locals()['test_refold_none_does_not_fold' + str(n)] = (
-            lambda self, n=n:
-                self._test_refold_none_does_not_fold(n))
-
-    def _test_refold_all(self, n):
+    def length_as_max_line_length_with_refold_all_folds(self, n):
         msg = self.msgmaker(self.typ(self.refold_long_expected[0]))
         s = self.ioclass()
         g = self.genclass(s, policy=self.policy.clone(refold_source='all',
@@ -138,11 +120,6 @@ class TestGeneratorBase:
         g.flatten(msg)
         self.assertEqual(s.getvalue(), self.typ(self.refold_all_expected[n]))
 
-    for n in refold_long_expected:
-        locals()['test_refold_all' + str(n)] = (
-            lambda self, n=n:
-                self._test_refold_all(n))
-
     def test_crlf_control_via_policy(self):
         source = "Subject: test\r\n\r\ntest body\r\n"
         expected = source
index e6fe29b1de3befe5dec74beb3837a67cd4896373..4a57ff14a4f8602cafb12833bfc549a68fe700e8 100644 (file)
@@ -4,7 +4,7 @@ import unittest
 from email import errors
 from email import policy
 from email.message import Message
-from test.test_email import TestEmailBase
+from test.test_email import TestEmailBase, Parameterized
 from email import headerregistry
 from email.headerregistry import Address, Group
 
@@ -175,9 +175,9 @@ class TestDateHeader(TestHeaderBase):
         self.assertEqual(m['Date'].datetime, self.dt)
 
 
-class TestAddressHeader(TestHeaderBase):
+class TestAddressHeader(TestHeaderBase, metaclass=Parameterized):
 
-    examples = {
+    example_params = {
 
         'empty':
             ('<>',
@@ -305,8 +305,8 @@ class TestAddressHeader(TestHeaderBase):
         # trailing comments, which aren't currently handled.  comments in
         # general are not handled yet.
 
-    def _test_single_addr(self, source, defects, decoded, display_name,
-                          addr_spec, username, domain, comment):
+    def example_as_address(self, source, defects, decoded, display_name,
+                           addr_spec, username, domain, comment):
         h = self.make_header('sender', source)
         self.assertEqual(h, decoded)
         self.assertDefectsEqual(h.defects, defects)
@@ -322,13 +322,8 @@ class TestAddressHeader(TestHeaderBase):
         # XXX: we have no comment support yet.
         #self.assertEqual(a.comment, comment)
 
-    for name in examples:
-        locals()['test_'+name] = (
-            lambda self, name=name:
-                self._test_single_addr(*self.examples[name]))
-
-    def _test_group_single_addr(self, source, defects, decoded, display_name,
-                                addr_spec, username, domain, comment):
+    def example_as_group(self, source, defects, decoded, display_name,
+                         addr_spec, username, domain, comment):
         source = 'foo: {};'.format(source)
         gdecoded = 'foo: {};'.format(decoded) if decoded else 'foo:;'
         h = self.make_header('to', source)
@@ -344,11 +339,6 @@ class TestAddressHeader(TestHeaderBase):
         self.assertEqual(a.username, username)
         self.assertEqual(a.domain, domain)
 
-    for name in examples:
-        locals()['test_group_'+name] = (
-            lambda self, name=name:
-                self._test_group_single_addr(*self.examples[name]))
-
     def test_simple_address_list(self):
         value = ('Fred <dinsdale@python.org>, foo@example.com, '
                     '"Harry W. Hastings" <hasty@example.com>')
@@ -366,7 +356,7 @@ class TestAddressHeader(TestHeaderBase):
             'Harry W. Hastings')
 
     def test_complex_address_list(self):
-        examples = list(self.examples.values())
+        examples = list(self.example_params.values())
         source = ('dummy list:;, another: (empty);,' +
                  ', '.join([x[0] for x in examples[:4]]) + ', ' +
                  r'"A \"list\"": ' +
index 3a5bd11ffc20b135d5511d933a5e0011466746c6..09477e042172d9476604ff8ea5d2847691784644 100644 (file)
@@ -6,83 +6,66 @@ import email
 import email.message
 from email import policy
 from email.headerregistry import HeaderRegistry
-from test.test_email import TestEmailBase
+from test.test_email import TestEmailBase, Parameterized
 
-class TestPickleCopyHeader(TestEmailBase):
+class TestPickleCopyHeader(TestEmailBase, metaclass=Parameterized):
 
     header_factory = HeaderRegistry()
 
     unstructured = header_factory('subject', 'this is a test')
 
-    def _test_deepcopy(self, name, value):
+    header_params = {
+        'subject': ('subject', 'this is a test'),
+        'from':    ('from',    'frodo@mordor.net'),
+        'to':      ('to',      'a: k@b.com, y@z.com;, j@f.com'),
+        'date':    ('date',    'Tue, 29 May 2012 09:24:26 +1000'),
+        }
+
+    def header_as_deepcopy(self, name, value):
         header = self.header_factory(name, value)
         h = copy.deepcopy(header)
         self.assertEqual(str(h), str(header))
 
-    def _test_pickle(self, name, value):
+    def header_as_pickle(self, name, value):
         header = self.header_factory(name, value)
         p = pickle.dumps(header)
         h = pickle.loads(p)
         self.assertEqual(str(h), str(header))
 
-    headers = (
-        ('subject', 'this is a test'),
-        ('from',    'frodo@mordor.net'),
-        ('to',      'a: k@b.com, y@z.com;, j@f.com'),
-        ('date',    'Tue, 29 May 2012 09:24:26 +1000'),
-        )
-
-    for header in headers:
-        locals()['test_deepcopy_'+header[0]] = (
-            lambda self, header=header:
-                self._test_deepcopy(*header))
-
-    for header in headers:
-        locals()['test_pickle_'+header[0]] = (
-            lambda self, header=header:
-                self._test_pickle(*header))
 
+class TestPickleCopyMessage(TestEmailBase, metaclass=Parameterized):
 
-class TestPickleCopyMessage(TestEmailBase):
-
-    msgs = {}
+    # Message objects are a sequence, so we have to make them a one-tuple in
+    # msg_params so they get passed to the parameterized test method as a
+    # single argument instead of as a list of headers.
+    msg_params = {}
 
     # Note: there will be no custom header objects in the parsed message.
-    msgs['parsed'] = email.message_from_string(textwrap.dedent("""\
+    msg_params['parsed'] = (email.message_from_string(textwrap.dedent("""\
         Date: Tue, 29 May 2012 09:24:26 +1000
         From: frodo@mordor.net
         To: bilbo@underhill.org
         Subject: help
 
         I think I forgot the ring.
-        """), policy=policy.default)
+        """), policy=policy.default),)
 
-    msgs['created'] = email.message.Message(policy=policy.default)
-    msgs['created']['Date'] = 'Tue, 29 May 2012 09:24:26 +1000'
-    msgs['created']['From'] = 'frodo@mordor.net'
-    msgs['created']['To'] = 'bilbo@underhill.org'
-    msgs['created']['Subject'] = 'help'
-    msgs['created'].set_payload('I think I forgot the ring.')
+    msg_params['created'] = (email.message.Message(policy=policy.default),)
+    msg_params['created'][0]['Date'] = 'Tue, 29 May 2012 09:24:26 +1000'
+    msg_params['created'][0]['From'] = 'frodo@mordor.net'
+    msg_params['created'][0]['To'] = 'bilbo@underhill.org'
+    msg_params['created'][0]['Subject'] = 'help'
+    msg_params['created'][0].set_payload('I think I forgot the ring.')
 
-    def _test_deepcopy(self, msg):
+    def msg_as_deepcopy(self, msg):
         msg2 = copy.deepcopy(msg)
         self.assertEqual(msg2.as_string(), msg.as_string())
 
-    def _test_pickle(self, msg):
+    def msg_as_pickle(self, msg):
         p = pickle.dumps(msg)
         msg2 = pickle.loads(p)
         self.assertEqual(msg2.as_string(), msg.as_string())
 
-    for name, msg in msgs.items():
-        locals()['test_deepcopy_'+name] = (
-            lambda self, msg=msg:
-                self._test_deepcopy(msg))
-
-    for name, msg in msgs.items():
-        locals()['test_pickle_'+name] = (
-            lambda self, msg=msg:
-                self._test_pickle(msg))
-
 
 if __name__ == '__main__':
     unittest.main()