]> granicus.if.org Git - python/commitdiff
bpo-30346: An iterator produced by the itertools.groupby() iterator (#1569)
authorSerhiy Storchaka <storchaka@gmail.com>
Sun, 24 Sep 2017 10:36:11 +0000 (13:36 +0300)
committerGitHub <noreply@github.com>
Sun, 24 Sep 2017 10:36:11 +0000 (13:36 +0300)
now becames exhausted after advancing the groupby iterator.

Doc/library/itertools.rst
Lib/test/test_itertools.py
Misc/NEWS.d/next/Library/2017-09-24-13-08-46.bpo-30346.Csse77.rst [new file with mode: 0644]
Modules/itertoolsmodule.c

index c989e464200d8ba791c4f983c22c7fd1a15ba64e..530c29dec4a8773571bdcdd8f97762c5af8989fe 100644 (file)
@@ -401,13 +401,14 @@ loops that truncate the stream.
           def __iter__(self):
               return self
           def __next__(self):
+              self.id = object()
               while self.currkey == self.tgtkey:
                   self.currvalue = next(self.it)    # Exit on StopIteration
                   self.currkey = self.keyfunc(self.currvalue)
               self.tgtkey = self.currkey
-              return (self.currkey, self._grouper(self.tgtkey))
-          def _grouper(self, tgtkey):
-              while self.currkey == tgtkey:
+              return (self.currkey, self._grouper(self.tgtkey, self.id))
+          def _grouper(self, tgtkey, id):
+              while self.id is id and self.currkey == tgtkey:
                   yield self.currvalue
                   try:
                       self.currvalue = next(self.it)
index 50cf1488ec6184b24ff368a4c329ddf4ded6fc85..8353e68977d222b4713df539ab1cfe7f9589b5a6 100644 (file)
@@ -751,6 +751,26 @@ class TestBasicOps(unittest.TestCase):
         self.assertEqual(set(keys), expectedkeys)
         self.assertEqual(len(keys), len(expectedkeys))
 
+        # Check case where inner iterator is used after advancing the groupby
+        # iterator
+        s = list(zip('AABBBAAAA', range(9)))
+        it = groupby(s, testR)
+        _, g1 = next(it)
+        _, g2 = next(it)
+        _, g3 = next(it)
+        self.assertEqual(list(g1), [])
+        self.assertEqual(list(g2), [])
+        self.assertEqual(next(g3), ('A', 5))
+        list(it)  # exhaust the groupby iterator
+        self.assertEqual(list(g3), [])
+
+        for proto in range(pickle.HIGHEST_PROTOCOL + 1):
+            it = groupby(s, testR)
+            _, g = next(it)
+            next(it)
+            next(it)
+            self.assertEqual(list(pickle.loads(pickle.dumps(g, proto))), [])
+
         # Exercise pipes and filters style
         s = 'abracadabra'
         # sort s | uniq
diff --git a/Misc/NEWS.d/next/Library/2017-09-24-13-08-46.bpo-30346.Csse77.rst b/Misc/NEWS.d/next/Library/2017-09-24-13-08-46.bpo-30346.Csse77.rst
new file mode 100644 (file)
index 0000000..81ad053
--- /dev/null
@@ -0,0 +1,2 @@
+An iterator produced by itertools.groupby() iterator now becames exhausted
+after advancing the groupby iterator.
index 48e6c35db4fe0434eb75e88f4541c4d3c86ba6fd..2ac5ab24ec8abf2e8b829a45de536217e8cb9516 100644 (file)
@@ -17,6 +17,7 @@ typedef struct {
     PyObject *tgtkey;
     PyObject *currkey;
     PyObject *currvalue;
+    const void *currgrouper;  /* borrowed reference */
 } groupbyobject;
 
 static PyTypeObject groupby_type;
@@ -77,6 +78,7 @@ groupby_next(groupbyobject *gbo)
 {
     PyObject *newvalue, *newkey, *r, *grouper;
 
+    gbo->currgrouper = NULL;
     /* skip to next iteration group */
     for (;;) {
         if (gbo->currkey == NULL)
@@ -255,6 +257,7 @@ _grouper_create(groupbyobject *parent, PyObject *tgtkey)
     Py_INCREF(parent);
     igo->tgtkey = tgtkey;
     Py_INCREF(tgtkey);
+    parent->currgrouper = igo;  /* borrowed reference */
 
     PyObject_GC_Track(igo);
     return (PyObject *)igo;
@@ -284,6 +287,8 @@ _grouper_next(_grouperobject *igo)
     PyObject *newvalue, *newkey, *r;
     int rcmp;
 
+    if (gbo->currgrouper != igo)
+        return NULL;
     if (gbo->currvalue == NULL) {
         newvalue = PyIter_Next(gbo->it);
         if (newvalue == NULL)
@@ -321,6 +326,9 @@ _grouper_next(_grouperobject *igo)
 static PyObject *
 _grouper_reduce(_grouperobject *lz)
 {
+    if (((groupbyobject *)lz->parent)->currgrouper != lz) {
+        return Py_BuildValue("N(())", _PyObject_GetBuiltin("iter"));
+    }
     return Py_BuildValue("O(OO)", Py_TYPE(lz), lz->parent, lz->tgtkey);
 }