]> granicus.if.org Git - python/commitdiff
Issue #23996: Added _PyGen_SetStopIterationValue for safe raising
authorSerhiy Storchaka <storchaka@gmail.com>
Sun, 6 Nov 2016 16:44:42 +0000 (18:44 +0200)
committerSerhiy Storchaka <storchaka@gmail.com>
Sun, 6 Nov 2016 16:44:42 +0000 (18:44 +0200)
StopIteration with value. More safely handle non-normalized exceptions
in -_PyGen_FetchStopIterationValue.

Include/genobject.h
Lib/test/test_coroutines.py
Lib/test/test_generators.py
Objects/genobject.c

index 1ff32a8eafac2ab5bbd76553d0402c9af74eaf24..61e708a73ab610dd8562e618e455ab350ed7623c 100644 (file)
@@ -41,6 +41,7 @@ PyAPI_FUNC(PyObject *) PyGen_New(struct _frame *);
 PyAPI_FUNC(PyObject *) PyGen_NewWithQualName(struct _frame *,
     PyObject *name, PyObject *qualname);
 PyAPI_FUNC(int) PyGen_NeedsFinalizing(PyGenObject *);
+PyAPI_FUNC(int) _PyGen_SetStopIterationValue(PyObject *);
 PyAPI_FUNC(int) _PyGen_FetchStopIterationValue(PyObject **);
 PyObject *_PyGen_Send(PyGenObject *, PyObject *);
 PyObject *_PyGen_yf(PyGenObject *);
index e52654c14d0183704784f82d6f45bae996ace432..4a327b5ba9f16835e6a0de313ebcf461bb153b0a 100644 (file)
@@ -710,6 +710,21 @@ class CoroutineTest(unittest.TestCase):
             coro.close()
             self.assertEqual(CHK, 1)
 
+    def test_coro_wrapper_send_tuple(self):
+        async def foo():
+            return (10,)
+
+        result = run_async__await__(foo())
+        self.assertEqual(result, ([], (10,)))
+
+    def test_coro_wrapper_send_stop_iterator(self):
+        async def foo():
+            return StopIteration(10)
+
+        result = run_async__await__(foo())
+        self.assertIsInstance(result[1], StopIteration)
+        self.assertEqual(result[1].value, 10)
+
     def test_cr_await(self):
         @types.coroutine
         def a():
@@ -1537,6 +1552,52 @@ class CoroutineTest(unittest.TestCase):
                 warnings.simplefilter("error")
                 run_async(foo())
 
+    def test_for_tuple(self):
+        class Done(Exception): pass
+
+        class AIter(tuple):
+            i = 0
+            def __aiter__(self):
+                return self
+            async def __anext__(self):
+                if self.i >= len(self):
+                    raise StopAsyncIteration
+                self.i += 1
+                return self[self.i - 1]
+
+        result = []
+        async def foo():
+            async for i in AIter([42]):
+                result.append(i)
+            raise Done
+
+        with self.assertRaises(Done):
+            foo().send(None)
+        self.assertEqual(result, [42])
+
+    def test_for_stop_iteration(self):
+        class Done(Exception): pass
+
+        class AIter(StopIteration):
+            i = 0
+            def __aiter__(self):
+                return self
+            async def __anext__(self):
+                if self.i:
+                    raise StopAsyncIteration
+                self.i += 1
+                return self.value
+
+        result = []
+        async def foo():
+            async for i in AIter(42):
+                result.append(i)
+            raise Done
+
+        with self.assertRaises(Done):
+            foo().send(None)
+        self.assertEqual(result, [42])
+
     def test_copy(self):
         async def func(): pass
         coro = func()
index 3f82462478d5e1873e3904208a02c8ff45aeb3b5..cd6a43d753e6fa386943e2c18eb73addf9bf6d13 100644 (file)
@@ -277,6 +277,27 @@ class ExceptionTest(unittest.TestCase):
             # hence no warning.
             next(g)
 
+    def test_return_tuple(self):
+        def g():
+            return (yield 1)
+
+        gen = g()
+        self.assertEqual(next(gen), 1)
+        with self.assertRaises(StopIteration) as cm:
+            gen.send((2,))
+        self.assertEqual(cm.exception.value, (2,))
+
+    def test_return_stopiteration(self):
+        def g():
+            return (yield 1)
+
+        gen = g()
+        self.assertEqual(next(gen), 1)
+        with self.assertRaises(StopIteration) as cm:
+            gen.send(StopIteration(2))
+        self.assertIsInstance(cm.exception.value, StopIteration)
+        self.assertEqual(cm.exception.value.value, 2)
+
 
 class YieldFromTests(unittest.TestCase):
     def test_generator_gi_yieldfrom(self):
index 9172e6a01025948977c14626b89820eadf2ff83a..0d5d54fdbafcd4bbc87b7615122ca6906c3e0b51 100644 (file)
@@ -154,12 +154,7 @@ gen_send_ex(PyGenObject *gen, PyObject *arg, int exc, int closing)
             /* Delay exception instantiation if we can */
             PyErr_SetNone(PyExc_StopIteration);
         } else {
-            PyObject *e = PyObject_CallFunctionObjArgs(
-                               PyExc_StopIteration, result, NULL);
-            if (e != NULL) {
-                PyErr_SetObject(PyExc_StopIteration, e);
-                Py_DECREF(e);
-            }
+            _PyGen_SetStopIterationValue(result);
         }
         Py_CLEAR(result);
     }
@@ -459,6 +454,43 @@ gen_iternext(PyGenObject *gen)
     return gen_send_ex(gen, NULL, 0, 0);
 }
 
+/*
+ * Set StopIteration with specified value.  Value can be arbitrary object
+ * or NULL.
+ *
+ * Returns 0 if StopIteration is set and -1 if any other exception is set.
+ */
+int
+_PyGen_SetStopIterationValue(PyObject *value)
+{
+    PyObject *e;
+
+    if (value == NULL ||
+        (!PyTuple_Check(value) &&
+         !PyObject_TypeCheck(value, (PyTypeObject *) PyExc_StopIteration)))
+    {
+        /* Delay exception instantiation if we can */
+        PyErr_SetObject(PyExc_StopIteration, value);
+        return 0;
+    }
+    /* Construct an exception instance manually with
+     * PyObject_CallFunctionObjArgs and pass it to PyErr_SetObject.
+     *
+     * We do this to handle a situation when "value" is a tuple, in which
+     * case PyErr_SetObject would set the value of StopIteration to
+     * the first element of the tuple.
+     *
+     * (See PyErr_SetObject/_PyErr_CreateException code for details.)
+     */
+    e = PyObject_CallFunctionObjArgs(PyExc_StopIteration, value, NULL);
+    if (e == NULL) {
+        return -1;
+    }
+    PyErr_SetObject(PyExc_StopIteration, e);
+    Py_DECREF(e);
+    return 0;
+}
+
 /*
  *   If StopIteration exception is set, fetches its 'value'
  *   attribute if any, otherwise sets pvalue to None.
@@ -469,7 +501,8 @@ gen_iternext(PyGenObject *gen)
  */
 
 int
-_PyGen_FetchStopIterationValue(PyObject **pvalue) {
+_PyGen_FetchStopIterationValue(PyObject **pvalue)
+{
     PyObject *et, *ev, *tb;
     PyObject *value = NULL;
 
@@ -481,8 +514,15 @@ _PyGen_FetchStopIterationValue(PyObject **pvalue) {
                 value = ((PyStopIterationObject *)ev)->value;
                 Py_INCREF(value);
                 Py_DECREF(ev);
-            } else if (et == PyExc_StopIteration) {
-                /* avoid normalisation and take ev as value */
+            } else if (et == PyExc_StopIteration && !PyTuple_Check(ev)) {
+                /* Avoid normalisation and take ev as value.
+                 *
+                 * Normalization is required if the value is a tuple, in
+                 * that case the value of StopIteration would be set to
+                 * the first element of the tuple.
+                 *
+                 * (See _PyErr_CreateException code for details.)
+                 */
                 value = ev;
             } else {
                 /* normalisation required */
@@ -1012,7 +1052,7 @@ typedef struct {
 static PyObject *
 aiter_wrapper_iternext(PyAIterWrapper *aw)
 {
-    PyErr_SetObject(PyExc_StopIteration, aw->aw_aiter);
+    _PyGen_SetStopIterationValue(aw->aw_aiter);
     return NULL;
 }