From 426e052a4f60f94f3d97c95ce8477b967a975b3b Mon Sep 17 00:00:00 2001 From: Raymond Hettinger Date: Mon, 3 Jan 2011 02:12:02 +0000 Subject: [PATCH] Make C helper function more closely match the pure python version, and add tests. --- Lib/test/test_collections.py | 15 +++++++- Modules/_collectionsmodule.c | 71 +++++++++++++++++++++++++----------- 2 files changed, 63 insertions(+), 23 deletions(-) diff --git a/Lib/test/test_collections.py b/Lib/test/test_collections.py index deda1cda32..d785fcbf85 100644 --- a/Lib/test/test_collections.py +++ b/Lib/test/test_collections.py @@ -3,7 +3,7 @@ import unittest, doctest, operator import inspect from test import support -from collections import namedtuple, Counter, OrderedDict +from collections import namedtuple, Counter, OrderedDict, _count_elements from test import mapping_tests import pickle, copy from random import randrange, shuffle @@ -775,6 +775,19 @@ class TestCounter(unittest.TestCase): c.subtract('aaaabbcce') self.assertEqual(c, Counter(a=-1, b=0, c=-1, d=1, e=-1)) + def test_helper_function(self): + # two paths, one for real dicts and one for other mappings + elems = list('abracadabra') + + d = dict() + _count_elements(d, elems) + self.assertEqual(d, {'a': 5, 'r': 2, 'b': 2, 'c': 1, 'd': 1}) + + m = OrderedDict() + _count_elements(m, elems) + self.assertEqual(m, + OrderedDict([('a', 5), ('b', 2), ('r', 2), ('c', 1), ('d', 1)])) + class TestOrderedDict(unittest.TestCase): def test_init(self): diff --git a/Modules/_collectionsmodule.c b/Modules/_collectionsmodule.c index 684b8738c5..f4a2c8bd6e 100644 --- a/Modules/_collectionsmodule.c +++ b/Modules/_collectionsmodule.c @@ -1536,41 +1536,68 @@ _count_elements(PyObject *self, PyObject *args) if (!PyArg_UnpackTuple(args, "_count_elements", 2, 2, &mapping, &iterable)) return NULL; - if (!PyDict_Check(mapping)) { - PyErr_SetString(PyExc_TypeError, - "Expected mapping argument to be a dictionary"); - return NULL; - } - it = PyObject_GetIter(iterable); if (it == NULL) return NULL; + one = PyLong_FromLong(1); if (one == NULL) { Py_DECREF(it); return NULL; } - while (1) { - key = PyIter_Next(it); - if (key == NULL) { - if (PyErr_Occurred() && PyErr_ExceptionMatches(PyExc_StopIteration)) - PyErr_Clear(); - break; + + if (PyDict_CheckExact(mapping)) { + while (1) { + key = PyIter_Next(it); + if (key == NULL) { + if (PyErr_Occurred() && PyErr_ExceptionMatches(PyExc_StopIteration)) + PyErr_Clear(); + else + break; + } + oldval = PyDict_GetItem(mapping, key); + if (oldval == NULL) { + if (PyDict_SetItem(mapping, key, one) == -1) + break; + } else { + newval = PyNumber_Add(oldval, one); + if (newval == NULL) + break; + if (PyDict_SetItem(mapping, key, newval) == -1) + break; + Py_CLEAR(newval); + } + Py_DECREF(key); } - oldval = PyDict_GetItem(mapping, key); - if (oldval == NULL) { - if (PyDict_SetItem(mapping, key, one) == -1) - break; - } else { - newval = PyNumber_Add(oldval, one); - if (newval == NULL) - break; - if (PyDict_SetItem(mapping, key, newval) == -1) + } else { + while (1) { + key = PyIter_Next(it); + if (key == NULL) { + if (PyErr_Occurred() && PyErr_ExceptionMatches(PyExc_StopIteration)) + PyErr_Clear(); + else + break; + } + oldval = PyObject_GetItem(mapping, key); + if (oldval == NULL) { + if (!PyErr_Occurred() || !PyErr_ExceptionMatches(PyExc_KeyError)) + break; + PyErr_Clear(); + Py_INCREF(one); + newval = one; + } else { + newval = PyNumber_Add(oldval, one); + Py_DECREF(oldval); + if (newval == NULL) + break; + } + if (PyObject_SetItem(mapping, key, newval) == -1) break; Py_CLEAR(newval); + Py_DECREF(key); } - Py_DECREF(key); } + Py_DECREF(it); Py_XDECREF(key); Py_XDECREF(newval); -- 2.40.0