]> granicus.if.org Git - python/commitdiff
itertools.count() no longer limited to sys.maxint.
authorRaymond Hettinger <python@rcn.com>
Thu, 4 Oct 2007 00:20:27 +0000 (00:20 +0000)
committerRaymond Hettinger <python@rcn.com>
Thu, 4 Oct 2007 00:20:27 +0000 (00:20 +0000)
Doc/library/itertools.rst
Lib/test/test_itertools.py
Misc/NEWS
Modules/itertoolsmodule.c

index cb168a8323958ec0f6c9e55d1fc00d46e61cbda8..93e62f657a694c786fb42ebc41387561bb35f34c 100644 (file)
@@ -79,19 +79,15 @@ loops that truncate the stream.
 .. function:: count([n])
 
    Make an iterator that returns consecutive integers starting with *n*. If not
-   specified *n* defaults to zero.   Does not currently support python long
-   integers.  Often used as an argument to :func:`imap` to generate consecutive
-   data points. Also, used with :func:`izip` to add sequence numbers.  Equivalent
-   to::
+   specified *n* defaults to zero.   Often used as an argument to :func:`imap` to
+   generate consecutive data points. Also, used with :func:`izip` to add sequence
+   numbers.  Equivalent to::
 
       def count(n=0):
           while True:
               yield n
               n += 1
 
-   Note, :func:`count` does not check for overflow and will return negative numbers
-   after exceeding ``sys.maxint``.  This behavior may change in the future.
-
 
 .. function:: cycle(iterable)
 
index 6c362ad0e0c375119742fb17bf3da7c17896d1e1..3a370d965f5233359182ef754a4200df7bba276a 100644 (file)
@@ -52,9 +52,12 @@ class TestBasicOps(unittest.TestCase):
         self.assertEqual(zip('abc',count()), [('a', 0), ('b', 1), ('c', 2)])
         self.assertEqual(zip('abc',count(3)), [('a', 3), ('b', 4), ('c', 5)])
         self.assertEqual(take(2, zip('abc',count(3))), [('a', 3), ('b', 4)])
+        self.assertEqual(take(2, zip('abc',count(-1))), [('a', -1), ('b', 0)])
+        self.assertEqual(take(2, zip('abc',count(-3))), [('a', -3), ('b', -2)])
         self.assertRaises(TypeError, count, 2, 3)
         self.assertRaises(TypeError, count, 'a')
-        self.assertRaises(OverflowError, list, islice(count(maxsize-5), 10))
+        self.assertEqual(list(islice(count(maxsize-5), 10)), range(maxsize-5, maxsize+5))
+        self.assertEqual(list(islice(count(-maxsize-5), 10)), range(-maxsize-5, -maxsize+5))
         c = count(3)
         self.assertEqual(repr(c), 'count(3)')
         c.next()
@@ -63,6 +66,8 @@ class TestBasicOps(unittest.TestCase):
         self.assertEqual(repr(c), 'count(-9)')
         c.next()
         self.assertEqual(c.next(), -8)
+        for i in (-sys.maxint-5, -sys.maxint+5 ,-10, -1, 0, 10, sys.maxint-5, sys.maxint+5):
+            self.assertEqual(repr(count(i)), 'count(%r)' % i)
 
     def test_cycle(self):
         self.assertEqual(take(10, cycle('abc')), list('abcabcabca'))
index 01f860bc3ffe71302611fa0b1d6122bc55fec7bb..bd31af479dc6ad8bb1ab241654e65f3f3ec99b8b 100644 (file)
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -270,6 +270,9 @@ Core and builtins
 Library
 -------
 
+- itertools.count() is no longer bounded to LONG_MAX.  Formerly, it raised
+  an OverflowError.  Now, automatically shifts from ints to longs.
+
 - Patch #1541463: optimize performance of cgi.FieldStorage operations.
 
 - Decimal is fully updated to the latest Decimal Specification (v1.66).
index 35b77e784373a0f22ce46f087dbfceeeac2acf7c..fb54b1413e8b0784e28d7d678d1ab49797c7be08 100644 (file)
@@ -2032,6 +2032,7 @@ static PyTypeObject ifilterfalse_type = {
 typedef struct {
        PyObject_HEAD
        Py_ssize_t cnt;
+       PyObject *long_cnt;     /* Arbitrarily large count when cnt >= PY_SSIZE_T_MAX */
 } countobject;
 
 static PyTypeObject count_type;
@@ -2041,37 +2042,97 @@ count_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
 {
        countobject *lz;
        Py_ssize_t cnt = 0;
+       PyObject *cnt_arg = NULL;
+       PyObject *long_cnt = NULL;
 
        if (type == &count_type && !_PyArg_NoKeywords("count()", kwds))
                return NULL;
 
-       if (!PyArg_ParseTuple(args, "|n:count", &cnt))
+       if (!PyArg_UnpackTuple(args, "count", 0, 1, &cnt_arg))
                return NULL;
 
+       if (cnt_arg != NULL) {
+               cnt = PyInt_AsSsize_t(cnt_arg);
+               if (cnt == -1 && PyErr_Occurred()) {
+                       PyErr_Clear();
+                       if (!PyLong_Check(cnt_arg)) {
+                               PyErr_SetString(PyExc_TypeError, "an integer is required");
+                               return NULL;
+                       }
+                       long_cnt = cnt_arg;
+                       Py_INCREF(long_cnt);
+                       cnt = PY_SSIZE_T_MAX;
+               }
+       }
+
        /* create countobject structure */
        lz = (countobject *)PyObject_New(countobject, &count_type);
-       if (lz == NULL)
+       if (lz == NULL) {
+               Py_XDECREF(long_cnt);
                return NULL;
+       }
        lz->cnt = cnt;
+       lz->long_cnt = long_cnt;
 
        return (PyObject *)lz;
 }
 
+static void
+count_dealloc(countobject *lz)
+{
+       Py_XDECREF(lz->long_cnt); 
+       PyObject_Del(lz);
+}
+
+static PyObject *
+count_nextlong(countobject *lz)
+{
+       static PyObject *one = NULL;
+       PyObject *cnt;
+       PyObject *stepped_up;
+
+       if (lz->long_cnt == NULL) {
+               lz->long_cnt = PyInt_FromSsize_t(PY_SSIZE_T_MAX);
+               if (lz->long_cnt == NULL)
+                       return NULL;
+       }
+       if (one == NULL) {
+               one = PyInt_FromLong(1);
+               if (one == NULL)
+                       return NULL;
+       }
+       cnt = lz->long_cnt;
+       assert(cnt != NULL);
+       stepped_up = PyNumber_Add(cnt, one);
+       if (stepped_up == NULL)
+               return NULL;
+       lz->long_cnt = stepped_up;
+       return cnt;
+}
+
 static PyObject *
 count_next(countobject *lz)
 {
-        if (lz->cnt == PY_SSIZE_T_MAX) {
-                PyErr_SetString(PyExc_OverflowError,
-                        "cannot count beyond PY_SSIZE_T_MAX");                
-                return NULL;         
-        }
+        if (lz->cnt == PY_SSIZE_T_MAX)
+               return count_nextlong(lz);
        return PyInt_FromSsize_t(lz->cnt++);
 }
 
 static PyObject *
 count_repr(countobject *lz)
 {
-       return PyString_FromFormat("count(%zd)", lz->cnt);
+       PyObject *cnt_repr;
+       PyObject *result;
+
+        if (lz->cnt != PY_SSIZE_T_MAX)
+               return PyString_FromFormat("count(%zd)", lz->cnt);
+
+       cnt_repr = PyObject_Repr(lz->long_cnt);
+       if (cnt_repr == NULL)
+               return NULL;
+       result = PyString_FromFormat("count(%s)", PyString_AS_STRING(cnt_repr));
+       Py_DECREF(cnt_repr);
+       return result;
 }
 
 PyDoc_STRVAR(count_doc,
@@ -2086,7 +2147,7 @@ static PyTypeObject count_type = {
        sizeof(countobject),            /* tp_basicsize */
        0,                              /* tp_itemsize */
        /* methods */
-       (destructor)PyObject_Del,       /* tp_dealloc */
+       (destructor)count_dealloc,      /* tp_dealloc */
        0,                              /* tp_print */
        0,                              /* tp_getattr */
        0,                              /* tp_setattr */