]> granicus.if.org Git - python/commitdiff
SF #950057: itertools.chain doesn't "process" exceptions as they occur
authorRaymond Hettinger <python@rcn.com>
Sat, 8 May 2004 19:49:42 +0000 (19:49 +0000)
committerRaymond Hettinger <python@rcn.com>
Sat, 8 May 2004 19:49:42 +0000 (19:49 +0000)
Both cycle() and chain() were handling exceptions only when switching
input sources.  The patch makes the handle more immediate.

Will backport.

Lib/test/test_itertools.py
Modules/itertoolsmodule.c

index 73e880971d91bd627f13f04c59a51512a699c99e..54e46e138b04a9f414d97900b90291204975dd0e 100644 (file)
@@ -644,6 +644,36 @@ class RegressionTests(unittest.TestCase):
         self.assertEqual(first, second)
 
 
+    def test_sf_950057(self):
+        # Make sure that chain() and cycle() catch exceptions immediately
+        # rather than when shifting between input sources
+
+        def gen1():
+            hist.append(0)
+            yield 1
+            hist.append(1)
+            assert False
+            hist.append(2)
+
+        def gen2(x):
+            hist.append(3)
+            yield 2
+            hist.append(4)
+            if x:
+                raise StopIteration
+
+        hist = []
+        self.assertRaises(AssertionError, list, chain(gen1(), gen2(False)))
+        self.assertEqual(hist, [0,1])
+
+        hist = []
+        self.assertRaises(AssertionError, list, chain(gen1(), gen2(True)))
+        self.assertEqual(hist, [0,1])
+
+        hist = []
+        self.assertRaises(AssertionError, list, cycle(gen1()))
+        self.assertEqual(hist, [0,1])
+
 libreftest = """ Doctest for examples in the library reference: libitertools.tex
 
 
index 4ce46430df412323e82fca59b194166201fe0348..3515bc6058bc1448ecc2809ad6bd48306f1cad9e 100644 (file)
@@ -699,6 +699,12 @@ cycle_next(cycleobject *lz)
                                PyList_Append(lz->saved, item);
                        return item;
                }
+               if (PyErr_Occurred()) {
+                       if (PyErr_ExceptionMatches(PyExc_StopIteration))
+                               PyErr_Clear();
+                       else
+                               return NULL;
+               }
                if (PyList_Size(lz->saved) == 0) 
                        return NULL;
                it = PyObject_GetIter(lz->saved);
@@ -1658,6 +1664,12 @@ chain_next(chainobject *lz)
                item = PyIter_Next(it);
                if (item != NULL)
                        return item;
+               if (PyErr_Occurred()) {
+                       if (PyErr_ExceptionMatches(PyExc_StopIteration))
+                               PyErr_Clear();
+                       else
+                               return NULL;
+               }
                lz->iternum++;
        }
        return NULL;