From 1d879f68525858fcfa7aacdd7cd16721a89c62c4 Mon Sep 17 00:00:00 2001 From: Raymond Hettinger Date: Tue, 4 Jan 2011 20:57:19 +0000 Subject: [PATCH] Backport r87613 to make OrderedDict subclassing match dict subclassing. --- Lib/collections.py | 26 +++++++++++++++++++++----- Lib/test/test_collections.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 49 insertions(+), 5 deletions(-) diff --git a/Lib/collections.py b/Lib/collections.py index 3329f08d62..27bb5e1268 100644 --- a/Lib/collections.py +++ b/Lib/collections.py @@ -21,7 +21,7 @@ from itertools import repeat as _repeat, chain as _chain, starmap as _starmap class _Link(object): __slots__ = 'prev', 'next', 'key', '__weakref__' -class OrderedDict(dict, MutableMapping): +class OrderedDict(dict): 'Dictionary that remembers insertion order' # An inherited dict maps keys to values. # The inherited dict provides __getitem__, __len__, __contains__, and get. @@ -50,7 +50,7 @@ class OrderedDict(dict, MutableMapping): self.__root = root = _Link() # sentinel node for the doubly linked list root.prev = root.next = root self.__map = {} - self.update(*args, **kwds) + self.__update(*args, **kwds) def clear(self): 'od.clear() -> None. Remove all items from od.' @@ -109,13 +109,29 @@ class OrderedDict(dict, MutableMapping): return (self.__class__, (items,), inst_dict) return self.__class__, (items,) - setdefault = MutableMapping.setdefault - update = MutableMapping.update - pop = MutableMapping.pop + update = __update = MutableMapping.update keys = MutableMapping.keys values = MutableMapping.values items = MutableMapping.items + __marker = object() + + def pop(self, key, default=__marker): + if key in self: + result = self[key] + del self[key] + return result + if default is self.__marker: + raise KeyError(key) + return default + + def setdefault(self, key, default=None): + 'od.setdefault(k[,d]) -> od.get(k,d), also set od[k]=d if k not in od' + if key in self: + return self[key] + self[key] = default + return default + def popitem(self, last=True): '''od.popitem() -> (k, v), return and remove a (key, value) pair. Pairs are returned in LIFO order if last is true or FIFO order if false. diff --git a/Lib/test/test_collections.py b/Lib/test/test_collections.py index 988db92e16..8989ac35c8 100644 --- a/Lib/test/test_collections.py +++ b/Lib/test/test_collections.py @@ -792,6 +792,10 @@ class TestOrderedDict(unittest.TestCase): self.assertEqual(list(d.items()), [('a', 1), ('b', 2), ('c', 3), ('d', 4), ('e', 5), ('f', 6), ('g', 7)]) + def test_abc(self): + self.assertTrue(isinstance(OrderedDict(), MutableMapping)) + self.assertTrue(issubclass(OrderedDict, MutableMapping)) + def test_clear(self): pairs = [('c', 1), ('b', 2), ('a', 3), ('d', 4), ('e', 5), ('f', 6)] shuffle(pairs) @@ -850,6 +854,17 @@ class TestOrderedDict(unittest.TestCase): self.assertEqual(len(od), 0) self.assertEqual(od.pop(k, 12345), 12345) + # make sure pop still works when __missing__ is defined + class Missing(OrderedDict): + def __missing__(self, key): + return 0 + m = Missing(a=1) + self.assertEqual(m.pop('b', 5), 5) + self.assertEqual(m.pop('a', 6), 1) + self.assertEqual(m.pop('a', 6), 6) + with self.assertRaises(KeyError): + m.pop('a') + def test_equality(self): pairs = [('c', 1), ('b', 2), ('a', 3), ('d', 4), ('e', 5), ('f', 6)] shuffle(pairs) @@ -934,6 +949,12 @@ class TestOrderedDict(unittest.TestCase): # make sure 'x' is added to the end self.assertEqual(list(od.items())[-1], ('x', 10)) + # make sure setdefault still works when __missing__ is defined + class Missing(OrderedDict): + def __missing__(self, key): + return 0 + self.assertEqual(Missing().setdefault(5, 9), 9) + def test_reinsert(self): # Given insert a, insert b, delete a, re-insert a, # verify that a is now later than b. @@ -945,6 +966,13 @@ class TestOrderedDict(unittest.TestCase): self.assertEqual(list(od.items()), [('b', 2), ('a', 1)]) + def test_override_update(self): + # Verify that subclasses can override update() without breaking __init__() + class MyOD(OrderedDict): + def update(self, *args, **kwds): + raise Exception() + items = [('a', 1), ('c', 3), ('b', 2)] + self.assertEqual(list(MyOD(items).items()), items) class GeneralMappingTests(mapping_tests.BasicTestMappingProtocol): type2test = OrderedDict -- 2.40.0