]> granicus.if.org Git - python/commitdiff
Issue #3781: Final cleanup of warnings.catch_warnings and its usage in the test suite...
authorNick Coghlan <ncoghlan@gmail.com>
Thu, 11 Sep 2008 12:11:06 +0000 (12:11 +0000)
committerNick Coghlan <ncoghlan@gmail.com>
Thu, 11 Sep 2008 12:11:06 +0000 (12:11 +0000)
Doc/library/test.rst
Doc/library/warnings.rst
Lib/test/test_import.py
Lib/test/test_py3kwarn.py
Lib/test/test_structmembers.py
Lib/test/test_sundry.py
Lib/test/test_support.py
Lib/test/test_symtable.py
Lib/test/test_warnings.py
Lib/warnings.py
Misc/NEWS

index 4c957599e15c7d1eebc5cfb7dcc2ca43571ca297..0a2814bb0f816edf8c19b87dd76a7b2dd49f28cf 100644 (file)
@@ -291,18 +291,26 @@ The :mod:`test.test_support` module defines the following functions:
    This will run all tests defined in the named module.
 
 
-.. function:: catch_warning(module=warnings, record=True)
+.. function:: check_warnings()
 
-   Return a context manager that guards the warnings filter from being
-   permanently changed and optionally alters the :func:`showwarning`
-   function to record the details of any warnings that are issued in the
-   managed context. Details of the most recent call to :func:`showwarning`
-   are saved directly on the context manager, while details of previous
-   warnings can be retrieved from the ``warnings`` list.
+   A convenience wrapper for ``warnings.catch_warnings()`` that makes
+   it easier to test that a warning was correctly raised with a single
+   assertion. It is approximately equivalent to calling
+   ``warnings.catch_warnings(record=True)``.
+
+   The main difference is that on entry to the context manager, a
+   :class:`WarningRecorder` instance is returned instead of a simple list.
+   The underlying warnings list is available via the recorder object's
+   :attr:`warnings` attribute, while the attributes of the last raised
+   warning are also accessible directly on the object. If no warning has
+   been raised, then the latter attributes will all be :const:`None`.
+
+   A :meth:`reset` method is also provided on the recorder object. This
+   method simply clears the warning list.
 
    The context manager is used like this::
 
-      with catch_warning() as w:
+      with check_warnings() as w:
           warnings.simplefilter("always")
           warnings.warn("foo")
           assert str(w.message) == "foo"
@@ -310,15 +318,8 @@ The :mod:`test.test_support` module defines the following functions:
           assert str(w.message) == "bar"
           assert str(w.warnings[0].message) == "foo"
           assert str(w.warnings[1].message) == "bar"
-
-   By default, the real :mod:`warnings` module is affected - the ability
-   to select a different module is provided for the benefit of the
-   :mod:`warnings` module's  own unit tests.
-   The ``record`` argument specifies whether or not the :func:`showwarning`
-   function is replaced. Note that recording the warnings in this fashion
-   also prevents them from being written to sys.stderr. If set to ``False``,
-   the standard handling of warning messages is left in place (however, the
-   original handling is still restored at the end of the block).
+          w.reset()
+          assert len(w.warnings) == 0
 
    .. versionadded:: 2.6
 
@@ -366,4 +367,10 @@ The :mod:`test.test_support` module defines the following classes:
 
    Temporarily unset the environment variable ``envvar``.
 
+.. class:: WarningsRecorder()
+
+   Class used to record warnings for unit tests. See documentation of
+   :func:`check_warnings` above for more details.
+
+   .. versionadded:: 2.6
 
index e9d018280f5e914a0880a96e3dbada910ef232ee..bb3aa44da592ae68f9307327b8dd8315ba863f36 100644 (file)
@@ -163,9 +163,9 @@ ImportWarning can also be enabled explicitly in Python code using::
 Temporarily Suppressing Warnings
 --------------------------------
 
-If you are using code that you know will raise a warning, such some deprecated
-function, but do not want to see the warning, then suppress the warning using
-the :class:`catch_warnings` context manager::
+If you are using code that you know will raise a warning, such as a deprecated
+function, but do not want to see the warning, then it is possible to suppress
+the warning using the :class:`catch_warnings` context manager::
 
     import warnings
 
@@ -216,7 +216,15 @@ the warning has been cleared.
 Once the context manager exits, the warnings filter is restored to its state
 when the context was entered. This prevents tests from changing the warnings
 filter in unexpected ways between tests and leading to indeterminate test
-results.
+results. The :func:`showwarning` function in the module is also restored to
+its original value.
+
+When testing multiple operations that raise the same kind of warning, it
+is important to test them in a manner that confirms each operation is raising
+a new warning (e.g. set warnings to be raised as exceptions and check the
+operations raise exceptions, check that the length of the warning list
+continues to increase after each operation, or else delete the previous
+entries from the warnings list before each new operation).
 
 
 .. _warning-functions:
@@ -330,16 +338,18 @@ Available Context Managers
 
 .. class:: catch_warnings([\*, record=False, module=None])
 
-    A context manager that copies and, upon exit, restores the warnings filter.
-    If the *record* argument is False (the default) the context manager returns
-    :class:`None`. If *record* is true, a list is returned that is populated
-    with objects as seen by a custom :func:`showwarning` function (which also
-    suppresses output to ``sys.stdout``). Each object has attributes with the
-    same names as the arguments to :func:`showwarning`.
+    A context manager that copies and, upon exit, restores the warnings filter
+    and the :func:`showwarning` function.
+    If the *record* argument is :const:`False` (the default) the context manager
+    returns :class:`None` on entry. If *record* is :const:`True`, a list is
+    returned that is progressively populated with objects as seen by a custom
+    :func:`showwarning` function (which also suppresses output to ``sys.stdout``).
+    Each object in the list has attributes with the same names as the arguments to
+    :func:`showwarning`.
 
     The *module* argument takes a module that will be used instead of the
     module returned when you import :mod:`warnings` whose filter will be
-    protected. This arguments exists primarily for testing the :mod:`warnings`
+    protected. This argument exists primarily for testing the :mod:`warnings`
     module itself.
 
     .. note::
index 644a473b12b88e4010798b985c72ab1ae4293bea..13e8cc3bd49e8a096534772d562cdbb1c4df2c7f 100644 (file)
@@ -5,7 +5,7 @@ import shutil
 import sys
 import py_compile
 import warnings
-from test.test_support import unlink, TESTFN, unload, run_unittest
+from test.test_support import unlink, TESTFN, unload, run_unittest, check_warnings
 
 
 def remove_files(name):
@@ -279,17 +279,17 @@ class RelativeImport(unittest.TestCase):
         check_relative()
         # Check relative fails with only __package__ wrong
         ns = dict(__package__='foo', __name__='test.notarealmodule')
-        with warnings.catch_warnings(record=True) as w:
+        with check_warnings() as w:
             check_absolute()
-            self.assert_('foo' in str(w[-1].message))
-            self.assertEqual(w[-1].category, RuntimeWarning)
+            self.assert_('foo' in str(w.message))
+            self.assertEqual(w.category, RuntimeWarning)
         self.assertRaises(SystemError, check_relative)
         # Check relative fails with __package__ and __name__ wrong
         ns = dict(__package__='foo', __name__='notarealpkg.notarealmodule')
-        with warnings.catch_warnings(record=True) as w:
+        with check_warnings() as w:
             check_absolute()
-            self.assert_('foo' in str(w[-1].message))
-            self.assertEqual(w[-1].category, RuntimeWarning)
+            self.assert_('foo' in str(w.message))
+            self.assertEqual(w.category, RuntimeWarning)
         self.assertRaises(SystemError, check_relative)
         # Check both fail with package set to a non-string
         ns = dict(__package__=object())
index aa1ecbb588bebd80b0fba62bbd626c113e59faed..0afa8e72b18ee8316d316f8d5da29dc14596e059 100644 (file)
@@ -1,6 +1,7 @@
 import unittest
 import sys
-from test.test_support import CleanImport, TestSkipped, run_unittest
+from test.test_support import (check_warnings, CleanImport,
+                               TestSkipped, run_unittest)
 import warnings
 
 from contextlib import nested
@@ -8,15 +9,22 @@ from contextlib import nested
 if not sys.py3kwarning:
     raise TestSkipped('%s must be run with the -3 flag' % __name__)
 
+def reset_module_registry(module):
+    try:
+        registry = module.__warningregistry__
+    except AttributeError:
+        pass
+    else:
+        registry.clear()
 
 class TestPy3KWarnings(unittest.TestCase):
 
     def assertWarning(self, _, warning, expected_message):
-        self.assertEqual(str(warning[-1].message), expected_message)
+        self.assertEqual(str(warning.message), expected_message)
 
     def test_backquote(self):
         expected = 'backquote not supported in 3.x; use repr()'
-        with warnings.catch_warnings(record=True) as w:
+        with check_warnings() as w:
             exec "`2`" in {}
         self.assertWarning(None, w, expected)
 
@@ -27,55 +35,71 @@ class TestPy3KWarnings(unittest.TestCase):
             exec expr in {'f' : f}
 
         expected = "assignment to True or False is forbidden in 3.x"
-        with warnings.catch_warnings(record=True) as w:
+        with check_warnings() as w:
             safe_exec("True = False")
             self.assertWarning(None, w, expected)
+            w.reset()
             safe_exec("False = True")
             self.assertWarning(None, w, expected)
+            w.reset()
             try:
                 safe_exec("obj.False = True")
             except NameError: pass
             self.assertWarning(None, w, expected)
+            w.reset()
             try:
                 safe_exec("obj.True = False")
             except NameError: pass
             self.assertWarning(None, w, expected)
+            w.reset()
             safe_exec("def False(): pass")
             self.assertWarning(None, w, expected)
+            w.reset()
             safe_exec("def True(): pass")
             self.assertWarning(None, w, expected)
+            w.reset()
             safe_exec("class False: pass")
             self.assertWarning(None, w, expected)
+            w.reset()
             safe_exec("class True: pass")
             self.assertWarning(None, w, expected)
+            w.reset()
             safe_exec("def f(True=43): pass")
             self.assertWarning(None, w, expected)
+            w.reset()
             safe_exec("def f(False=None): pass")
             self.assertWarning(None, w, expected)
+            w.reset()
             safe_exec("f(False=True)")
             self.assertWarning(None, w, expected)
+            w.reset()
             safe_exec("f(True=1)")
             self.assertWarning(None, w, expected)
 
 
     def test_type_inequality_comparisons(self):
         expected = 'type inequality comparisons not supported in 3.x'
-        with warnings.catch_warnings(record=True) as w:
+        with check_warnings() as w:
             self.assertWarning(int < str, w, expected)
+            w.reset()
             self.assertWarning(type < object, w, expected)
 
     def test_object_inequality_comparisons(self):
         expected = 'comparing unequal types not supported in 3.x'
-        with warnings.catch_warnings(record=True) as w:
+        with check_warnings() as w:
             self.assertWarning(str < [], w, expected)
+            w.reset()
             self.assertWarning(object() < (1, 2), w, expected)
 
     def test_dict_inequality_comparisons(self):
         expected = 'dict inequality comparisons not supported in 3.x'
-        with warnings.catch_warnings(record=True) as w:
+        with check_warnings() as w:
             self.assertWarning({} < {2:3}, w, expected)
+            w.reset()
             self.assertWarning({} <= {}, w, expected)
+            w.reset()
             self.assertWarning({} > {2:3}, w, expected)
+            w.reset()
             self.assertWarning({2:3} >= {}, w, expected)
 
     def test_cell_inequality_comparisons(self):
@@ -86,8 +110,9 @@ class TestPy3KWarnings(unittest.TestCase):
             return g
         cell0, = f(0).func_closure
         cell1, = f(1).func_closure
-        with warnings.catch_warnings(record=True) as w:
+        with check_warnings() as w:
             self.assertWarning(cell0 == cell1, w, expected)
+            w.reset()
             self.assertWarning(cell0 < cell1, w, expected)
 
     def test_code_inequality_comparisons(self):
@@ -96,10 +121,13 @@ class TestPy3KWarnings(unittest.TestCase):
             pass
         def g(x):
             pass
-        with warnings.catch_warnings(record=True) as w:
+        with check_warnings() as w:
             self.assertWarning(f.func_code < g.func_code, w, expected)
+            w.reset()
             self.assertWarning(f.func_code <= g.func_code, w, expected)
+            w.reset()
             self.assertWarning(f.func_code >= g.func_code, w, expected)
+            w.reset()
             self.assertWarning(f.func_code > g.func_code, w, expected)
 
     def test_builtin_function_or_method_comparisons(self):
@@ -107,10 +135,13 @@ class TestPy3KWarnings(unittest.TestCase):
                     'inequality comparisons not supported in 3.x')
         func = eval
         meth = {}.get
-        with warnings.catch_warnings(record=True) as w:
+        with check_warnings() as w:
             self.assertWarning(func < meth, w, expected)
+            w.reset()
             self.assertWarning(func > meth, w, expected)
+            w.reset()
             self.assertWarning(meth <= func, w, expected)
+            w.reset()
             self.assertWarning(meth >= func, w, expected)
 
     def test_sort_cmp_arg(self):
@@ -118,15 +149,18 @@ class TestPy3KWarnings(unittest.TestCase):
         lst = range(5)
         cmp = lambda x,y: -1
 
-        with warnings.catch_warnings(record=True) as w:
+        with check_warnings() as w:
             self.assertWarning(lst.sort(cmp=cmp), w, expected)
+            w.reset()
             self.assertWarning(sorted(lst, cmp=cmp), w, expected)
+            w.reset()
             self.assertWarning(lst.sort(cmp), w, expected)
+            w.reset()
             self.assertWarning(sorted(lst, cmp), w, expected)
 
     def test_sys_exc_clear(self):
         expected = 'sys.exc_clear() not supported in 3.x; use except clauses'
-        with warnings.catch_warnings(record=True) as w:
+        with check_warnings() as w:
             self.assertWarning(sys.exc_clear(), w, expected)
 
     def test_methods_members(self):
@@ -135,17 +169,17 @@ class TestPy3KWarnings(unittest.TestCase):
             __methods__ = ['a']
             __members__ = ['b']
         c = C()
-        with warnings.catch_warnings(record=True) as w:
+        with check_warnings() as w:
             self.assertWarning(dir(c), w, expected)
 
     def test_softspace(self):
         expected = 'file.softspace not supported in 3.x'
         with file(__file__) as f:
-            with warnings.catch_warnings(record=True) as w:
+            with check_warnings() as w:
                 self.assertWarning(f.softspace, w, expected)
             def set():
                 f.softspace = 0
-            with warnings.catch_warnings(record=True) as w:
+            with check_warnings() as w:
                 self.assertWarning(set(), w, expected)
 
     def test_slice_methods(self):
@@ -161,59 +195,60 @@ class TestPy3KWarnings(unittest.TestCase):
         expected = "in 3.x, __{0}slice__ has been removed; use __{0}item__"
 
         for obj in (Spam(), Egg()):
-            with warnings.catch_warnings(record=True) as w:
+            with check_warnings() as w:
                 self.assertWarning(obj[1:2], w, expected.format('get'))
+                w.reset()
                 del obj[3:4]
                 self.assertWarning(None, w, expected.format('del'))
+                w.reset()
                 obj[4:5] = "eggs"
                 self.assertWarning(None, w, expected.format('set'))
 
     def test_tuple_parameter_unpacking(self):
         expected = "tuple parameter unpacking has been removed in 3.x"
-        with warnings.catch_warnings(record=True) as w:
+        with check_warnings() as w:
             exec "def f((a, b)): pass"
             self.assertWarning(None, w, expected)
 
     def test_buffer(self):
         expected = 'buffer() not supported in 3.x'
-        with warnings.catch_warnings(record=True) as w:
+        with check_warnings() as w:
             self.assertWarning(buffer('a'), w, expected)
 
     def test_file_xreadlines(self):
         expected = ("f.xreadlines() not supported in 3.x, "
                     "try 'for line in f' instead")
         with file(__file__) as f:
-            with warnings.catch_warnings(record=True) as w:
+            with check_warnings() as w:
                 self.assertWarning(f.xreadlines(), w, expected)
 
     def test_hash_inheritance(self):
-        with warnings.catch_warnings(record=True) as w:
+        with check_warnings() as w:
             # With object as the base class
             class WarnOnlyCmp(object):
                 def __cmp__(self, other): pass
-            self.assertEqual(len(w), 1)
+            self.assertEqual(len(w.warnings), 1)
             self.assertWarning(None, w,
                  "Overriding __cmp__ blocks inheritance of __hash__ in 3.x")
-            del w[:]
+            w.reset()
             class WarnOnlyEq(object):
                 def __eq__(self, other): pass
-            self.assertEqual(len(w), 1)
+            self.assertEqual(len(w.warnings), 1)
             self.assertWarning(None, w,
                  "Overriding __eq__ blocks inheritance of __hash__ in 3.x")
-            del w[:]
+            w.reset()
             class WarnCmpAndEq(object):
                 def __cmp__(self, other): pass
                 def __eq__(self, other): pass
-            self.assertEqual(len(w), 2)
-            self.assertWarning(None, w[:1],
+            self.assertEqual(len(w.warnings), 2)
+            self.assertWarning(None, w.warnings[0],
                  "Overriding __cmp__ blocks inheritance of __hash__ in 3.x")
             self.assertWarning(None, w,
                  "Overriding __eq__ blocks inheritance of __hash__ in 3.x")
-            del w[:]
+            w.reset()
             class NoWarningOnlyHash(object):
                 def __hash__(self): pass
-            self.assertEqual(len(w), 0)
-            del w[:]
+            self.assertEqual(len(w.warnings), 0)
             # With an intermediate class in the heirarchy
             class DefinesAllThree(object):
                 def __cmp__(self, other): pass
@@ -221,28 +256,28 @@ class TestPy3KWarnings(unittest.TestCase):
                 def __hash__(self): pass
             class WarnOnlyCmp(DefinesAllThree):
                 def __cmp__(self, other): pass
-            self.assertEqual(len(w), 1)
+            self.assertEqual(len(w.warnings), 1)
             self.assertWarning(None, w,
                  "Overriding __cmp__ blocks inheritance of __hash__ in 3.x")
-            del w[:]
+            w.reset()
             class WarnOnlyEq(DefinesAllThree):
                 def __eq__(self, other): pass
-            self.assertEqual(len(w), 1)
+            self.assertEqual(len(w.warnings), 1)
             self.assertWarning(None, w,
                  "Overriding __eq__ blocks inheritance of __hash__ in 3.x")
-            del w[:]
+            w.reset()
             class WarnCmpAndEq(DefinesAllThree):
                 def __cmp__(self, other): pass
                 def __eq__(self, other): pass
-            self.assertEqual(len(w), 2)
-            self.assertWarning(None, w[:1],
+            self.assertEqual(len(w.warnings), 2)
+            self.assertWarning(None, w.warnings[0],
                  "Overriding __cmp__ blocks inheritance of __hash__ in 3.x")
             self.assertWarning(None, w,
                  "Overriding __eq__ blocks inheritance of __hash__ in 3.x")
-            del w[:]
+            w.reset()
             class NoWarningOnlyHash(DefinesAllThree):
                 def __hash__(self): pass
-            self.assertEqual(len(w), 0)
+            self.assertEqual(len(w.warnings), 0)
 
 
 class TestStdlibRemovals(unittest.TestCase):
@@ -283,6 +318,9 @@ class TestStdlibRemovals(unittest.TestCase):
         """Make sure the specified module, when imported, raises a
         DeprecationWarning and specifies itself in the message."""
         with nested(CleanImport(module_name), warnings.catch_warnings()):
+            # XXX: This is not quite enough for extension modules - those
+            # won't rerun their init code even with CleanImport.
+            # You can see this easily by running the whole test suite with -3
             warnings.filterwarnings("error", ".+ removed",
                                     DeprecationWarning, __name__)
             try:
@@ -320,12 +358,15 @@ class TestStdlibRemovals(unittest.TestCase):
         def dumbo(where, names, args): pass
         for path_mod in ("ntpath", "macpath", "os2emxpath", "posixpath"):
             mod = __import__(path_mod)
-            with warnings.catch_warnings(record=True) as w:
+            reset_module_registry(mod)
+            with check_warnings() as w:
                 mod.walk("crashers", dumbo, None)
-            self.assertEquals(str(w[-1].message), msg)
+            self.assertEquals(str(w.message), msg)
 
     def test_commands_members(self):
         import commands
+        # commands module tests may have already triggered this warning
+        reset_module_registry(commands)
         members = {"mk2arg" : 2, "mkarg" : 1, "getstatus" : 1}
         for name, arg_count in members.items():
             with warnings.catch_warnings():
@@ -335,6 +376,8 @@ class TestStdlibRemovals(unittest.TestCase):
 
     def test_reduce_move(self):
         from operator import add
+        # reduce tests may have already triggered this warning
+        reset_module_registry(unittest)
         with warnings.catch_warnings():
             warnings.filterwarnings("error", "reduce")
             self.assertRaises(DeprecationWarning, reduce, add, range(10))
@@ -342,6 +385,8 @@ class TestStdlibRemovals(unittest.TestCase):
     def test_mutablestring_removal(self):
         # UserString.MutableString has been removed in 3.0.
         import UserString
+        # UserString tests may have already triggered this warning
+        reset_module_registry(UserString)
         with warnings.catch_warnings():
             warnings.filterwarnings("error", ".*MutableString",
                                     DeprecationWarning)
@@ -349,7 +394,7 @@ class TestStdlibRemovals(unittest.TestCase):
 
 
 def test_main():
-    with warnings.catch_warnings():
+    with check_warnings():
         warnings.simplefilter("always")
         run_unittest(TestPy3KWarnings,
                      TestStdlibRemovals)
index e0e7e5613001c85b70f723765f2a4f608aa03d59..c196cc5147483081ff044cceff135a7ded0c0571 100644 (file)
@@ -66,35 +66,35 @@ class ReadWriteTests(unittest.TestCase):
 
 class TestWarnings(unittest.TestCase):
     def has_warned(self, w):
-        self.assertEqual(w[-1].category, RuntimeWarning)
+        self.assertEqual(w.category, RuntimeWarning)
 
     def test_byte_max(self):
-        with warnings.catch_warnings(record=True) as w:
+        with test_support.check_warnings() as w:
             ts.T_BYTE = CHAR_MAX+1
             self.has_warned(w)
 
     def test_byte_min(self):
-        with warnings.catch_warnings(record=True) as w:
+        with test_support.check_warnings() as w:
             ts.T_BYTE = CHAR_MIN-1
             self.has_warned(w)
 
     def test_ubyte_max(self):
-        with warnings.catch_warnings(record=True) as w:
+        with test_support.check_warnings() as w:
             ts.T_UBYTE = UCHAR_MAX+1
             self.has_warned(w)
 
     def test_short_max(self):
-        with warnings.catch_warnings(record=True) as w:
+        with test_support.check_warnings() as w:
             ts.T_SHORT = SHRT_MAX+1
             self.has_warned(w)
 
     def test_short_min(self):
-        with warnings.catch_warnings(record=True) as w:
+        with test_support.check_warnings() as w:
             ts.T_SHORT = SHRT_MIN-1
             self.has_warned(w)
 
     def test_ushort_max(self):
-        with warnings.catch_warnings(record=True) as w:
+        with test_support.check_warnings() as w:
             ts.T_USHORT = USHRT_MAX+1
             self.has_warned(w)
 
index 49ec12c944ac0b4d0ea74bb2a733ea6e2da9f42e..e7499ae4b26f79d981ac54b47869bbeaac5fcd0a 100644 (file)
@@ -8,7 +8,7 @@ import warnings
 
 class TestUntestedModules(unittest.TestCase):
     def test_at_least_import_untested_modules(self):
-        with warnings.catch_warnings(record=True):
+        with warnings.catch_warnings():
             import CGIHTTPServer
             import aifc
             import audiodev
index 0bf22cf35dbb80dae694e3a65e48349bff57dbc1..abaf11734749c85e370ff351d4840287845c2c84 100644 (file)
@@ -18,7 +18,7 @@ __all__ = ["Error", "TestFailed", "TestSkipped", "ResourceDenied", "import_modul
            "is_resource_enabled", "requires", "find_unused_port", "bind_port",
            "fcmp", "have_unicode", "is_jython", "TESTFN", "HOST", "FUZZ",
            "findfile", "verify", "vereq", "sortdict", "check_syntax_error",
-           "open_urlresource", "CleanImport",
+           "open_urlresource", "check_warnings", "CleanImport",
            "EnvironmentVarGuard", "captured_output",
            "captured_stdout", "TransientResource", "transient_internet",
            "run_with_locale", "set_memlimit", "bigmemtest", "bigaddrspacetest",
@@ -381,6 +381,29 @@ def open_urlresource(url):
     return open(fn)
 
 
+class WarningsRecorder(object):
+    """Convenience wrapper for the warnings list returned on
+       entry to the warnings.catch_warnings() context manager.
+    """
+    def __init__(self, warnings_list):
+        self.warnings = warnings_list
+
+    def __getattr__(self, attr):
+        if self.warnings:
+            return getattr(self.warnings[-1], attr)
+        elif attr in warnings.WarningMessage._WARNING_DETAILS:
+            return None
+        raise AttributeError("%r has no attribute %r" % (self, attr))
+
+    def reset(self):
+        del self.warnings[:]
+
+@contextlib.contextmanager
+def check_warnings():
+    with warnings.catch_warnings(record=True) as w:
+        yield WarningsRecorder(w)
+
+
 class CleanImport(object):
     """Context manager to force import to return a new module reference.
 
index b20f2b4e0e92768c5276705e3dfa3cabdcc157f7..f1fa18ac92af5a5ad085c632d22b4b7397261edc 100644 (file)
@@ -60,16 +60,16 @@ class SymtableTest(unittest.TestCase):
         def check(w, msg):
             self.assertEqual(str(w.message), msg)
         sym = self.top.lookup("glob")
-        with warnings.catch_warnings(record=True) as w:
+        with test_support.check_warnings() as w:
             warnings.simplefilter("always", DeprecationWarning)
             self.assertFalse(sym.is_vararg())
-            check(w[-1].message, "is_vararg() is obsolete and will be removed")
+            check(w, "is_vararg() is obsolete and will be removed")
+            w.reset()
             self.assertFalse(sym.is_keywordarg())
-            check(w[-1].message,
-                    "is_keywordarg() is obsolete and will be removed")
+            check(w, "is_keywordarg() is obsolete and will be removed")
+            w.reset()
             self.assertFalse(sym.is_in_tuple())
-            check(w[-1].message,
-                    "is_in_tuple() is obsolete and will be removed")
+            check(w, "is_in_tuple() is obsolete and will be removed")
 
     def test_type(self):
         self.assertEqual(self.top.get_type(), "module")
index 388b5e9d3a6b0e870fbafe959be4b9c0a984b26e..b37cdf7788ec7143d6ab2b50ffa2db2b6d10f6c0 100644 (file)
@@ -517,10 +517,12 @@ class CatchWarningTests(BaseTest):
         wmod = self.module
         orig_filters = wmod.filters
         orig_showwarning = wmod.showwarning
-        with wmod.catch_warnings(record=True, module=wmod):
+        # Ensure both showwarning and filters are restored when recording
+        with wmod.catch_warnings(module=wmod, record=True):
             wmod.filters = wmod.showwarning = object()
         self.assert_(wmod.filters is orig_filters)
         self.assert_(wmod.showwarning is orig_showwarning)
+        # Same test, but with recording disabled
         with wmod.catch_warnings(module=wmod, record=False):
             wmod.filters = wmod.showwarning = object()
         self.assert_(wmod.filters is orig_filters)
@@ -528,9 +530,10 @@ class CatchWarningTests(BaseTest):
 
     def test_catch_warnings_recording(self):
         wmod = self.module
+        # Ensure warnings are recorded when requested
         with wmod.catch_warnings(module=wmod, record=True) as w:
             self.assertEqual(w, [])
-            self.assertRaises(AttributeError, getattr, w, 'message')
+            self.assert_(type(w) is list)
             wmod.simplefilter("always")
             wmod.warn("foo")
             self.assertEqual(str(w[-1].message), "foo")
@@ -540,11 +543,61 @@ class CatchWarningTests(BaseTest):
             self.assertEqual(str(w[1].message), "bar")
             del w[:]
             self.assertEqual(w, [])
+        # Ensure warnings are not recorded when not requested
         orig_showwarning = wmod.showwarning
         with wmod.catch_warnings(module=wmod, record=False) as w:
             self.assert_(w is None)
             self.assert_(wmod.showwarning is orig_showwarning)
 
+    def test_catch_warnings_reentry_guard(self):
+        wmod = self.module
+        # Ensure catch_warnings is protected against incorrect usage
+        x = wmod.catch_warnings(module=wmod, record=True)
+        self.assertRaises(RuntimeError, x.__exit__)
+        with x:
+            self.assertRaises(RuntimeError, x.__enter__)
+        # Same test, but with recording disabled
+        x = wmod.catch_warnings(module=wmod, record=False)
+        self.assertRaises(RuntimeError, x.__exit__)
+        with x:
+            self.assertRaises(RuntimeError, x.__enter__)
+
+    def test_catch_warnings_defaults(self):
+        wmod = self.module
+        orig_filters = wmod.filters
+        orig_showwarning = wmod.showwarning
+        # Ensure default behaviour is not to record warnings
+        with wmod.catch_warnings(module=wmod) as w:
+            self.assert_(w is None)
+            self.assert_(wmod.showwarning is orig_showwarning)
+            self.assert_(wmod.filters is not orig_filters)
+        self.assert_(wmod.filters is orig_filters)
+        if wmod is sys.modules['warnings']:
+            # Ensure the default module is this one
+            with wmod.catch_warnings() as w:
+                self.assert_(w is None)
+                self.assert_(wmod.showwarning is orig_showwarning)
+                self.assert_(wmod.filters is not orig_filters)
+            self.assert_(wmod.filters is orig_filters)
+
+    def test_check_warnings(self):
+        # Explicit tests for the test_support convenience wrapper
+        wmod = self.module
+        if wmod is sys.modules['warnings']:
+            with test_support.check_warnings() as w:
+                self.assertEqual(w.warnings, [])
+                wmod.simplefilter("always")
+                wmod.warn("foo")
+                self.assertEqual(str(w.message), "foo")
+                wmod.warn("bar")
+                self.assertEqual(str(w.message), "bar")
+                self.assertEqual(str(w.warnings[0].message), "foo")
+                self.assertEqual(str(w.warnings[1].message), "bar")
+                w.reset()
+                self.assertEqual(w.warnings, [])
+
+
+
 class CCatchWarningTests(CatchWarningTests):
     module = c_warnings
 
index 04e7b5878c6255bcfcd461da0734e5f1b1e298bc..59011caa46deef4ff80fdef2ed155e1585643fc0 100644 (file)
@@ -331,8 +331,21 @@ class catch_warnings(object):
         """
         self._record = record
         self._module = sys.modules['warnings'] if module is None else module
+        self._entered = False
+
+    def __repr__(self):
+        args = []
+        if self._record:
+            args.append("record=True")
+        if self._module is not sys.modules['warnings']:
+            args.append("module=%r" % self._module)
+        name = type(self).__name__
+        return "%s(%s)" % (name, ", ".join(args))
 
     def __enter__(self):
+        if self._entered:
+            raise RuntimeError("Cannot enter %r twice" % self)
+        self._entered = True
         self._filters = self._module.filters
         self._module.filters = self._filters[:]
         self._showwarning = self._module.showwarning
@@ -346,6 +359,8 @@ class catch_warnings(object):
             return None
 
     def __exit__(self, *exc_info):
+        if not self._entered:
+            raise RuntimeError("Cannot exit %r without entering first" % self)
         self._module.filters = self._filters
         self._module.showwarning = self._showwarning
 
index 55490b7f44a574cbdf7ef45bd461a1fc3bdf01a9..fba47232f7cf1f84eae7dc60e7cfd05b528e7e59 100644 (file)
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -79,9 +79,13 @@ Library
 
 - Issue #3811: The Unicode database was updated to 5.1.
 
+- Issue #3781: Further warnings.catch_warnings() cleanup to prevent
+  silent misbehaviour when a single instance is nested in multiple
+  with statements, or when the methods are invoked in the wrong order.
+
 - Issue #3809: Fixed spurious 'test.blah' file left behind by test_logging.
 
-- Issue 3781: Clean up the API for warnings.catch_warnings() by having it
+- Issue #3781: Clean up the API for warnings.catch_warnings() by having it
   return a list or None rather than a custom object.
 
 - Issue #1638033: Cookie.Morsel gained the httponly attribute.
@@ -142,6 +146,10 @@ Extension Modules
 Tests
 -----
 
+- Issue #3781: Add test.test_support.check_warnings() as a convenience
+  wrapper for warnings.catch_warnings() that makes it easier to check
+  that expected warning messages are being reported.
+
 - Issue #3796: Some tests functions were not enabled in test_float.
 
 - Issue #3768: Move test_py3kwarn over to the new API for catch_warnings().