]> granicus.if.org Git - python/commitdiff
* Increase test coverage.
authorRaymond Hettinger <python@rcn.com>
Tue, 28 Sep 2004 04:45:28 +0000 (04:45 +0000)
committerRaymond Hettinger <python@rcn.com>
Tue, 28 Sep 2004 04:45:28 +0000 (04:45 +0000)
* Have groupby() be careful about decreffing structure members.

Lib/test/test_itertools.py
Modules/itertoolsmodule.c

index 6ae6785449786fbb6e39ef556795cda99ac013ee..02f84b7e39820e57b7b37da7652fb0b77fc6dc62 100644 (file)
@@ -53,6 +53,10 @@ class TestBasicOps(unittest.TestCase):
         self.assertRaises(TypeError, count, 'a')
         c = count(sys.maxint-2)   # verify that rollover doesn't crash
         c.next(); c.next(); c.next(); c.next(); c.next()
+        c = count(3)
+        self.assertEqual(repr(c), 'count(3)')
+        c.next()
+        self.assertEqual(repr(c), 'count(4)')
 
     def test_cycle(self):
         self.assertEqual(take(10, cycle('abc')), list('abcabcabca'))
@@ -67,6 +71,7 @@ class TestBasicOps(unittest.TestCase):
         self.assertEqual([], list(groupby([], key=id)))
         self.assertRaises(TypeError, list, groupby('abc', []))
         self.assertRaises(TypeError, groupby, None)
+        self.assertRaises(TypeError, groupby, 'abc', lambda x:x, 10)
 
         # Check normal input
         s = [(0, 10, 20), (0, 11,21), (0,12,21), (1,13,21), (1,14,22),
@@ -199,6 +204,12 @@ class TestBasicOps(unittest.TestCase):
         self.assertRaises(TypeError, repeat)
         self.assertRaises(TypeError, repeat, None, 3, 4)
         self.assertRaises(TypeError, repeat, None, 'a')
+        r = repeat(1+0j)
+        self.assertEqual(repr(r), 'repeat((1+0j))')
+        r = repeat(1+0j, 5)
+        self.assertEqual(repr(r), 'repeat((1+0j), 5)')
+        list(r)
+        self.assertEqual(repr(r), 'repeat((1+0j), 0)')
 
     def test_imap(self):
         self.assertEqual(list(imap(operator.pow, range(3), range(1,7))),
@@ -275,6 +286,9 @@ class TestBasicOps(unittest.TestCase):
         self.assertRaises(TypeError, takewhile, operator.pow, [(4,5)], 'extra')
         self.assertRaises(TypeError, takewhile(10, [(4,5)]).next)
         self.assertRaises(ValueError, takewhile(errfunc, [(4,5)]).next)
+        t = takewhile(bool, [1, 1, 1, 0, 0, 0])
+        self.assertEqual(list(t), [1, 1, 1])
+        self.assertRaises(StopIteration, t.next)
 
     def test_dropwhile(self):
         data = [1, 3, 5, 20, 2, 4, 6, 8]
@@ -347,11 +361,26 @@ class TestBasicOps(unittest.TestCase):
         self.assertEqual(list(a), range(100,2000))
         self.assertEqual(list(c), range(2,2000))
 
+        # test values of n
+        self.assertRaises(TypeError, tee, 'abc', 'invalid')
+        for n in xrange(5):
+            result = tee('abc', n)
+            self.assertEqual(type(result), tuple)
+            self.assertEqual(len(result), n)
+            self.assertEqual(map(list, result), [list('abc')]*n)
+
         # tee pass-through to copyable iterator
         a, b = tee('abc')
         c, d = tee(a)
         self.assert_(a is c)
 
+        # test tee_new
+        t1, t2 = tee('abc')
+        tnew = type(t1)
+        self.assertRaises(TypeError, tnew)
+        self.assertRaises(TypeError, tnew, 10)
+        t3 = tnew(t1)
+        self.assert_(list(t1) == list(t2) == list(t3) == list('abc'))
 
     def test_StopIteration(self):
         self.assertRaises(StopIteration, izip().next)
index bf2c49311ffb36154f59e0ab40e020f1342f53ef..3da0258eefb21d4edb4207bf0efb95822feb26bf 100644 (file)
@@ -75,7 +75,7 @@ groupby_traverse(groupbyobject *gbo, visitproc visit, void *arg)
 static PyObject *
 groupby_next(groupbyobject *gbo)
 {
-       PyObject *newvalue, *newkey, *r, *grouper;
+       PyObject *newvalue, *newkey, *r, *grouper, *tmp;
 
        /* skip to next iteration group */
        for (;;) {
@@ -110,15 +110,19 @@ groupby_next(groupbyobject *gbo)
                        }
                }
 
-               Py_XDECREF(gbo->currkey);
+               tmp = gbo->currkey;
                gbo->currkey = newkey;
-               Py_XDECREF(gbo->currvalue);
+               Py_XDECREF(tmp);
+
+               tmp = gbo->currvalue;
                gbo->currvalue = newvalue;
+               Py_XDECREF(tmp);
        }
 
-       Py_XDECREF(gbo->tgtkey);
-       gbo->tgtkey = gbo->currkey;
        Py_INCREF(gbo->currkey);
+       tmp = gbo->tgtkey;
+       gbo->tgtkey = gbo->currkey;
+       Py_XDECREF(tmp);
 
        grouper = _grouper_create(gbo, gbo->tgtkey);
        if (grouper == NULL)