]> granicus.if.org Git - python/commitdiff
Check for a common user error with defaultdict().
authorRaymond Hettinger <python@rcn.com>
Wed, 7 Feb 2007 21:42:17 +0000 (21:42 +0000)
committerRaymond Hettinger <python@rcn.com>
Wed, 7 Feb 2007 21:42:17 +0000 (21:42 +0000)
Lib/test/test_defaultdict.py
Modules/collectionsmodule.c

index 134b5a8cb2fafd905bd8e61eee04aad41da6c120..1834f9071ae3f1d6abf791674d7c8d3c88fd63c3 100644 (file)
@@ -47,6 +47,7 @@ class TestDefaultDict(unittest.TestCase):
             self.assertEqual(err.args, (15,))
         else:
             self.fail("d2[15] didn't raise KeyError")
+        self.assertRaises(TypeError, defaultdict, 1)
 
     def test_missing(self):
         d1 = defaultdict()
@@ -60,10 +61,10 @@ class TestDefaultDict(unittest.TestCase):
         self.assertEqual(repr(d1), "defaultdict(None, {})")
         d1[11] = 41
         self.assertEqual(repr(d1), "defaultdict(None, {11: 41})")
-        d2 = defaultdict(0)
-        self.assertEqual(d2.default_factory, 0)
+        d2 = defaultdict(int)
+        self.assertEqual(d2.default_factory, int)
         d2[12] = 42
-        self.assertEqual(repr(d2), "defaultdict(0, {12: 42})")
+        self.assertEqual(repr(d2), "defaultdict(<type 'int'>, {12: 42})")
         def foo(): return 43
         d3 = defaultdict(foo)
         self.assert_(d3.default_factory is foo)
index a4cdcfaf52dfb456ad6b4cfa2f174d8479e2965a..f98bd490323dcd2db35f44ee1f1cb51ae052ffe9 100644 (file)
@@ -1252,8 +1252,14 @@ defdict_init(PyObject *self, PyObject *args, PyObject *kwds)
                newargs = PyTuple_New(0);
        else {
                Py_ssize_t n = PyTuple_GET_SIZE(args);
-               if (n > 0)
+               if (n > 0) {
                        newdefault = PyTuple_GET_ITEM(args, 0);
+                       if (!PyCallable_Check(newdefault)) {
+                               PyErr_SetString(PyExc_TypeError,
+                                       "first argument must be callable");                           
+                               return -1;
+                       }
+               }
                newargs = PySequence_GetSlice(args, 1, n);
        }
        if (newargs == NULL)