]> granicus.if.org Git - python/commitdiff
Fix OrderedDic.pop() to work for subclasses that define __missing__().
authorRaymond Hettinger <python@rcn.com>
Sat, 1 Jan 2011 23:51:55 +0000 (23:51 +0000)
committerRaymond Hettinger <python@rcn.com>
Sat, 1 Jan 2011 23:51:55 +0000 (23:51 +0000)
Lib/collections.py
Lib/test/test_collections.py

index d0a44c2828338b86d0396e4649b72639b29be1b9..36ee18a2b61759fd3e63e6fe8168f0b02ad679f4 100644 (file)
@@ -22,7 +22,7 @@ from reprlib import recursive_repr as _recursive_repr
 class _Link(object):
     __slots__ = 'prev', 'next', 'key', '__weakref__'
 
-class OrderedDict(dict, MutableMapping):
+class OrderedDict(dict):
     'Dictionary that remembers insertion order'
     # An inherited dict maps keys to values.
     # The inherited dict provides __getitem__, __len__, __contains__, and get.
@@ -172,12 +172,22 @@ class OrderedDict(dict, MutableMapping):
         return size
 
     update = __update = MutableMapping.update
-    pop = MutableMapping.pop
     keys = MutableMapping.keys
     values = MutableMapping.values
     items = MutableMapping.items
     __ne__ = MutableMapping.__ne__
 
+    __marker = object()
+
+    def pop(self, key, default=__marker):
+        if key in self:
+            result = self[key]
+            del self[key]
+            return result
+        if default is self.__marker:
+            raise KeyError(key)
+        return default
+
     def setdefault(self, key, default=None):
         'OD.setdefault(k[,d]) -> OD.get(k,d), also set OD[k]=d if k not in OD'
         if key in self:
index 8c959792de8a80da642a4c57addd43de02e5202d..deda1cda3287df2951ca0c803d9b2ff6a41c76d7 100644 (file)
@@ -834,6 +834,10 @@ class TestOrderedDict(unittest.TestCase):
         self.assertEqual(list(d.items()),
             [('a', 1), ('b', 2), ('c', 3), ('d', 4), ('e', 5), ('f', 6), ('g', 7)])
 
+    def test_abc(self):
+        self.assertIsInstance(OrderedDict(), MutableMapping)
+        self.assertTrue(issubclass(OrderedDict, MutableMapping))
+
     def test_clear(self):
         pairs = [('c', 1), ('b', 2), ('a', 3), ('d', 4), ('e', 5), ('f', 6)]
         shuffle(pairs)
@@ -892,6 +896,17 @@ class TestOrderedDict(unittest.TestCase):
         self.assertEqual(len(od), 0)
         self.assertEqual(od.pop(k, 12345), 12345)
 
+        # make sure pop still works when __missing__ is defined
+        class Missing(OrderedDict):
+            def __missing__(self, key):
+                return 0
+        m = Missing(a=1)
+        self.assertEqual(m.pop('b', 5), 5)
+        self.assertEqual(m.pop('a', 6), 1)
+        self.assertEqual(m.pop('a', 6), 6)
+        with self.assertRaises(KeyError):
+            m.pop('a')
+
     def test_equality(self):
         pairs = [('c', 1), ('b', 2), ('a', 3), ('d', 4), ('e', 5), ('f', 6)]
         shuffle(pairs)