From c7e07d9318d4362bcc1cb274d4793cb48a061211 Mon Sep 17 00:00:00 2001
From: =?utf8?q?Fran=C3=A7ois=20Pinard?= <pinard@iro.umontreal.ca>
Date: Sun, 24 Feb 2008 22:06:50 -0500
Subject: [PATCH] pytest a bit more like py.test

---
 tests/ChangeLog      | 11 ++++++
 tests/Makefile.am    |  2 +-
 tests/Makefile.in    |  2 +-
 tests/common.py      | 12 ++++--
 tests/pytest         | 91 +++++++++++++++++++++++++++++---------------
 tests/t90_bigauto.py | 17 +++++----
 6 files changed, 91 insertions(+), 44 deletions(-)

diff --git a/tests/ChangeLog b/tests/ChangeLog
index f695108..029f43e 100644
--- a/tests/ChangeLog
+++ b/tests/ChangeLog
@@ -1,3 +1,14 @@
+2008-02-24  François Pinard  <pinard@iro.umontreal.ca>
+
+	* pytest (py.test): New, including the old raises, and the new skip,
+	fail and exit functions, as suggested by Holger Krekel.
+	Interpret a skip() call in a generator routine as a request to skip
+	the whole set of generated tests.
+
+	* Makefile.am (check): Do not set PATH, set RECODE instead.
+	* common.py: Look for RECODE in os.environ to find the program.
+	* t90_bigauto.py: Use the above.
+
 2008-02-23  François Pinard  <pinard@iro.umontreal.ca>
 
 	* t90_bigauto.py: New, replacing bigauto.py.
diff --git a/tests/Makefile.am b/tests/Makefile.am
index 5555828..765f01a 100644
--- a/tests/Makefile.am
+++ b/tests/Makefile.am
@@ -28,5 +28,5 @@ EXTRA_DIST = NOTES pytest common.py $(SUITE)
 PYTHON = python
 
 check-local:
-	PATH=../src:$$PATH PYTHONPATH=$(srcdir) \
+	RECODE=../src/recode PYTHONPATH=$(srcdir) \
 	  $(PYTHON) $(srcdir)/pytest $(srcdir)/t*.py
diff --git a/tests/Makefile.in b/tests/Makefile.in
index cf83f37..a5e9433 100644
--- a/tests/Makefile.in
+++ b/tests/Makefile.in
@@ -373,7 +373,7 @@ uninstall-am:
 
 
 check-local:
-	PATH=../src:$$PATH PYTHONPATH=$(srcdir) \
+	RECODE=../src/recode PYTHONPATH=$(srcdir) \
 	  $(PYTHON) $(srcdir)/pytest $(srcdir)/t*.py
 # Tell versions [3.59,3.63) of GNU make to not export all variables.
 # Otherwise a system limit (for SysV at least) may be exceeded.
diff --git a/tests/common.py b/tests/common.py
index cb44a9f..ae8bab2 100644
--- a/tests/common.py
+++ b/tests/common.py
@@ -2,7 +2,9 @@
 
 __metaclass__ = type
 import os
-from __main__ import SkipTest, raises
+from __main__ import py
+
+recode_program = os.environ.get('RECODE')
 
 try:
     import Recode
@@ -54,14 +56,16 @@ def assert_or_diff(output, expected):
         assert False, (len(output), len(expected))
 
 def external_output(command):
-    return os.popen(command.replace('$R', 'recode'), 'rb').read()
+    if not recode_program:
+        py.test.skip()
+    return os.popen(command.replace('$R', recode_program), 'rb').read()
 
 def recode_output(input):
     if run.external:
         file(run.work, 'wb').write(input)
         return external_output('$R %s < %s' % (run.request, run.work))
     if Recode is None:
-        raise SkipTest
+        py.test.skip()
     return Recode.recode(run.request, input)
 
 def recode_back_output(input):
@@ -71,7 +75,7 @@ def recode_back_output(input):
         external_output('$R %s %s' % (run.request, run.work))
         return external_output('$R %s..%s < %s' % (after, before, run.work))
     if Recode is None:
-        raise SkipTest
+        py.test.skip()
     temp = Recode.recode(run.request, input)
     return Recode.recode('%s..%s' % (after, before), temp)
 
diff --git a/tests/pytest b/tests/pytest
index 49075fd..f056c51 100755
--- a/tests/pytest
+++ b/tests/pytest
@@ -46,11 +46,7 @@ from StringIO import StringIO
 # How many displayable characters in an output line.
 WIDTH = 79
 
-class Limit_Reached(Exception):
-    pass
-
-class SkipTest(Exception):
-    pass
+class Limit_Reached(Exception): pass
 
 class Main:
     prefix = 'test_'
@@ -159,18 +155,18 @@ class Main:
                             else:
                                 text = u' ' + text
                             write(text + u'\n')
-            except KeyboardInterrupt:
+            except Exit, exception:
                 if not self.verbose:
                     write(u'\n')
-                write(u'\n*** INTERRUPTION! ***\n')
+                write(u'\n* %s *\n' % str(exception))
             except Limit_Reached:
                 if not self.verbose:
                     write(u'\n')
                 if not self.save:
                     if len(self.failures) == 1:
-                        write(u'\n*** ONE ERROR ALREADY! ***\n')
+                        write(u'\n* One error already! *\n')
                     else:
-                        write(u'\n*** %d ERRORS ALREADY! ***\n' % self.limit)
+                        write(u'\n* %d errors already! *\n' % self.limit)
         finally:
             if self.profiler is not None:
                 stats = lsprof.Stats(self.profiler.getstats())
@@ -291,7 +287,7 @@ class Main:
                     self.handle_class(prefix + u'/' + name, objet)
             else:
                 self.handle_function(prefix + u'/' + name, objet,
-                                      generator, None)
+                                     generator, None)
         if self.did_tests_in_module and hasattr(module, u'teardown_module'):
             module.teardown_module(module)
 
@@ -304,22 +300,30 @@ class Main:
                                    bool(code.co_flags & 32)))
         if not collection:
             return
+        # FIXME: Should likely do module setup here!
         instance = classe()
         if hasattr(instance, u'setup_class'):
             self.delayed_setup_module = instance.setup_class, classe
         for _, name, method, generator in sorted(collection):
             self.handle_function(prefix + u'/' + name, getattr(instance, name),
-                                  generator, instance)
+                                 generator, instance)
         if self.did_tests_in_class and hasattr(instance, u'teardown_class'):
             instance.teardown_class(classe)
 
     def handle_function(self, prefix, function, generator, instance):
+        collection = []
         if generator:
-            for counter, arguments in enumerate(function()):
-                self.launch_test(prefix + u'/' + unicode(counter + 1),
-                                 arguments[0], arguments[1:], instance)
+            # FIXME: Should likely do class setup here.
+            try:
+                for counter, arguments in enumerate(function()):
+                    collection.append((prefix + u'/' + unicode(counter + 1),
+                                       arguments[0], arguments[1:]))
+            except Skipped:
+                return
         else:
-            self.launch_test(prefix, function, (), instance)
+            collection.append((prefix, function, ()))
+        for prefix, function, arguments in collection:
+            self.launch_test(prefix, function, arguments, instance)
 
     def launch_test(self, prefix, function, arguments, instance):
         # Check if this test should be retained.
@@ -346,11 +350,16 @@ class Main:
         try:
             try:
                 function(*arguments)
-            except KeyboardInterrupt:
-                success = None
+            except Exit:
+                success = False
                 raise
-            except SkipTest:
+            except Failed:
+                success = False
+            except Skipped:
                 success = None
+            except KeyboardInterrupt:
+                success = None
+                raise Exit("Interruption!")
             except:
                 success = False
             else:
@@ -435,19 +444,41 @@ class Friendly_StreamWriter:
 run = Main()
 main = run.main
 
-class ExceptionExpected(Exception):
-    pass
+class Exit(Exception): pass
 
-def raises(expected, *args, **kws):
-    try:
-        if isinstance(args[0], unicode) and not kws:
-            eval(args[0])
-        else:
-            args[0](*args[1:], **kws)
-    except expected:
-        return
-    else:
-        raise ExceptionExpected(u"Exception did not happen.")
+class Failed(Exception): pass
+
+class NotRaised(Exception): pass
+
+class Skipped(Exception): pass
+
+class py:
+
+    class test:
+
+        @staticmethod
+        def exit(message="Unknown reason"):
+            raise Exit(message)
+
+        @staticmethod
+        def fail(message="Unknown reason"):
+            raise Failed(message)
+
+        @staticmethod
+        def skip(message="Unknown reason"):
+            raise Skipped(message)
+
+        @staticmethod
+        def raises(expected, *args, **kws):
+            try:
+                if isinstance(args[0], unicode) and not kws:
+                    eval(args[0])
+                else:
+                    args[0](*args[1:], **kws)
+            except expected:
+                return
+            else:
+                raise NotRaised(u"Exception did not happen.")
 
 if __name__ == u'__main__':
     main(*sys.argv[1:])
diff --git a/tests/t90_bigauto.py b/tests/t90_bigauto.py
index 36afebd..861cce6 100644
--- a/tests/t90_bigauto.py
+++ b/tests/t90_bigauto.py
@@ -28,24 +28,25 @@ argument, all possible possible recodings are considered.
 """
 
 import os, sys
-import common
+from common import py, Recode
+from common import setup_module, teardown_module
 
 class Test:
     avoid_as_before = 'count-characters', 'dump-with-names', 'flat'
 
     def test_1(self):
-        if common.Recode is None:
-            raise common.SkipTest
-        self.outer = common.Recode.Outer(strict=False)
+        if Recode is None:
+            py.test.skip()
+        self.outer = Recode.Outer(strict=False)
         self.charsets = sorted(self.outer.all_charsets())
         for before in self.charsets:
             if before not in self.avoid_as_before:
                 yield self.validate, before
 
     def test_2(self):
-        if common.Recode is None:
-            raise common.SkipTest
-        self.outer = common.Recode.Outer(strict=True)
+        if Recode is None:
+            py.test.skip()
+        self.outer = Recode.Outer(strict=True)
         self.charsets = sorted(self.outer.all_charsets())
         for before in self.charsets:
             if before not in self.avoid_as_before:
@@ -59,7 +60,7 @@ class Test:
         print before
         for after in self.charsets:
             if after is not before:
-                request = common.Recode.Request(self.outer)
+                request = Recode.Request(self.outer)
                 request.scan('%s..%s' % (before, after))
 
 def main(*arguments):
-- 
2.49.0