]> granicus.if.org Git - python/commitdiff
Issue 24342: Let wrapper set by sys.set_coroutine_wrapper fail gracefully
authorYury Selivanov <yselivanov@sprymix.com>
Tue, 2 Jun 2015 22:43:51 +0000 (18:43 -0400)
committerYury Selivanov <yselivanov@sprymix.com>
Tue, 2 Jun 2015 22:43:51 +0000 (18:43 -0400)
Doc/library/sys.rst
Include/ceval.h
Include/pystate.h
Lib/test/test_coroutines.py
Python/ceval.c
Python/pystate.c

index 3e8fd82aa7272940473f932d6803c39ab609bc93..f9733b2ed63b43f99b90e3c78e85426912df4be4 100644 (file)
@@ -1085,6 +1085,20 @@ always available.
    If called twice, the new wrapper replaces the previous one.  The function
    is thread-specific.
 
+   The *wrapper* callable cannot define new coroutines directly or indirectly::
+
+        def wrapper(coro):
+            async def wrap(coro):
+                return await coro
+            return wrap(coro)
+        sys.set_coroutine_wrapper(wrapper)
+
+        async def foo(): pass
+
+        # The following line will fail with a RuntimeError, because
+        # `wrapper` creates a `wrap(coro)` coroutine:
+        foo()
+
    See also :func:`get_coroutine_wrapper`.
 
    .. versionadded:: 3.5
index e5585945ae174d3a63e2313615f5fd35ae975e94..9f4d3f1998ce076920c4118d26c88005f825483c 100644 (file)
@@ -23,8 +23,9 @@ PyAPI_FUNC(PyObject *) PyEval_CallMethod(PyObject *obj,
 #ifndef Py_LIMITED_API
 PyAPI_FUNC(void) PyEval_SetProfile(Py_tracefunc, PyObject *);
 PyAPI_FUNC(void) PyEval_SetTrace(Py_tracefunc, PyObject *);
-PyAPI_FUNC(void) _PyEval_SetCoroutineWrapper(PyObject *wrapper);
+PyAPI_FUNC(void) _PyEval_SetCoroutineWrapper(PyObject *);
 PyAPI_FUNC(PyObject *) _PyEval_GetCoroutineWrapper(void);
+PyAPI_FUNC(PyObject *) _PyEval_ApplyCoroutineWrapper(PyObject *);
 #endif
 
 struct _frame; /* Avoid including frameobject.h */
index 2ee81df7b587bb9955ba1cb04757d61321ee5466..a2fd8031d04d3f3a453d285a9ff588d1b637c167 100644 (file)
@@ -135,6 +135,7 @@ typedef struct _ts {
     void *on_delete_data;
 
     PyObject *coroutine_wrapper;
+    int in_coroutine_wrapper;
 
     /* XXX signal handlers should also be here */
 
index e79896a9b8e954cd49a1f2b3a6c3b5a4f7dcee9c..670852d20c06eee1b1628b89200c23f49edd6613 100644 (file)
@@ -995,6 +995,26 @@ class SysSetCoroWrapperTest(unittest.TestCase):
             sys.set_coroutine_wrapper(1)
         self.assertIsNone(sys.get_coroutine_wrapper())
 
+    def test_set_wrapper_3(self):
+        async def foo():
+            return 'spam'
+
+        def wrapper(coro):
+            async def wrap(coro):
+                return await coro
+            return wrap(coro)
+
+        sys.set_coroutine_wrapper(wrapper)
+        try:
+            with self.assertRaisesRegex(
+                RuntimeError,
+                "coroutine wrapper.*\.wrapper at 0x.*attempted to "
+                "recursively wrap <coroutine.*\.wrap"):
+
+                foo()
+        finally:
+            sys.set_coroutine_wrapper(None)
+
 
 class CAPITest(unittest.TestCase):
 
index bb2c0b96a517d1327e23ac797e8e1a4bfcf276ac..2a1db17b45626b202549df50192e30550a667c98 100644 (file)
@@ -3921,7 +3921,6 @@ _PyEval_EvalCodeWithName(PyObject *_co, PyObject *globals, PyObject *locals,
 
     if (co->co_flags & CO_GENERATOR) {
         PyObject *gen;
-        PyObject *coroutine_wrapper;
 
         /* Don't need to keep the reference to f_back, it will be set
          * when the generator is resumed. */
@@ -3935,14 +3934,9 @@ _PyEval_EvalCodeWithName(PyObject *_co, PyObject *globals, PyObject *locals,
         if (gen == NULL)
             return NULL;
 
-        if (co->co_flags & (CO_COROUTINE | CO_ITERABLE_COROUTINE)) {
-            coroutine_wrapper = _PyEval_GetCoroutineWrapper();
-            if (coroutine_wrapper != NULL) {
-                PyObject *wrapped =
-                            PyObject_CallFunction(coroutine_wrapper, "N", gen);
-                gen = wrapped;
-            }
-        }
+        if (co->co_flags & (CO_COROUTINE | CO_ITERABLE_COROUTINE))
+            return _PyEval_ApplyCoroutineWrapper(gen);
+
         return gen;
     }
 
@@ -4407,6 +4401,33 @@ _PyEval_GetCoroutineWrapper(void)
     return tstate->coroutine_wrapper;
 }
 
+PyObject *
+_PyEval_ApplyCoroutineWrapper(PyObject *gen)
+{
+    PyObject *wrapped;
+    PyThreadState *tstate = PyThreadState_GET();
+    PyObject *wrapper = tstate->coroutine_wrapper;
+
+    if (tstate->in_coroutine_wrapper) {
+        assert(wrapper != NULL);
+        PyErr_Format(PyExc_RuntimeError,
+                     "coroutine wrapper %.150R attempted "
+                     "to recursively wrap %.150R",
+                     wrapper,
+                     gen);
+        return NULL;
+    }
+
+    if (wrapper == NULL) {
+        return gen;
+    }
+
+    tstate->in_coroutine_wrapper = 1;
+    wrapped = PyObject_CallFunction(wrapper, "N", gen);
+    tstate->in_coroutine_wrapper = 0;
+    return wrapped;
+}
+
 PyObject *
 PyEval_GetBuiltins(void)
 {
index 4ac05d6625686e688e87a943f722a7feb55c07d7..7e0267ae1d04743226fd6eabccad934a6b5cdc83 100644 (file)
@@ -213,6 +213,7 @@ new_threadstate(PyInterpreterState *interp, int init)
         tstate->on_delete_data = NULL;
 
         tstate->coroutine_wrapper = NULL;
+        tstate->in_coroutine_wrapper = 0;
 
         if (init)
             _PyThreadState_Init(tstate);