]> granicus.if.org Git - python/commitdiff
refactor traceback.py to reduce code duplication (closes #17646)
authorBenjamin Peterson <benjamin@python.org>
Mon, 29 Apr 2013 20:09:39 +0000 (16:09 -0400)
committerBenjamin Peterson <benjamin@python.org>
Mon, 29 Apr 2013 20:09:39 +0000 (16:09 -0400)
Patch by Martin Morrison.

Lib/test/test_traceback.py
Lib/traceback.py

index 5bce2af68a960df8fb5e126d3c7e86bbe5483598..24753a8321639d5d7b5467b06d1fa56efa593b34 100644 (file)
@@ -160,11 +160,26 @@ class TracebackFormatTests(unittest.TestCase):
             file_ = StringIO()
             traceback_print(tb, file_)
             python_fmt  = file_.getvalue()
+            # Call all _tb and _exc functions
+            with captured_output("stderr") as tbstderr:
+                traceback.print_tb(tb)
+            tbfile = StringIO()
+            traceback.print_tb(tb, file=tbfile)
+            with captured_output("stderr") as excstderr:
+                traceback.print_exc()
+            excfmt = traceback.format_exc()
+            excfile = StringIO()
+            traceback.print_exc(file=excfile)
         else:
             raise Error("unable to create test traceback string")
 
         # Make sure that Python and the traceback module format the same thing
         self.assertEqual(traceback_fmt, python_fmt)
+        # Now verify the _tb func output
+        self.assertEqual(tbstderr.getvalue(), tbfile.getvalue())
+        # Now verify the _exc func output
+        self.assertEqual(excstderr.getvalue(), excfile.getvalue())
+        self.assertEqual(excfmt, excfile.getvalue())
 
         # Make sure that the traceback is properly indented.
         tb_lines = python_fmt.splitlines()
@@ -174,6 +189,19 @@ class TracebackFormatTests(unittest.TestCase):
         self.assertTrue(location.startswith('  File'))
         self.assertTrue(source_line.startswith('    raise'))
 
+    def test_stack_format(self):
+        # Verify _stack functions. Note we have to use _getframe(1) to
+        # compare them without this frame appearing in the output
+        with captured_output("stderr") as ststderr:
+            traceback.print_stack(sys._getframe(1))
+        stfile = StringIO()
+        traceback.print_stack(sys._getframe(1), file=stfile)
+        self.assertEqual(ststderr.getvalue(), stfile.getvalue())
+
+        stfmt = traceback.format_stack(sys._getframe(1))
+
+        self.assertEqual(ststderr.getvalue(), "".join(stfmt))
+
 
 cause_message = (
     "\nThe above exception was the direct cause "
index 33b86c7e6e5042d5cb3e51ef5000a6e66c4eaf6a..3aa1578f4e425c9523e9f2b23c35779322571e26 100644 (file)
@@ -2,26 +2,31 @@
 
 import linecache
 import sys
+import operator
 
 __all__ = ['extract_stack', 'extract_tb', 'format_exception',
            'format_exception_only', 'format_list', 'format_stack',
            'format_tb', 'print_exc', 'format_exc', 'print_exception',
            'print_last', 'print_stack', 'print_tb']
 
-def _print(file, str='', terminator='\n'):
-    file.write(str+terminator)
+#
+# Formatting and printing lists of traceback lines.
+#
 
+def _format_list_iter(extracted_list):
+    for filename, lineno, name, line in extracted_list:
+        item = '  File "{}", line {}, in {}\n'.format(filename, lineno, name)
+        if line:
+            item = item + '    {}\n'.format(line.strip())
+        yield item
 
 def print_list(extracted_list, file=None):
     """Print the list of tuples as returned by extract_tb() or
     extract_stack() as a formatted stack trace to the given file."""
     if file is None:
         file = sys.stderr
-    for filename, lineno, name, line in extracted_list:
-        _print(file,
-               '  File "%s", line %d, in %s' % (filename,lineno,name))
-        if line:
-            _print(file, '    %s' % line.strip())
+    for item in _format_list_iter(extracted_list):
+        print(item, file=file, end="")
 
 def format_list(extracted_list):
     """Format a list of traceback entry tuples for printing.
@@ -33,14 +38,44 @@ def format_list(extracted_list):
     the strings may contain internal newlines as well, for those items
     whose source text line is not None.
     """
-    list = []
-    for filename, lineno, name, line in extracted_list:
-        item = '  File "%s", line %d, in %s\n' % (filename,lineno,name)
+    return list(_format_list_iter(extracted_list))
+
+#
+# Printing and Extracting Tracebacks.
+#
+
+# extractor takes curr and needs to return a tuple of:
+# - Frame object
+# - Line number
+# - Next item (same type as curr)
+# In practice, curr is either a traceback or a frame.
+def _extract_tb_or_stack_iter(curr, limit, extractor):
+    if limit is None:
+        limit = getattr(sys, 'tracebacklimit', None)
+
+    n = 0
+    while curr is not None and (limit is None or n < limit):
+        f, lineno, next_item = extractor(curr)
+        co = f.f_code
+        filename = co.co_filename
+        name = co.co_name
+
+        linecache.checkcache(filename)
+        line = linecache.getline(filename, lineno, f.f_globals)
+
         if line:
-            item = item + '    %s\n' % line.strip()
-        list.append(item)
-    return list
+            line = line.strip()
+        else:
+            line = None
+
+        yield (filename, lineno, name, line)
+        curr = next_item
+        n += 1
 
+def _extract_tb_iter(tb, limit):
+    return _extract_tb_or_stack_iter(
+                tb, limit,
+                operator.attrgetter("tb_frame", "tb_lineno", "tb_next"))
 
 def print_tb(tb, limit=None, file=None):
     """Print up to 'limit' stack trace entries from the traceback 'tb'.
@@ -50,29 +85,11 @@ def print_tb(tb, limit=None, file=None):
     'file' should be an open file or file-like object with a write()
     method.
     """
-    if file is None:
-        file = sys.stderr
-    if limit is None:
-        if hasattr(sys, 'tracebacklimit'):
-            limit = sys.tracebacklimit
-    n = 0
-    while tb is not None and (limit is None or n < limit):
-        f = tb.tb_frame
-        lineno = tb.tb_lineno
-        co = f.f_code
-        filename = co.co_filename
-        name = co.co_name
-        _print(file,
-               '  File "%s", line %d, in %s' % (filename, lineno, name))
-        linecache.checkcache(filename)
-        line = linecache.getline(filename, lineno, f.f_globals)
-        if line: _print(file, '    ' + line.strip())
-        tb = tb.tb_next
-        n = n+1
+    print_list(extract_tb(tb, limit=limit), file=file)
 
 def format_tb(tb, limit=None):
-    """A shorthand for 'format_list(extract_stack(f, limit))."""
-    return format_list(extract_tb(tb, limit))
+    """A shorthand for 'format_list(extract_tb(tb, limit))."""
+    return format_list(extract_tb(tb, limit=limit))
 
 def extract_tb(tb, limit=None):
     """Return list of up to limit pre-processed entries from traceback.
@@ -85,26 +102,11 @@ def extract_tb(tb, limit=None):
     leading and trailing whitespace stripped; if the source is not
     available it is None.
     """
-    if limit is None:
-        if hasattr(sys, 'tracebacklimit'):
-            limit = sys.tracebacklimit
-    list = []
-    n = 0
-    while tb is not None and (limit is None or n < limit):
-        f = tb.tb_frame
-        lineno = tb.tb_lineno
-        co = f.f_code
-        filename = co.co_filename
-        name = co.co_name
-        linecache.checkcache(filename)
-        line = linecache.getline(filename, lineno, f.f_globals)
-        if line: line = line.strip()
-        else: line = None
-        list.append((filename, lineno, name, line))
-        tb = tb.tb_next
-        n = n+1
-    return list
+    return list(_extract_tb_iter(tb, limit=limit))
 
+#
+# Exception formatting and output.
+#
 
 _cause_message = (
     "\nThe above exception was the direct cause "
@@ -134,6 +136,21 @@ def _iter_chain(exc, custom_tb=None, seen=None):
     for it in its:
         yield from it
 
+def _format_exception_iter(etype, value, tb, limit, chain):
+    if chain:
+        values = _iter_chain(value, tb)
+    else:
+        values = [(value, tb)]
+
+    for value, tb in values:
+        if isinstance(value, str):
+            # This is a cause/context message line
+            yield value + '\n'
+            continue
+        if tb:
+            yield 'Traceback (most recent call last):\n'
+            yield from _format_list_iter(_extract_tb_iter(tb, limit=limit))
+        yield from _format_exception_only_iter(type(value), value)
 
 def print_exception(etype, value, tb, limit=None, file=None, chain=True):
     """Print exception up to 'limit' stack trace entries from 'tb' to 'file'.
@@ -148,20 +165,8 @@ def print_exception(etype, value, tb, limit=None, file=None, chain=True):
     """
     if file is None:
         file = sys.stderr
-    if chain:
-        values = _iter_chain(value, tb)
-    else:
-        values = [(value, tb)]
-    for value, tb in values:
-        if isinstance(value, str):
-            _print(file, value)
-            continue
-        if tb:
-            _print(file, 'Traceback (most recent call last):')
-            print_tb(tb, limit, file)
-        lines = format_exception_only(type(value), value)
-        for line in lines:
-            _print(file, line, '')
+    for line in _format_exception_iter(etype, value, tb, limit, chain):
+        print(line, file=file, end="")
 
 def format_exception(etype, value, tb, limit=None, chain=True):
     """Format a stack trace and the exception information.
@@ -172,20 +177,7 @@ def format_exception(etype, value, tb, limit=None, chain=True):
     these lines are concatenated and printed, exactly the same text is
     printed as does print_exception().
     """
-    list = []
-    if chain:
-        values = _iter_chain(value, tb)
-    else:
-        values = [(value, tb)]
-    for value, tb in values:
-        if isinstance(value, str):
-            list.append(value + '\n')
-            continue
-        if tb:
-            list.append('Traceback (most recent call last):\n')
-            list.extend(format_tb(tb, limit))
-        list.extend(format_exception_only(type(value), value))
-    return list
+    return list(_format_exception_iter(etype, value, tb, limit, chain))
 
 def format_exception_only(etype, value):
     """Format the exception part of a traceback.
@@ -203,10 +195,14 @@ def format_exception_only(etype, value):
     string in the list.
 
     """
+    return list(_format_exception_only_iter(etype, value))
+
+def _format_exception_only_iter(etype, value):
     # Gracefully handle (the way Python 2.4 and earlier did) the case of
     # being called with (None, None).
     if etype is None:
-        return [_format_final_exc_line(etype, value)]
+        yield _format_final_exc_line(etype, value)
+        return
 
     stype = etype.__name__
     smod = etype.__module__
@@ -214,26 +210,26 @@ def format_exception_only(etype, value):
         stype = smod + '.' + stype
 
     if not issubclass(etype, SyntaxError):
-        return [_format_final_exc_line(stype, value)]
+        yield _format_final_exc_line(stype, value)
+        return
 
     # It was a syntax error; show exactly where the problem was found.
-    lines = []
     filename = value.filename or "<string>"
     lineno = str(value.lineno) or '?'
-    lines.append('  File "%s", line %s\n' % (filename, lineno))
+    yield '  File "{}", line {}\n'.format(filename, lineno)
+
     badline = value.text
     offset = value.offset
     if badline is not None:
-        lines.append('    %s\n' % badline.strip())
+        yield '    {}\n'.format(badline.strip())
         if offset is not None:
             caretspace = badline.rstrip('\n')[:offset].lstrip()
             # non-space whitespace (likes tabs) must be kept for alignment
             caretspace = ((c.isspace() and c or ' ') for c in caretspace)
             # only three spaces to account for offset1 == pos 0
-            lines.append('   %s^\n' % ''.join(caretspace))
+            yield '   {}^\n'.format(''.join(caretspace))
     msg = value.msg or "<no detail available>"
-    lines.append("%s: %s\n" % (stype, msg))
-    return lines
+    yield "{}: {}\n".format(stype, msg)
 
 def _format_final_exc_line(etype, value):
     valuestr = _some_str(value)
@@ -249,38 +245,34 @@ def _some_str(value):
     except:
         return '<unprintable %s object>' % type(value).__name__
 
-
 def print_exc(limit=None, file=None, chain=True):
     """Shorthand for 'print_exception(*sys.exc_info(), limit, file)'."""
-    if file is None:
-        file = sys.stderr
-    try:
-        etype, value, tb = sys.exc_info()
-        print_exception(etype, value, tb, limit, file, chain)
-    finally:
-        etype = value = tb = None
-
+    print_exception(*sys.exc_info(), limit=limit, file=file, chain=chain)
 
 def format_exc(limit=None, chain=True):
     """Like print_exc() but return a string."""
-    try:
-        etype, value, tb = sys.exc_info()
-        return ''.join(
-            format_exception(etype, value, tb, limit, chain))
-    finally:
-        etype = value = tb = None
-
+    return "".join(format_exception(*sys.exc_info(), limit=limit, chain=chain))
 
 def print_last(limit=None, file=None, chain=True):
     """This is a shorthand for 'print_exception(sys.last_type,
     sys.last_value, sys.last_traceback, limit, file)'."""
     if not hasattr(sys, "last_type"):
         raise ValueError("no last exception")
-    if file is None:
-        file = sys.stderr
     print_exception(sys.last_type, sys.last_value, sys.last_traceback,
                     limit, file, chain)
 
+#
+# Printing and Extracting Stacks.
+#
+
+def _extract_stack_iter(f, limit=None):
+    return _extract_tb_or_stack_iter(
+                f, limit, lambda f: (f, f.f_lineno, f.f_back))
+
+def _get_stack(f):
+    if f is None:
+        f = sys._getframe().f_back.f_back
+    return f
 
 def print_stack(f=None, limit=None, file=None):
     """Print a stack trace from its invocation point.
@@ -289,21 +281,11 @@ def print_stack(f=None, limit=None, file=None):
     stack frame at which to start. The optional 'limit' and 'file'
     arguments have the same meaning as for print_exception().
     """
-    if f is None:
-        try:
-            raise ZeroDivisionError
-        except ZeroDivisionError:
-            f = sys.exc_info()[2].tb_frame.f_back
-    print_list(extract_stack(f, limit), file)
+    print_list(extract_stack(_get_stack(f), limit=limit), file=file)
 
 def format_stack(f=None, limit=None):
     """Shorthand for 'format_list(extract_stack(f, limit))'."""
-    if f is None:
-        try:
-            raise ZeroDivisionError
-        except ZeroDivisionError:
-            f = sys.exc_info()[2].tb_frame.f_back
-    return format_list(extract_stack(f, limit))
+    return format_list(extract_stack(_get_stack(f), limit=limit))
 
 def extract_stack(f=None, limit=None):
     """Extract the raw traceback from the current stack frame.
@@ -314,27 +296,6 @@ def extract_stack(f=None, limit=None):
     line number, function name, text), and the entries are in order
     from oldest to newest stack frame.
     """
-    if f is None:
-        try:
-            raise ZeroDivisionError
-        except ZeroDivisionError:
-            f = sys.exc_info()[2].tb_frame.f_back
-    if limit is None:
-        if hasattr(sys, 'tracebacklimit'):
-            limit = sys.tracebacklimit
-    list = []
-    n = 0
-    while f is not None and (limit is None or n < limit):
-        lineno = f.f_lineno
-        co = f.f_code
-        filename = co.co_filename
-        name = co.co_name
-        linecache.checkcache(filename)
-        line = linecache.getline(filename, lineno, f.f_globals)
-        if line: line = line.strip()
-        else: line = None
-        list.append((filename, lineno, name, line))
-        f = f.f_back
-        n = n+1
-    list.reverse()
-    return list
+    stack = list(_extract_stack_iter(_get_stack(f), limit=limit))
+    stack.reverse()
+    return stack