]> granicus.if.org Git - python/commitdiff
Issue #15836: assertRaises(), assertRaisesRegex(), assertWarns() and
authorSerhiy Storchaka <storchaka@gmail.com>
Thu, 21 May 2015 17:15:40 +0000 (20:15 +0300)
committerSerhiy Storchaka <storchaka@gmail.com>
Thu, 21 May 2015 17:15:40 +0000 (20:15 +0300)
assertWarnsRegex() assertments now check the type of the first argument
to prevent possible user error.  Based on patch by Daniel Wagner-Hall.

Lib/test/test_importlib/builtin/test_loader.py
Lib/unittest/case.py
Lib/unittest/test/test_case.py
Misc/ACKS
Misc/NEWS

index 1684ab6eb411a07f836de96fff0d69692e085ab6..b1349ec5da46781e6a047f4164a280c5beabe246 100644 (file)
@@ -97,7 +97,6 @@ class InspectLoaderTests:
             method = getattr(self.machinery.BuiltinImporter, meth_name)
         with self.assertRaises(ImportError) as cm:
             method(util.BUILTINS.bad_name)
-        self.assertRaises(util.BUILTINS.bad_name)
 
 
 (Frozen_InspectLoaderTests,
index befad61946558a5e26b3c5f37d452b0af6f117f1..7701ad3adcc22612f0c74e86b5550d83f584262f 100644 (file)
@@ -119,6 +119,10 @@ def expectedFailure(test_item):
     test_item.__unittest_expecting_failure__ = True
     return test_item
 
+def _is_subtype(expected, basetype):
+    if isinstance(expected, tuple):
+        return all(_is_subtype(e, basetype) for e in expected)
+    return isinstance(expected, type) and issubclass(expected, basetype)
 
 class _BaseTestCaseContext:
 
@@ -148,6 +152,9 @@ class _AssertRaisesBaseContext(_BaseTestCaseContext):
         If args is not empty, call a callable passing positional and keyword
         arguments.
         """
+        if not _is_subtype(self.expected, self._base_type):
+            raise TypeError('%s() arg 1 must be %s' %
+                            (name, self._base_type_str))
         if args and args[0] is None:
             warnings.warn("callable is None",
                           DeprecationWarning, 3)
@@ -172,6 +179,9 @@ class _AssertRaisesBaseContext(_BaseTestCaseContext):
 class _AssertRaisesContext(_AssertRaisesBaseContext):
     """A context manager used to implement TestCase.assertRaises* methods."""
 
+    _base_type = BaseException
+    _base_type_str = 'an exception type or tuple of exception types'
+
     def __enter__(self):
         return self
 
@@ -206,6 +216,9 @@ class _AssertRaisesContext(_AssertRaisesBaseContext):
 class _AssertWarnsContext(_AssertRaisesBaseContext):
     """A context manager used to implement TestCase.assertWarns* methods."""
 
+    _base_type = Warning
+    _base_type_str = 'a warning type or tuple of warning types'
+
     def __enter__(self):
         # The __warningregistry__'s need to be in a pristine state for tests
         # to work properly.
index a05cc648900c835866f467b114222bd4f44b532e..ada733b1ffb45482b7e9b4763d2bf6f3a0bca318 100644 (file)
@@ -1185,6 +1185,18 @@ test case
         with self.assertRaises(ExceptionMock):
             self.assertRaises(ValueError, Stub)
 
+    def testAssertRaisesNoExceptionType(self):
+        with self.assertRaises(TypeError):
+            self.assertRaises()
+        with self.assertRaises(TypeError):
+            self.assertRaises(1)
+        with self.assertRaises(TypeError):
+            self.assertRaises(object)
+        with self.assertRaises(TypeError):
+            self.assertRaises((ValueError, 1))
+        with self.assertRaises(TypeError):
+            self.assertRaises((ValueError, object))
+
     def testAssertRaisesRegex(self):
         class ExceptionMock(Exception):
             pass
@@ -1258,6 +1270,20 @@ test case
         self.assertIsInstance(e, ExceptionMock)
         self.assertEqual(e.args[0], v)
 
+    def testAssertRaisesRegexNoExceptionType(self):
+        with self.assertRaises(TypeError):
+            self.assertRaisesRegex()
+        with self.assertRaises(TypeError):
+            self.assertRaisesRegex(ValueError)
+        with self.assertRaises(TypeError):
+            self.assertRaisesRegex(1, 'expect')
+        with self.assertRaises(TypeError):
+            self.assertRaisesRegex(object, 'expect')
+        with self.assertRaises(TypeError):
+            self.assertRaisesRegex((ValueError, 1), 'expect')
+        with self.assertRaises(TypeError):
+            self.assertRaisesRegex((ValueError, object), 'expect')
+
     def testAssertWarnsCallable(self):
         def _runtime_warn():
             warnings.warn("foo", RuntimeWarning)
@@ -1336,6 +1362,20 @@ test case
                 with self.assertWarns(DeprecationWarning):
                     _runtime_warn()
 
+    def testAssertWarnsNoExceptionType(self):
+        with self.assertRaises(TypeError):
+            self.assertWarns()
+        with self.assertRaises(TypeError):
+            self.assertWarns(1)
+        with self.assertRaises(TypeError):
+            self.assertWarns(object)
+        with self.assertRaises(TypeError):
+            self.assertWarns((UserWarning, 1))
+        with self.assertRaises(TypeError):
+            self.assertWarns((UserWarning, object))
+        with self.assertRaises(TypeError):
+            self.assertWarns((UserWarning, Exception))
+
     def testAssertWarnsRegexCallable(self):
         def _runtime_warn(msg):
             warnings.warn(msg, RuntimeWarning)
@@ -1414,6 +1454,22 @@ test case
                 with self.assertWarnsRegex(RuntimeWarning, "o+"):
                     _runtime_warn("barz")
 
+    def testAssertWarnsRegexNoExceptionType(self):
+        with self.assertRaises(TypeError):
+            self.assertWarnsRegex()
+        with self.assertRaises(TypeError):
+            self.assertWarnsRegex(UserWarning)
+        with self.assertRaises(TypeError):
+            self.assertWarnsRegex(1, 'expect')
+        with self.assertRaises(TypeError):
+            self.assertWarnsRegex(object, 'expect')
+        with self.assertRaises(TypeError):
+            self.assertWarnsRegex((UserWarning, 1), 'expect')
+        with self.assertRaises(TypeError):
+            self.assertWarnsRegex((UserWarning, object), 'expect')
+        with self.assertRaises(TypeError):
+            self.assertWarnsRegex((UserWarning, Exception), 'expect')
+
     @contextlib.contextmanager
     def assertNoStderr(self):
         with captured_stderr() as buf:
index 807eeb27a6a180dea22d7d6c89a3676f886c8ed1..4d3e2900435632cd527673f509fdeb62dee62133 100644 (file)
--- a/Misc/ACKS
+++ b/Misc/ACKS
@@ -1472,6 +1472,7 @@ Alex Volkov
 Martijn Vries
 Sjoerd de Vries
 Guido Vranken
+Daniel Wagner-Hall
 Niki W. Waibel
 Wojtek Walczak
 Charles Waldman
index 5838a63046b477a8faee3a5014a4d04cbad18e73..0f35d3f97b0f61e28b8c8d12d923987a602f936b 100644 (file)
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -52,6 +52,10 @@ Core and Builtins
 Library
 -------
 
+- Issue #15836: assertRaises(), assertRaisesRegex(), assertWarns() and
+  assertWarnsRegex() assertments now check the type of the first argument
+  to prevent possible user error.  Based on patch by Daniel Wagner-Hall.
+
 - Issue #9858: Add missing method stubs to _io.RawIOBase.  Patch by Laura
   Rupprecht.