]> granicus.if.org Git - python/commitdiff
- Patch 1433928:
authorGuido van Rossum <guido@python.org>
Sat, 25 Feb 2006 22:38:04 +0000 (22:38 +0000)
committerGuido van Rossum <guido@python.org>
Sat, 25 Feb 2006 22:38:04 +0000 (22:38 +0000)
  - The copy module now "copies" function objects (as atomic objects).
  - dict.__getitem__ now looks for a __missing__ hook before raising
    KeyError.
  - Added a new type, defaultdict, to the collections module.
    This uses the new __missing__ hook behavior added to dict (see above).

12 files changed:
Doc/lib/libcollections.tex
Doc/lib/libcopy.tex
Doc/lib/libstdtypes.tex
Lib/UserDict.py
Lib/copy.py
Lib/test/test_copy.py
Lib/test/test_defaultdict.py [new file with mode: 0644]
Lib/test/test_dict.py
Lib/test/test_userdict.py
Misc/NEWS
Modules/collectionsmodule.c
Objects/dictobject.c

index 51603aa23373cf622ce4e5aaee89d033a2acb125..542ef6b0e96c904f4c634f80083af90892c059ae 100644 (file)
@@ -8,9 +8,10 @@
 \versionadded{2.4}
 
 
-This module implements high-performance container datatypes.  Currently, the
-only datatype is a deque.  Future additions may include B-trees
-and Fibonacci heaps.
+This module implements high-performance container datatypes.  Currently,
+there are two datatypes, deque and defaultdict.
+Future additions may include B-trees and Fibonacci heaps.
+\versionchanged[Added defaultdict]{2.5}
 
 \begin{funcdesc}{deque}{\optional{iterable}}
   Returns a new deque objected initialized left-to-right (using
@@ -211,3 +212,46 @@ def maketree(iterable):
 [[[['a', 'b'], ['c', 'd']], [['e', 'f'], ['g', 'h']]]]
 
 \end{verbatim}
+
+
+
+\begin{funcdesc}{defaultdict}{\optional{default_factory\optional{, ...}}}
+  Returns a new dictionary-like object.  \class{defaultdict} is a subclass
+  of the builtin \class{dict} class.  It overrides one method and adds one
+  writable instance variable.  The remaining functionality is the same as
+  for the \class{dict} class and is not documented here.
+
+  The first argument provides the initial value for the
+  \member{default_factory} attribute; it defaults to \code{None}.
+  All remaining arguments are treated the same as if they were
+  passed to the \class{dict} constructor, including keyword arguments.
+
+ \versionadded{2.5}
+\end{funcdesc}
+
+\class{defaultdict} objects support the following method in addition to
+the standard \class{dict} operations:
+
+\begin{methoddesc}{__missing__}{key}
+  If the \member{default_factory} attribute is \code{None}, this raises
+  an \exception{KeyError} exception with the \var{key} as argument.
+
+  If \member{default_factory} is not \code{None}, it is called without
+  arguments to provide a default value for the given \var{key}, this
+  value is inserted in the dictionary for the \var{key}, and returned.
+
+  If calling \member{default_factory} raises an exception this exception
+  is propagated unchanged.
+
+  This method is called by the \method{__getitem__} method of the
+  \class{dict} class when the requested key is not found; whatever it
+  returns or raises is then returned or raised by \method{__getitem__}.
+\end{methoddesc}
+
+\class{defaultdict} objects support the following instance variable:
+
+\begin{datadesc}{default_factory}
+  This attribute is used by the \method{__missing__} method; it is initialized
+  from the first argument to the constructor, if present, or to \code{None}, 
+  if absent.
+\end{datadesc}
index d73d6fdecf9c6473175c2cc58b6a3e17bd8f7f3d..59641879d99b20640faf58e046c474d8accdaac0 100644 (file)
@@ -67,9 +67,12 @@ set of components copied.
 
 \end{itemize}
 
-This version does not copy types like module, class, function, method,
+This module does not copy types like module, method,
 stack trace, stack frame, file, socket, window, array, or any similar
-types.
+types.  It does ``copy'' functions and classes (shallow and deeply),
+by returning the original object unchanged; this is compatible with
+the way these are treated by the \module{pickle} module.
+\versionchanged[Added copying functions]{2.5}
 
 Classes can use the same interfaces to control copying that they use
 to control pickling.  See the description of module
index 5df39db2911032f9ec6bc77d4a0c012b4c098f48..5d153758b92d45607d2f9be780b4d46dd99d978f 100644 (file)
@@ -1350,7 +1350,7 @@ arbitrary objects):
 
 \begin{tableiii}{c|l|c}{code}{Operation}{Result}{Notes}
   \lineiii{len(\var{a})}{the number of items in \var{a}}{}
-  \lineiii{\var{a}[\var{k}]}{the item of \var{a} with key \var{k}}{(1)}
+  \lineiii{\var{a}[\var{k}]}{the item of \var{a} with key \var{k}}{(1), (10)}
   \lineiii{\var{a}[\var{k}] = \var{v}}
           {set \code{\var{a}[\var{k}]} to \var{v}}
           {}
@@ -1454,6 +1454,17 @@ then is updated with those key/value pairs:
 \versionchanged[Allowed the argument to be an iterable of key/value
                 pairs and allowed keyword arguments]{2.4}
 
+\item[(10)] If a subclass of dict defines a method \method{__missing__},
+if the key \var{k} is not present, the \var{a}[\var{k}] operation calls
+that method with the key \var{k} as argument.  The \var{a}[\var{k}]
+operation then returns or raises whatever is returned or raised by the
+\function{__missing__}(\var{k}) call if the key is not present.
+No other operations or methods invoke \method{__missing__}().
+If \method{__missing__} is not defined, \exception{KeyError} is raised.
+\method{__missing__} must be a method; it cannot be an instance variable.
+For an example, see \module{collections}.\class{defaultdict}.
+\versionadded{2.5}
+
 \end{description}
 
 \subsection{File Objects
index 71687038743d303acfc6dbf925045b2893c7a9dc..5e97817f061ec1bdef90a3709b0f06e182d681a8 100644 (file)
@@ -14,7 +14,12 @@ class UserDict:
         else:
             return cmp(self.data, dict)
     def __len__(self): return len(self.data)
-    def __getitem__(self, key): return self.data[key]
+    def __getitem__(self, key):
+        if key in self.data:
+            return self.data[key]
+        if hasattr(self.__class__, "__missing__"):
+            return self.__class__.__missing__(self, key)
+        raise KeyError(key)
     def __setitem__(self, key, item): self.data[key] = item
     def __delitem__(self, key): del self.data[key]
     def clear(self): self.data.clear()
index b3419ca9ed45088060d4bb07fc50f38426afa92e..9e60144cc153bf1fe0ef88c798aed955ef4e0419 100644 (file)
@@ -101,7 +101,8 @@ def _copy_immutable(x):
     return x
 for t in (type(None), int, long, float, bool, str, tuple,
           frozenset, type, xrange, types.ClassType,
-          types.BuiltinFunctionType):
+          types.BuiltinFunctionType,
+         types.FunctionType):
     d[t] = _copy_immutable
 for name in ("ComplexType", "UnicodeType", "CodeType"):
     t = getattr(types, name, None)
@@ -217,6 +218,7 @@ d[type] = _deepcopy_atomic
 d[xrange] = _deepcopy_atomic
 d[types.ClassType] = _deepcopy_atomic
 d[types.BuiltinFunctionType] = _deepcopy_atomic
+d[types.FunctionType] = _deepcopy_atomic
 
 def _deepcopy_list(x, memo):
     y = []
index bd5a3e10b99b51f351c4b36e0d77397ec4276adc..ff4c987eba21b65c5659a3adcf1819740ced0cea 100644 (file)
@@ -568,6 +568,22 @@ class TestCopy(unittest.TestCase):
                 raise ValueError, "ain't got no stickin' state"
         self.assertRaises(ValueError, copy.copy, EvilState())
 
+    def test_copy_function(self):
+        self.assertEqual(copy.copy(global_foo), global_foo)
+        def foo(x, y): return x+y
+        self.assertEqual(copy.copy(foo), foo)
+        bar = lambda: None
+        self.assertEqual(copy.copy(bar), bar)
+
+    def test_deepcopy_function(self):
+        self.assertEqual(copy.deepcopy(global_foo), global_foo)
+        def foo(x, y): return x+y
+        self.assertEqual(copy.deepcopy(foo), foo)
+        bar = lambda: None
+        self.assertEqual(copy.deepcopy(bar), bar)
+
+def global_foo(x, y): return x+y
+
 def test_main():
     test_support.run_unittest(TestCopy)
 
diff --git a/Lib/test/test_defaultdict.py b/Lib/test/test_defaultdict.py
new file mode 100644 (file)
index 0000000..b5a6628
--- /dev/null
@@ -0,0 +1,135 @@
+"""Unit tests for collections.defaultdict."""
+
+import os
+import copy
+import tempfile
+import unittest
+
+from collections import defaultdict
+
+def foobar():
+    return list
+
+class TestDefaultDict(unittest.TestCase):
+
+    def test_basic(self):
+        d1 = defaultdict()
+        self.assertEqual(d1.default_factory, None)
+        d1.default_factory = list
+        d1[12].append(42)
+        self.assertEqual(d1, {12: [42]})
+        d1[12].append(24)
+        self.assertEqual(d1, {12: [42, 24]})
+        d1[13]
+        d1[14]
+        self.assertEqual(d1, {12: [42, 24], 13: [], 14: []})
+        self.assert_(d1[12] is not d1[13] is not d1[14])
+        d2 = defaultdict(list, foo=1, bar=2)
+        self.assertEqual(d2.default_factory, list)
+        self.assertEqual(d2, {"foo": 1, "bar": 2})
+        self.assertEqual(d2["foo"], 1)
+        self.assertEqual(d2["bar"], 2)
+        self.assertEqual(d2[42], [])
+        self.assert_("foo" in d2)
+        self.assert_("foo" in d2.keys())
+        self.assert_("bar" in d2)
+        self.assert_("bar" in d2.keys())
+        self.assert_(42 in d2)
+        self.assert_(42 in d2.keys())
+        self.assert_(12 not in d2)
+        self.assert_(12 not in d2.keys())
+        d2.default_factory = None
+        self.assertEqual(d2.default_factory, None)
+        try:
+            d2[15]
+        except KeyError, err:
+            self.assertEqual(err.args, (15,))
+        else:
+            self.fail("d2[15] didn't raise KeyError")
+
+    def test_missing(self):
+        d1 = defaultdict()
+        self.assertRaises(KeyError, d1.__missing__, 42)
+        d1.default_factory = list
+        self.assertEqual(d1.__missing__(42), [])
+
+    def test_repr(self):
+        d1 = defaultdict()
+        self.assertEqual(d1.default_factory, None)
+        self.assertEqual(repr(d1), "defaultdict(None, {})")
+        d1[11] = 41
+        self.assertEqual(repr(d1), "defaultdict(None, {11: 41})")
+        d2 = defaultdict(0)
+        self.assertEqual(d2.default_factory, 0)
+        d2[12] = 42
+        self.assertEqual(repr(d2), "defaultdict(0, {12: 42})")
+        def foo(): return 43
+        d3 = defaultdict(foo)
+        self.assert_(d3.default_factory is foo)
+        d3[13]
+        self.assertEqual(repr(d3), "defaultdict(%s, {13: 43})" % repr(foo))
+
+    def test_print(self):
+        d1 = defaultdict()
+        def foo(): return 42
+        d2 = defaultdict(foo, {1: 2})
+        # NOTE: We can't use tempfile.[Named]TemporaryFile since this
+        # code must exercise the tp_print C code, which only gets
+        # invoked for *real* files.
+        tfn = tempfile.mktemp()
+        try:
+            f = open(tfn, "w+")
+            try:
+                print >>f, d1
+                print >>f, d2
+                f.seek(0)
+                self.assertEqual(f.readline(), repr(d1) + "\n")
+                self.assertEqual(f.readline(), repr(d2) + "\n")
+            finally:
+                f.close()
+        finally:
+            os.remove(tfn)
+
+    def test_copy(self):
+        d1 = defaultdict()
+        d2 = d1.copy()
+        self.assertEqual(type(d2), defaultdict)
+        self.assertEqual(d2.default_factory, None)
+        self.assertEqual(d2, {})
+        d1.default_factory = list
+        d3 = d1.copy()
+        self.assertEqual(type(d3), defaultdict)
+        self.assertEqual(d3.default_factory, list)
+        self.assertEqual(d3, {})
+        d1[42]
+        d4 = d1.copy()
+        self.assertEqual(type(d4), defaultdict)
+        self.assertEqual(d4.default_factory, list)
+        self.assertEqual(d4, {42: []})
+        d4[12]
+        self.assertEqual(d4, {42: [], 12: []})
+
+    def test_shallow_copy(self):
+        d1 = defaultdict(foobar, {1: 1})
+        d2 = copy.copy(d1)
+        self.assertEqual(d2.default_factory, foobar)
+        self.assertEqual(d2, d1)
+        d1.default_factory = list
+        d2 = copy.copy(d1)
+        self.assertEqual(d2.default_factory, list)
+        self.assertEqual(d2, d1)
+
+    def test_deep_copy(self):
+        d1 = defaultdict(foobar, {1: [1]})
+        d2 = copy.deepcopy(d1)
+        self.assertEqual(d2.default_factory, foobar)
+        self.assertEqual(d2, d1)
+        self.assert_(d1[1] is not d2[1])
+        d1.default_factory = list
+        d2 = copy.deepcopy(d1)
+        self.assertEqual(d2.default_factory, list)
+        self.assertEqual(d2, d1)
+
+
+if __name__ == "__main__":
+    unittest.main()
index e13829caec9dfdbe5db63293665bb36dcade294c..f3f78e79dc3705f54efa24bcb55d3c6c5d2b961f 100644 (file)
@@ -395,6 +395,56 @@ class DictTest(unittest.TestCase):
         else:
             self.fail("< didn't raise Exc")
 
+    def test_missing(self):
+        # Make sure dict doesn't have a __missing__ method
+        self.assertEqual(hasattr(dict, "__missing__"), False)
+        self.assertEqual(hasattr({}, "__missing__"), False)
+        # Test several cases:
+        # (D) subclass defines __missing__ method returning a value
+        # (E) subclass defines __missing__ method raising RuntimeError
+        # (F) subclass sets __missing__ instance variable (no effect)
+        # (G) subclass doesn't define __missing__ at a all
+        class D(dict):
+            def __missing__(self, key):
+                return 42
+        d = D({1: 2, 3: 4})
+        self.assertEqual(d[1], 2)
+        self.assertEqual(d[3], 4)
+        self.assert_(2 not in d)
+        self.assert_(2 not in d.keys())
+        self.assertEqual(d[2], 42)
+        class E(dict):
+            def __missing__(self, key):
+                raise RuntimeError(key)
+        e = E()
+        try:
+            e[42]
+        except RuntimeError, err:
+            self.assertEqual(err.args, (42,))
+        else:
+            self.fail_("e[42] didn't raise RuntimeError")
+        class F(dict):
+            def __init__(self):
+                # An instance variable __missing__ should have no effect
+                self.__missing__ = lambda key: None
+        f = F()
+        try:
+            f[42]
+        except KeyError, err:
+            self.assertEqual(err.args, (42,))
+        else:
+            self.fail_("f[42] didn't raise KeyError")
+        class G(dict):
+            pass
+        g = G()
+        try:
+            g[42]
+        except KeyError, err:
+            self.assertEqual(err.args, (42,))
+        else:
+            self.fail_("g[42] didn't raise KeyError")
+
+
 import mapping_tests
 
 class GeneralMappingTests(mapping_tests.BasicTestMappingProtocol):
index 2d5fa0304f1c2969e9fcc49dca6bc02062e0a649..a4b7de406bff2b09d2ecfd4a4860dafd9a4c3a18 100644 (file)
@@ -148,6 +148,55 @@ class UserDictTest(mapping_tests.TestHashMappingProtocol):
         self.assertEqual(t.popitem(), ("x", 42))
         self.assertRaises(KeyError, t.popitem)
 
+    def test_missing(self):
+        # Make sure UserDict doesn't have a __missing__ method
+        self.assertEqual(hasattr(UserDict, "__missing__"), False)
+        # Test several cases:
+        # (D) subclass defines __missing__ method returning a value
+        # (E) subclass defines __missing__ method raising RuntimeError
+        # (F) subclass sets __missing__ instance variable (no effect)
+        # (G) subclass doesn't define __missing__ at a all
+        class D(UserDict.UserDict):
+            def __missing__(self, key):
+                return 42
+        d = D({1: 2, 3: 4})
+        self.assertEqual(d[1], 2)
+        self.assertEqual(d[3], 4)
+        self.assert_(2 not in d)
+        self.assert_(2 not in d.keys())
+        self.assertEqual(d[2], 42)
+        class E(UserDict.UserDict):
+            def __missing__(self, key):
+                raise RuntimeError(key)
+        e = E()
+        try:
+            e[42]
+        except RuntimeError, err:
+            self.assertEqual(err.args, (42,))
+        else:
+            self.fail_("e[42] didn't raise RuntimeError")
+        class F(UserDict.UserDict):
+            def __init__(self):
+                # An instance variable __missing__ should have no effect
+                self.__missing__ = lambda key: None
+                UserDict.UserDict.__init__(self)
+        f = F()
+        try:
+            f[42]
+        except KeyError, err:
+            self.assertEqual(err.args, (42,))
+        else:
+            self.fail_("f[42] didn't raise KeyError")
+        class G(UserDict.UserDict):
+            pass
+        g = G()
+        try:
+            g[42]
+        except KeyError, err:
+            self.assertEqual(err.args, (42,))
+        else:
+            self.fail_("g[42] didn't raise KeyError")
+
 ##########################
 # Test Dict Mixin
 
index 8429b7be742df85d48fc34d4cd67e69dc56e9678..e5c5e40e55075601f8233119b7d0dc762e33397f 100644 (file)
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -12,6 +12,11 @@ What's New in Python 2.5 alpha 1?
 Core and builtins
 -----------------
 
+- Patch 1433928:
+  - The copy module now "copies" function objects (as atomic objects).
+  - dict.__getitem__ now looks for a __missing__ hook before raising
+    KeyError.
+
 - Fix the encodings package codec search function to only search
   inside its own package. Fixes problem reported in patch #1433198.
 
@@ -224,6 +229,9 @@ Core and builtins
 Extension Modules
 -----------------
 
+- Patch 1433928: Added a new type, defaultdict, to the collections module.
+  This uses the new __missing__ hook behavior added to dict (see above).
+
 - Bug #854823: socketmodule now builds on Sun platforms even when
   INET_ADDRSTRLEN is not defined.
 
index 7368d80d853716f76c4b316eabbba346a64ed02d..b80ab07327ff4838771db238266356cfbb00b817 100644 (file)
@@ -1065,10 +1065,269 @@ PyTypeObject dequereviter_type = {
        0,
 };
 
+/* defaultdict type *********************************************************/
+
+typedef struct {
+       PyDictObject dict;
+       PyObject *default_factory;
+} defdictobject;
+
+static PyTypeObject defdict_type; /* Forward */
+
+PyDoc_STRVAR(defdict_missing_doc,
+"__missing__(key) # Called by __getitem__ for missing key; pseudo-code:\n\
+  if self.default_factory is None: raise KeyError(key)\n\
+  self[key] = value = self.default_factory()\n\
+  return value\n\
+");
+
+static PyObject *
+defdict_missing(defdictobject *dd, PyObject *key)
+{
+       PyObject *factory = dd->default_factory;
+       PyObject *value;
+       if (factory == NULL || factory == Py_None) {
+               /* XXX Call dict.__missing__(key) */
+               PyErr_SetObject(PyExc_KeyError, key);
+               return NULL;
+       }
+       value = PyEval_CallObject(factory, NULL);
+       if (value == NULL)
+               return value;
+       if (PyObject_SetItem((PyObject *)dd, key, value) < 0) {
+               Py_DECREF(value);
+               return NULL;
+       }
+       return value;
+}
+
+PyDoc_STRVAR(defdict_copy_doc, "D.copy() -> a shallow copy of D.");
+
+static PyObject *
+defdict_copy(defdictobject *dd)
+{
+       /* This calls the object's class.  That only works for subclasses
+          whose class constructor has the same signature.  Subclasses that
+          define a different constructor signature must override copy().
+       */
+       return PyObject_CallFunctionObjArgs((PyObject *)dd->dict.ob_type,
+                                           dd->default_factory, dd, NULL);
+}
+
+static PyObject *
+defdict_reduce(defdictobject *dd)
+{
+       /* __reduce__ must returns a 5-tuple as follows:
+
+          - factory function
+          - tuple of args for the factory function
+          - additional state (here None)
+          - sequence iterator (here None)
+          - dictionary iterator (yielding successive (key, value) pairs
+
+          This API is used by pickle.py and copy.py.
+
+          For this to be useful with pickle.py, the default_factory
+          must be picklable; e.g., None, a built-in, or a global
+          function in a module or package.
+
+          Both shallow and deep copying are supported, but for deep
+          copying, the default_factory must be deep-copyable; e.g. None,
+          or a built-in (functions are not copyable at this time).
+
+          This only works for subclasses as long as their constructor
+          signature is compatible; the first argument must be the
+          optional default_factory, defaulting to None.
+       */
+       PyObject *args;
+       PyObject *items;
+       PyObject *result;
+       if (dd->default_factory == NULL || dd->default_factory == Py_None)
+               args = PyTuple_New(0);
+       else
+               args = PyTuple_Pack(1, dd->default_factory);
+       if (args == NULL)
+               return NULL;
+       items = PyObject_CallMethod((PyObject *)dd, "iteritems", "()");
+       if (items == NULL) {
+               Py_DECREF(args);
+               return NULL;
+       }
+       result = PyTuple_Pack(5, dd->dict.ob_type, args,
+                             Py_None, Py_None, items);
+       Py_DECREF(args);
+       return result;
+}
+
+static PyMethodDef defdict_methods[] = {
+       {"__missing__", (PyCFunction)defdict_missing, METH_O,
+        defdict_missing_doc},
+       {"copy", (PyCFunction)defdict_copy, METH_NOARGS,
+        defdict_copy_doc},
+       {"__copy__", (PyCFunction)defdict_copy, METH_NOARGS,
+        defdict_copy_doc},
+       {"__reduce__", (PyCFunction)defdict_reduce, METH_NOARGS,
+        reduce_doc},
+       {NULL}
+};
+
+static PyMemberDef defdict_members[] = {
+       {"default_factory", T_OBJECT,
+        offsetof(defdictobject, default_factory), 0,
+        PyDoc_STR("Factory for default value called by __missing__().")},
+       {NULL}
+};
+
+static void
+defdict_dealloc(defdictobject *dd)
+{
+       Py_CLEAR(dd->default_factory);
+       PyDict_Type.tp_dealloc((PyObject *)dd);
+}
+
+static int
+defdict_print(defdictobject *dd, FILE *fp, int flags)
+{
+       int sts;
+       fprintf(fp, "defaultdict(");
+       if (dd->default_factory == NULL)
+               fprintf(fp, "None");
+       else {
+               PyObject_Print(dd->default_factory, fp, 0);
+       }
+       fprintf(fp, ", ");
+       sts = PyDict_Type.tp_print((PyObject *)dd, fp, 0);
+       fprintf(fp, ")");
+       return sts;
+}
+
+static PyObject *
+defdict_repr(defdictobject *dd)
+{
+       PyObject *defrepr;
+       PyObject *baserepr;
+       PyObject *result;
+       baserepr = PyDict_Type.tp_repr((PyObject *)dd);
+       if (baserepr == NULL)
+               return NULL;
+       if (dd->default_factory == NULL)
+               defrepr = PyString_FromString("None");
+       else
+               defrepr = PyObject_Repr(dd->default_factory);
+       if (defrepr == NULL) {
+               Py_DECREF(baserepr);
+               return NULL;
+       }
+       result = PyString_FromFormat("defaultdict(%s, %s)",
+                                    PyString_AS_STRING(defrepr),
+                                    PyString_AS_STRING(baserepr));
+       Py_DECREF(defrepr);
+       Py_DECREF(baserepr);
+       return result;
+}
+
+static int
+defdict_traverse(PyObject *self, visitproc visit, void *arg)
+{
+       Py_VISIT(((defdictobject *)self)->default_factory);
+       return PyDict_Type.tp_traverse(self, visit, arg);
+}
+
+static int
+defdict_tp_clear(defdictobject *dd)
+{
+       if (dd->default_factory != NULL) {
+               Py_DECREF(dd->default_factory);
+               dd->default_factory = NULL;
+       }
+       return PyDict_Type.tp_clear((PyObject *)dd);
+}
+
+static int
+defdict_init(PyObject *self, PyObject *args, PyObject *kwds)
+{
+       defdictobject *dd = (defdictobject *)self;
+       PyObject *olddefault = dd->default_factory;
+       PyObject *newdefault = NULL;
+       PyObject *newargs;
+       int result;
+       if (args == NULL || !PyTuple_Check(args))
+               newargs = PyTuple_New(0);
+       else {
+               Py_ssize_t n = PyTuple_GET_SIZE(args);
+               if (n > 0)
+                       newdefault = PyTuple_GET_ITEM(args, 0);
+               newargs = PySequence_GetSlice(args, 1, n);
+       }
+       if (newargs == NULL)
+               return -1;
+       Py_XINCREF(newdefault);
+       dd->default_factory = newdefault;
+       result = PyDict_Type.tp_init(self, newargs, kwds);
+       Py_DECREF(newargs);
+       Py_XDECREF(olddefault);
+       return result;
+}
+
+PyDoc_STRVAR(defdict_doc,
+"defaultdict(default_factory) --> dict with default factory\n\
+\n\
+The default factory is called without arguments to produce\n\
+a new value when a key is not present, in __getitem__ only.\n\
+A defaultdict compares equal to a dict with the same items.\n\
+");
+
+static PyTypeObject defdict_type = {
+       PyObject_HEAD_INIT(NULL)
+       0,                              /* ob_size */
+       "collections.defaultdict",      /* tp_name */
+       sizeof(defdictobject),          /* tp_basicsize */
+       0,                              /* tp_itemsize */
+       /* methods */
+       (destructor)defdict_dealloc,    /* tp_dealloc */
+       (printfunc)defdict_print,       /* tp_print */
+       0,                              /* tp_getattr */
+       0,                              /* tp_setattr */
+       0,                              /* tp_compare */
+       (reprfunc)defdict_repr,         /* tp_repr */
+       0,                              /* tp_as_number */
+       0,                              /* tp_as_sequence */
+       0,                              /* tp_as_mapping */
+       0,                              /* tp_hash */
+       0,                              /* tp_call */
+       0,                              /* tp_str */
+       PyObject_GenericGetAttr,        /* tp_getattro */
+       0,                              /* tp_setattro */
+       0,                              /* tp_as_buffer */
+       Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HAVE_GC |
+               Py_TPFLAGS_HAVE_WEAKREFS,       /* tp_flags */
+       defdict_doc,                    /* tp_doc */
+       (traverseproc)defdict_traverse, /* tp_traverse */
+       (inquiry)defdict_tp_clear,      /* tp_clear */
+       0,                              /* tp_richcompare */
+       0,                              /* tp_weaklistoffset*/
+       0,                              /* tp_iter */
+       0,                              /* tp_iternext */
+       defdict_methods,                /* tp_methods */
+       defdict_members,                /* tp_members */
+       0,                              /* tp_getset */
+       &PyDict_Type,                   /* tp_base */
+       0,                              /* tp_dict */
+       0,                              /* tp_descr_get */
+       0,                              /* tp_descr_set */
+       0,                              /* tp_dictoffset */
+       (initproc)defdict_init,         /* tp_init */
+       PyType_GenericAlloc,            /* tp_alloc */
+       0,                              /* tp_new */
+       PyObject_GC_Del,                /* tp_free */
+};
+
 /* module level code ********************************************************/
 
 PyDoc_STRVAR(module_doc,
-"High performance data structures\n\
+"High performance data structures.\n\
+- deque:        ordered collection accessible from endpoints only\n\
+- defaultdict:  dict subclass with a default value factory\n\
 ");
 
 PyMODINIT_FUNC
@@ -1085,6 +1344,11 @@ initcollections(void)
        Py_INCREF(&deque_type);
        PyModule_AddObject(m, "deque", (PyObject *)&deque_type);
 
+       if (PyType_Ready(&defdict_type) < 0)
+               return;
+       Py_INCREF(&defdict_type);
+       PyModule_AddObject(m, "defaultdict", (PyObject *)&defdict_type);
+
        if (PyType_Ready(&dequeiter_type) < 0)
                return;
 
index f5e532028629bbd191cc3d130be9a57490fc3e08..2254762a396c79767756a1e6b81bc98ad9458d13 100644 (file)
@@ -882,8 +882,22 @@ dict_subscript(dictobject *mp, register PyObject *key)
                        return NULL;
        }
        v = (mp->ma_lookup)(mp, key, hash) -> me_value;
-       if (v == NULL)
+       if (v == NULL) {
+               if (!PyDict_CheckExact(mp)) {
+                       /* Look up __missing__ method if we're a subclass. */
+                       static PyObject *missing_str = NULL;
+                       if (missing_str == NULL)
+                               missing_str = 
+                                 PyString_InternFromString("__missing__");
+                       PyObject *missing = _PyType_Lookup(mp->ob_type,
+                                                          missing_str);
+                       if (missing != NULL)
+                               return PyObject_CallFunctionObjArgs(missing,
+                                       (PyObject *)mp, key, NULL);
+               }
                PyErr_SetObject(PyExc_KeyError, key);
+               return NULL;
+       }
        else
                Py_INCREF(v);
        return v;