]> granicus.if.org Git - python/commitdiff
Addition of delta keyword argument to unittest.TestCase.assertAlmostEquals and assert...
authorMichael Foord <fuzzyman@voidspace.org.uk>
Sat, 27 Mar 2010 19:10:11 +0000 (19:10 +0000)
committerMichael Foord <fuzzyman@voidspace.org.uk>
Sat, 27 Mar 2010 19:10:11 +0000 (19:10 +0000)
This allows the comparison of objects by specifying a maximum difference; this includes the comparing of non-numeric objects that don't support rounding.

Lib/unittest/case.py
Lib/unittest/test/test_assertions.py

index 94956896f3e8e73f3aab15daf6d79cfa77a9f600..98918038187b508bd517478c3ac9458cdd714d5c 100644 (file)
@@ -490,10 +490,12 @@ class TestCase(object):
                                                           safe_repr(second)))
             raise self.failureException(msg)
 
-    def assertAlmostEqual(self, first, second, places=7, msg=None):
+
+    def assertAlmostEqual(self, first, second, places=None, msg=None, delta=None):
         """Fail if the two objects are unequal as determined by their
            difference rounded to the given number of decimal places
-           (default 7) and comparing to zero.
+           (default 7) and comparing to zero, or by comparing that the
+           between the two objects is more than the given delta.
 
            Note that decimal places (from zero) are usually not the same
            as significant digits (measured from the most signficant digit).
@@ -502,31 +504,61 @@ class TestCase(object):
            compare almost equal.
         """
         if first == second:
-            # shortcut for inf
+            # shortcut
             return
-        if round(abs(second-first), places) != 0:
+        if delta is not None and places is not None:
+            raise TypeError("specify delta or places not both")
+
+        if delta is not None:
+            if abs(first - second) <= delta:
+                return
+
+            standardMsg = '%s != %s within %s delta' % (safe_repr(first),
+                                                        safe_repr(second),
+                                                        safe_repr(delta))
+        else:
+            if places is None:
+                places = 7
+
+            if round(abs(second-first), places) == 0:
+                return
+
             standardMsg = '%s != %s within %r places' % (safe_repr(first),
                                                           safe_repr(second),
                                                           places)
-            msg = self._formatMessage(msg, standardMsg)
-            raise self.failureException(msg)
+        msg = self._formatMessage(msg, standardMsg)
+        raise self.failureException(msg)
 
-    def assertNotAlmostEqual(self, first, second, places=7, msg=None):
+    def assertNotAlmostEqual(self, first, second, places=None, msg=None, delta=None):
         """Fail if the two objects are equal as determined by their
            difference rounded to the given number of decimal places
-           (default 7) and comparing to zero.
+           (default 7) and comparing to zero, or by comparing that the
+           between the two objects is less than the given delta.
 
            Note that decimal places (from zero) are usually not the same
            as significant digits (measured from the most signficant digit).
 
            Objects that are equal automatically fail.
         """
-        if (first == second) or round(abs(second-first), places) == 0:
+        if delta is not None and places is not None:
+            raise TypeError("specify delta or places not both")
+        if delta is not None:
+            if not (first == second) and abs(first - second) > delta:
+                return
+            standardMsg = '%s == %s within %s delta' % (safe_repr(first),
+                                                        safe_repr(second),
+                                                        safe_repr(delta))
+        else:
+            if places is None:
+                places = 7
+            if not (first == second) and round(abs(second-first), places) != 0:
+                return
             standardMsg = '%s == %s within %r places' % (safe_repr(first),
-                                                          safe_repr(second),
-                                                          places)
-            msg = self._formatMessage(msg, standardMsg)
-            raise self.failureException(msg)
+                                                         safe_repr(second),
+                                                         places)
+
+        msg = self._formatMessage(msg, standardMsg)
+        raise self.failureException(msg)
 
     # Synonyms for assertion methods
 
index 4b4e03ae84db82fdcd760b026f4e73632b443e57..5d77ce8fa459b1d6a5116a2c8ec46308da2b2459 100644 (file)
@@ -1,3 +1,5 @@
+import datetime
+
 import unittest
 
 
@@ -25,6 +27,28 @@ class Test_Assertions(unittest.TestCase):
         self.assertRaises(self.failureException, self.assertNotAlmostEqual,
                           float('inf'), float('inf'))
 
+    def test_AmostEqualWithDelta(self):
+        self.assertAlmostEqual(1.1, 1.0, delta=0.5)
+        self.assertAlmostEqual(1.0, 1.1, delta=0.5)
+        self.assertNotAlmostEqual(1.1, 1.0, delta=0.05)
+        self.assertNotAlmostEqual(1.0, 1.1, delta=0.05)
+
+        self.assertRaises(self.failureException, self.assertAlmostEqual,
+                          1.1, 1.0, delta=0.05)
+        self.assertRaises(self.failureException, self.assertNotAlmostEqual,
+                          1.1, 1.0, delta=0.5)
+
+        self.assertRaises(TypeError, self.assertAlmostEqual,
+                          1.1, 1.0, places=2, delta=2)
+        self.assertRaises(TypeError, self.assertNotAlmostEqual,
+                          1.1, 1.0, places=2, delta=2)
+
+        first = datetime.datetime.now()
+        second = first + datetime.timedelta(seconds=10)
+        self.assertAlmostEqual(first, second,
+                               delta=datetime.timedelta(seconds=20))
+        self.assertNotAlmostEqual(first, second,
+                                  delta=datetime.timedelta(seconds=5))
 
     def test_assertRaises(self):
         def _raise(e):