]> granicus.if.org Git - python/commitdiff
Addition of -b command line option to unittest for buffering stdout and stderr during...
authorMichael Foord <fuzzyman@voidspace.org.uk>
Fri, 2 Apr 2010 21:42:47 +0000 (21:42 +0000)
committerMichael Foord <fuzzyman@voidspace.org.uk>
Fri, 2 Apr 2010 21:42:47 +0000 (21:42 +0000)
Lib/unittest/main.py
Lib/unittest/result.py
Lib/unittest/runner.py
Lib/unittest/test/test_break.py
Lib/unittest/test/test_result.py

index e2703ce727072d34e584804c98510feca27540df..7c1f48c2a71739bcba7b5df71f7f436bc4da4301 100644 (file)
@@ -68,12 +68,12 @@ class TestProgram(object):
     USAGE = USAGE_FROM_MODULE
 
     # defaults for testing
-    failfast = catchbreak = None
+    failfast = catchbreak = buffer = None
 
     def __init__(self, module='__main__', defaultTest=None,
                  argv=None, testRunner=None,
                  testLoader=loader.defaultTestLoader, exit=True,
-                 verbosity=1, failfast=None, catchbreak=None):
+                 verbosity=1, failfast=None, catchbreak=None, buffer=None):
         if isinstance(module, basestring):
             self.module = __import__(module)
             for part in module.split('.')[1:]:
@@ -87,6 +87,7 @@ class TestProgram(object):
         self.failfast = failfast
         self.catchbreak = catchbreak
         self.verbosity = verbosity
+        self.buffer = buffer
         self.defaultTest = defaultTest
         self.testRunner = testRunner
         self.testLoader = testLoader
@@ -111,9 +112,9 @@ class TestProgram(object):
             return
 
         import getopt
-        long_opts = ['help', 'verbose', 'quiet', 'failfast', 'catch']
+        long_opts = ['help', 'verbose', 'quiet', 'failfast', 'catch', 'buffer']
         try:
-            options, args = getopt.getopt(argv[1:], 'hHvqfc', long_opts)
+            options, args = getopt.getopt(argv[1:], 'hHvqfcb', long_opts)
             for opt, value in options:
                 if opt in ('-h','-H','--help'):
                     self.usageExit()
@@ -129,6 +130,10 @@ class TestProgram(object):
                     if self.catchbreak is None:
                         self.catchbreak = True
                     # Should this raise an exception if -c is not valid?
+                if opt in ('-b','--buffer'):
+                    if self.buffer is None:
+                        self.buffer = True
+                    # Should this raise an exception if -b is not valid?
             if len(args) == 0 and self.defaultTest is None:
                 # createTests will load tests from self.module
                 self.testNames = None
@@ -164,6 +169,10 @@ class TestProgram(object):
             parser.add_option('-c', '--catch', dest='catchbreak', default=False,
                               help='Catch ctrl-C and display results so far',
                               action='store_true')
+        if self.buffer != False:
+            parser.add_option('-b', '--buffer', dest='buffer', default=False,
+                              help='Buffer stdout and stderr during tests',
+                              action='store_true')
         parser.add_option('-s', '--start-directory', dest='start', default='.',
                           help="Directory to start discovery ('.' default)")
         parser.add_option('-p', '--pattern', dest='pattern', default='test*.py',
@@ -184,6 +193,8 @@ class TestProgram(object):
             self.failfast = options.failfast
         if self.catchbreak is None:
             self.catchbreak = options.catchbreak
+        if self.buffer is None:
+            self.buffer = options.buffer
 
         if options.verbose:
             self.verbosity = 2
@@ -203,9 +214,10 @@ class TestProgram(object):
         if isinstance(self.testRunner, (type, types.ClassType)):
             try:
                 testRunner = self.testRunner(verbosity=self.verbosity,
-                                             failfast=self.failfast)
+                                             failfast=self.failfast,
+                                             buffer=self.buffer)
             except TypeError:
-                # didn't accept the verbosity or failfast arguments
+                # didn't accept the verbosity, buffer or failfast arguments
                 testRunner = self.testRunner()
         else:
             # it is assumed to be a TestRunner instance
index ec80e8eea5871c4a6c9aa856f902ecfbba2a20db..46eba0401b929823b02887a71889b11529506a5d 100644 (file)
@@ -1,7 +1,11 @@
 """Test result object"""
 
+import os
+import sys
 import traceback
 
+from cStringIO import StringIO
+
 from . import util
 from functools import wraps
 
@@ -15,6 +19,15 @@ def failfast(method):
         return method(self, *args, **kw)
     return inner
 
+
+_std_out = sys.stdout
+_std_err = sys.stderr
+
+NEWLINE = os.linesep
+STDOUT_LINE = '%sStdout:%s%%s' % (NEWLINE, NEWLINE)
+STDERR_LINE = '%sStderr:%s%%s' % (NEWLINE, NEWLINE)
+
+
 class TestResult(object):
     """Holder for test result information.
 
@@ -37,6 +50,10 @@ class TestResult(object):
         self.expectedFailures = []
         self.unexpectedSuccesses = []
         self.shouldStop = False
+        self.buffer = False
+        self._stdout_buffer = StringIO()
+        self._stderr_buffer = StringIO()
+        self._mirrorOutput = False
 
     def printErrors(self):
         "Called by TestRunner after test run"
@@ -44,6 +61,10 @@ class TestResult(object):
     def startTest(self, test):
         "Called when the given test is about to be run"
         self.testsRun += 1
+        self._mirrorOutput = False
+        if self.buffer:
+            sys.stdout = self._stdout_buffer
+            sys.stderr = self._stderr_buffer
 
     def startTestRun(self):
         """Called once before any tests are executed.
@@ -53,6 +74,26 @@ class TestResult(object):
 
     def stopTest(self, test):
         """Called when the given test has been run"""
+        if self.buffer:
+            if self._mirrorOutput:
+                output = sys.stdout.getvalue()
+                error = sys.stderr.getvalue()
+                if output:
+                    if not output.endswith(NEWLINE):
+                        output += NEWLINE
+                    sys.__stdout__.write(STDOUT_LINE % output)
+                if error:
+                    if not error.endswith(NEWLINE):
+                        error += NEWLINE
+                    sys.__stderr__.write(STDERR_LINE % error)
+
+            sys.stdout = _std_out
+            sys.stderr = _std_err
+            self._stdout_buffer.seek(0)
+            self._stdout_buffer.truncate()
+            self._stderr_buffer.seek(0)
+            self._stderr_buffer.truncate()
+        self._mirrorOutput = False
 
     def stopTestRun(self):
         """Called once after all tests are executed.
@@ -66,12 +107,14 @@ class TestResult(object):
         returned by sys.exc_info().
         """
         self.errors.append((test, self._exc_info_to_string(err, test)))
+        self._mirrorOutput = True
 
     @failfast
     def addFailure(self, test, err):
         """Called when an error has occurred. 'err' is a tuple of values as
         returned by sys.exc_info()."""
         self.failures.append((test, self._exc_info_to_string(err, test)))
+        self._mirrorOutput = True
 
     def addSuccess(self, test):
         "Called when a test has completed successfully"
@@ -105,11 +148,27 @@ class TestResult(object):
         # Skip test runner traceback levels
         while tb and self._is_relevant_tb_level(tb):
             tb = tb.tb_next
+
         if exctype is test.failureException:
             # Skip assert*() traceback levels
             length = self._count_relevant_tb_levels(tb)
-            return ''.join(traceback.format_exception(exctype, value, tb, length))
-        return ''.join(traceback.format_exception(exctype, value, tb))
+            msgLines = traceback.format_exception(exctype, value, tb, length)
+        else:
+            msgLines = traceback.format_exception(exctype, value, tb)
+
+        if self.buffer:
+            output = sys.stdout.getvalue()
+            error = sys.stderr.getvalue()
+            if output:
+                if not output.endswith(NEWLINE):
+                    output += NEWLINE
+                msgLines.append(STDOUT_LINE % output)
+            if error:
+                if not error.endswith(NEWLINE):
+                    error += NEWLINE
+                msgLines.append(STDERR_LINE % error)
+        return ''.join(msgLines)
+
 
     def _is_relevant_tb_level(self, tb):
         return '__unittest' in tb.tb_frame.f_globals
index 5169d20ece02bb69b5d9de28b532991e9bf199cc..068aa5bf43f9a2152fe69ed8827fbab4b7486513 100644 (file)
@@ -125,11 +125,12 @@ class TextTestRunner(object):
     resultclass = TextTestResult
 
     def __init__(self, stream=sys.stderr, descriptions=True, verbosity=1,
-                 failfast=False, resultclass=None):
+                 failfast=False, buffer=False, resultclass=None):
         self.stream = _WritelnDecorator(stream)
         self.descriptions = descriptions
         self.verbosity = verbosity
         self.failfast = failfast
+        self.buffer = buffer
         if resultclass is not None:
             self.resultclass = resultclass
 
@@ -141,6 +142,7 @@ class TextTestRunner(object):
         result = self._makeResult()
         registerResult(result)
         result.failfast = self.failfast
+        result.buffer = self.buffer
         startTime = time.time()
         startTestRun = getattr(result, 'startTestRun', None)
         if startTestRun is not None:
index 0de31cd86435f6d967d80238167f05e5d525cbb7..1cd75b8d565889311d947f040c11c549cba3d044 100644 (file)
@@ -205,8 +205,9 @@ class TestBreak(unittest.TestCase):
         p = Program(False)
         p.runTests()
 
-        self.assertEqual(FakeRunner.initArgs, [((), {'verbosity': verbosity,
-                                                'failfast': failfast})])
+        self.assertEqual(FakeRunner.initArgs, [((), {'buffer': None,
+                                                     'verbosity': verbosity,
+                                                     'failfast': failfast})])
         self.assertEqual(FakeRunner.runArgs, [test])
         self.assertEqual(p.result, result)
 
@@ -217,8 +218,9 @@ class TestBreak(unittest.TestCase):
         p = Program(True)
         p.runTests()
 
-        self.assertEqual(FakeRunner.initArgs, [((), {'verbosity': verbosity,
-                                                'failfast': failfast})])
+        self.assertEqual(FakeRunner.initArgs, [((), {'buffer': None,
+                                                     'verbosity': verbosity,
+                                                     'failfast': failfast})])
         self.assertEqual(FakeRunner.runArgs, [test])
         self.assertEqual(p.result, result)
 
index c29fd9c00e44d39dc4131e3a861142ba1dbe621b..86376798eef801cf110bc10a974e1ebc9acbe91b 100644 (file)
@@ -1,5 +1,6 @@
 import sys
-from cStringIO import StringIO
+import textwrap
+from cStringIO import StringIO, OutputType
 from test import test_support
 
 import unittest
@@ -301,6 +302,8 @@ def __init__(self, stream=None, descriptions=None, verbosity=None):
     self.errors = []
     self.testsRun = 0
     self.shouldStop = False
+    self.buffer = False
+
 classDict['__init__'] = __init__
 OldResult = type('OldResult', (object,), classDict)
 
@@ -356,5 +359,130 @@ class Test_OldTestResult(unittest.TestCase):
         runner.run(Test('testFoo'))
 
 
+class TestOutputBuffering(unittest.TestCase):
+
+    def setUp(self):
+        self._real_out = sys.__stdout__
+        self._real_err = sys.__stderr__
+
+    def tearDown(self):
+        sys.stdout = sys.__stdout__ = self._real_out
+        sys.stderr = sys.__stderr__ = self._real_err
+
+    def testBufferOutputOff(self):
+        real_out = self._real_out
+        real_err = self._real_err
+
+        result = unittest.TestResult()
+        self.assertFalse(result.buffer)
+
+        self.assertIs(real_out, sys.stdout)
+        self.assertIs(real_err, sys.stderr)
+
+        result.startTest(self)
+
+        self.assertIs(real_out, sys.stdout)
+        self.assertIs(real_err, sys.stderr)
+
+    def testBufferOutputStartTestAddSuccess(self):
+        real_out = self._real_out
+        real_err = self._real_err
+
+        result = unittest.TestResult()
+        self.assertFalse(result.buffer)
+
+        result.buffer = True
+
+        self.assertIs(real_out, sys.stdout)
+        self.assertIs(real_err, sys.stderr)
+
+        result.startTest(self)
+
+        self.assertIsNot(real_out, sys.stdout)
+        self.assertIsNot(real_err, sys.stderr)
+        self.assertIsInstance(sys.stdout, OutputType)
+        self.assertIsInstance(sys.stderr, OutputType)
+        self.assertIsNot(sys.stdout, sys.stderr)
+
+        out_stream = sys.stdout
+        err_stream = sys.stderr
+
+        sys.__stdout__ = StringIO()
+        sys.__stderr__ = StringIO()
+
+        print 'foo'
+        print >> sys.stderr, 'bar'
+
+        self.assertEqual(out_stream.getvalue(), 'foo\n')
+        self.assertEqual(err_stream.getvalue(), 'bar\n')
+
+        self.assertEqual(sys.__stdout__.getvalue(), '')
+        self.assertEqual(sys.__stderr__.getvalue(), '')
+
+        result.addSuccess(self)
+        result.stopTest(self)
+
+        self.assertIs(real_out, sys.stdout)
+        self.assertIs(real_err, sys.stderr)
+
+        self.assertEqual(sys.__stdout__.getvalue(), '')
+        self.assertEqual(sys.__stderr__.getvalue(), '')
+
+        self.assertEqual(out_stream.getvalue(), '')
+        self.assertEqual(err_stream.getvalue(), '')
+
+
+    def getStartedResult(self):
+        result = unittest.TestResult()
+        result.buffer = True
+        result.startTest(self)
+        return result
+
+    def testBufferOutputAddErrorOrFailure(self):
+        def clear():
+            sys.__stdout__ = StringIO()
+            sys.__stderr__ = StringIO()
+
+        for message_attr, add_attr, include_error in [
+            ('errors', 'addError', True),
+            ('failures', 'addFailure', False),
+            ('errors', 'addError', True),
+            ('failures', 'addFailure', False)
+        ]:
+            clear()
+            result = self.getStartedResult()
+            buffered_out = sys.stdout
+            buffered_err = sys.stderr
+
+            print >> sys.stdout, 'foo'
+            if include_error:
+                print >> sys.stderr, 'bar'
+
+
+            addFunction = getattr(result, add_attr)
+            addFunction(self, (None, None, None))
+            result.stopTest(self)
+
+            result_list = getattr(result, message_attr)
+            self.assertEqual(len(result_list), 1)
+
+            test, message = result_list[0]
+            expectedOutMessage = textwrap.dedent("""
+                Stdout:
+                foo
+            """)
+            expectedErrMessage = ''
+            if include_error:
+                expectedErrMessage = textwrap.dedent("""
+                Stderr:
+                bar
+            """)
+            expectedFullMessage = 'None\n%s%s' % (expectedOutMessage, expectedErrMessage)
+
+            self.assertIs(test, self)
+            self.assertEqual(sys.__stdout__.getvalue(), expectedOutMessage)
+            self.assertEqual(sys.__stderr__.getvalue(), expectedErrMessage)
+            self.assertMultiLineEqual(message, expectedFullMessage)
+
 if __name__ == '__main__':
     unittest.main()