From: Michael Foord Date: Thu, 17 Mar 2011 00:34:53 +0000 (-0400) Subject: Issue #10242: backport of more fixes to unittest.TestCase.assertItemsEqual X-Git-Tag: v2.7.2rc1~244 X-Git-Url: https://granicus.if.org/sourcecode?a=commitdiff_plain;h=4c9e91a092f55a669b83255ce18bad0c7daa0b31;p=python Issue #10242: backport of more fixes to unittest.TestCase.assertItemsEqual --- diff --git a/Lib/unittest/case.py b/Lib/unittest/case.py index cd8f4fa009..ecb6a3e419 100644 --- a/Lib/unittest/case.py +++ b/Lib/unittest/case.py @@ -10,9 +10,11 @@ import warnings from . import result from .util import ( - strclass, safe_repr, sorted_list_difference, unorderable_list_difference + strclass, safe_repr, unorderable_list_difference, + _count_diff_all_purpose, _count_diff_hashable ) + __unittest = True @@ -863,6 +865,7 @@ class TestCase(object): - [0, 1, 1] and [1, 0, 1] compare equal. - [0, 0, 1] and [0, 1] compare unequal. """ + first_seq, second_seq = list(actual_seq), list(expected_seq) with warnings.catch_warnings(): if sys.py3kwarning: # Silence Py3k warning raised during the sorting @@ -871,29 +874,23 @@ class TestCase(object): "comparing unequal types"]: warnings.filterwarnings("ignore", _msg, DeprecationWarning) try: - actual = collections.Counter(iter(actual_seq)) - expected = collections.Counter(iter(expected_seq)) + first = collections.Counter(first_seq) + second = collections.Counter(second_seq) except TypeError: - # Unsortable items (example: set(), complex(), ...) - actual = list(actual_seq) - expected = list(expected_seq) - missing, unexpected = unorderable_list_difference(expected, actual) + # Handle case with unhashable elements + differences = _count_diff_all_purpose(first_seq, second_seq) else: - if actual == expected: + if first == second: return - missing = list(expected - actual) - unexpected = list(actual - expected) + differences = _count_diff_hashable(first_seq, second_seq) - errors = [] - if missing: - errors.append('Expected, but missing:\n %s' % - safe_repr(missing)) - if unexpected: - errors.append('Unexpected, but present:\n %s' % - safe_repr(unexpected)) - if errors: - standardMsg = '\n'.join(errors) - self.fail(self._formatMessage(msg, standardMsg)) + if differences: + standardMsg = 'Element counts were not equal:\n' + lines = ['First has %d, Second has %d: %r' % diff for diff in differences] + diffMsg = '\n'.join(lines) + standardMsg = self._truncateMessage(standardMsg, diffMsg) + msg = self._formatMessage(msg, standardMsg) + self.fail(msg) def assertMultiLineEqual(self, first, second, msg=None): """Assert that two multi-line strings are equal.""" diff --git a/Lib/unittest/test/test_assertions.py b/Lib/unittest/test/test_assertions.py index e85ca915c3..e1ba614706 100644 --- a/Lib/unittest/test/test_assertions.py +++ b/Lib/unittest/test/test_assertions.py @@ -228,12 +228,6 @@ class TestLongMessage(unittest.TestCase): "^Missing: 'key'$", "^Missing: 'key' : oops$"]) - def testAssertItemsEqual(self): - self.assertMessages('assertItemsEqual', ([], [None]), - [r"\[None\]$", "^oops$", - r"\[None\]$", - r"\[None\] : oops$"]) - def testAssertMultiLineEqual(self): self.assertMessages('assertMultiLineEqual', ("", "foo"), [r"\+ foo$", "^oops$", diff --git a/Lib/unittest/test/test_case.py b/Lib/unittest/test/test_case.py index 250e905352..06eeda1503 100644 --- a/Lib/unittest/test/test_case.py +++ b/Lib/unittest/test/test_case.py @@ -686,20 +686,19 @@ class Test_TestCase(unittest.TestCase, TestEquality, TestHashing): # Test that sequences of unhashable objects can be tested for sameness: self.assertItemsEqual([[1, 2], [3, 4], 0], [False, [3, 4], [1, 2]]) - with test_support.check_warnings(quiet=True) as w: - # hashable types, but not orderable - self.assertRaises(self.failureException, self.assertItemsEqual, - [], [divmod, 'x', 1, 5j, 2j, frozenset()]) - # comparing dicts raises a py3k warning - self.assertItemsEqual([{'a': 1}, {'b': 2}], [{'b': 2}, {'a': 1}]) - # comparing heterogenous non-hashable sequences raises a py3k warning - self.assertItemsEqual([1, 'x', divmod, []], [divmod, [], 'x', 1]) - self.assertRaises(self.failureException, self.assertItemsEqual, - [], [divmod, [], 'x', 1, 5j, 2j, set()]) - # fail the test if warnings are not silenced - if w.warnings: - self.fail('assertItemsEqual raised a warning: ' + - str(w.warnings[0])) + # Test that iterator of unhashable objects can be tested for sameness: + self.assertItemsEqual(iter([1, 2, [], 3, 4]), + iter([1, 2, [], 3, 4])) + + # hashable types, but not orderable + self.assertRaises(self.failureException, self.assertItemsEqual, + [], [divmod, 'x', 1, 5j, 2j, frozenset()]) + # comparing dicts + self.assertItemsEqual([{'a': 1}, {'b': 2}], [{'b': 2}, {'a': 1}]) + # comparing heterogenous non-hashable sequences + self.assertItemsEqual([1, 'x', divmod, []], [divmod, [], 'x', 1]) + self.assertRaises(self.failureException, self.assertItemsEqual, + [], [divmod, [], 'x', 1, 5j, 2j, set()]) self.assertRaises(self.failureException, self.assertItemsEqual, [[1]], [[2]]) @@ -717,6 +716,19 @@ class Test_TestCase(unittest.TestCase, TestEquality, TestHashing): b = a[::-1] self.assertItemsEqual(a, b) + # test utility functions supporting assertItemsEqual() + + diffs = set(unittest.util._count_diff_all_purpose('aaabccd', 'abbbcce')) + expected = {(3,1,'a'), (1,3,'b'), (1,0,'d'), (0,1,'e')} + self.assertEqual(diffs, expected) + + diffs = unittest.util._count_diff_all_purpose([[]], []) + self.assertEqual(diffs, [(1, 0, [])]) + + diffs = set(unittest.util._count_diff_hashable('aaabccd', 'abbbcce')) + expected = {(3,1,'a'), (1,3,'b'), (1,0,'d'), (0,1,'e')} + self.assertEqual(diffs, expected) + def testAssertSetEqual(self): set1 = set() set2 = set() diff --git a/Lib/unittest/util.py b/Lib/unittest/util.py index d201657c6e..220a024e90 100644 --- a/Lib/unittest/util.py +++ b/Lib/unittest/util.py @@ -1,4 +1,6 @@ """Various utility functions.""" +from collections import namedtuple, OrderedDict + __unittest = True @@ -92,3 +94,63 @@ def unorderable_list_difference(expected, actual, ignore_duplicate=False): # anything left in actual is unexpected return missing, actual + +_Mismatch = namedtuple('Mismatch', 'actual expected value') + +def _count_diff_all_purpose(actual, expected): + 'Returns list of (cnt_act, cnt_exp, elem) triples where the counts differ' + # elements need not be hashable + s, t = list(actual), list(expected) + m, n = len(s), len(t) + NULL = object() + result = [] + for i, elem in enumerate(s): + if elem is NULL: + continue + cnt_s = cnt_t = 0 + for j in range(i, m): + if s[j] == elem: + cnt_s += 1 + s[j] = NULL + for j, other_elem in enumerate(t): + if other_elem == elem: + cnt_t += 1 + t[j] = NULL + if cnt_s != cnt_t: + diff = _Mismatch(cnt_s, cnt_t, elem) + result.append(diff) + + for i, elem in enumerate(t): + if elem is NULL: + continue + cnt_t = 0 + for j in range(i, n): + if t[j] == elem: + cnt_t += 1 + t[j] = NULL + diff = _Mismatch(0, cnt_t, elem) + result.append(diff) + return result + +def _ordered_count(iterable): + 'Return dict of element counts, in the order they were first seen' + c = OrderedDict() + for elem in iterable: + c[elem] = c.get(elem, 0) + 1 + return c + +def _count_diff_hashable(actual, expected): + 'Returns list of (cnt_act, cnt_exp, elem) triples where the counts differ' + # elements must be hashable + s, t = _ordered_count(actual), _ordered_count(expected) + result = [] + for elem, cnt_s in s.items(): + cnt_t = t.get(elem, 0) + if cnt_s != cnt_t: + diff = _Mismatch(cnt_s, cnt_t, elem) + result.append(diff) + for elem, cnt_t in t.items(): + if elem not in s: + diff = _Mismatch(0, cnt_t, elem) + result.append(diff) + return result