]> granicus.if.org Git - python/commitdiff
Issue 24017: fix for "async with" refcounting
authorNick Coghlan <ncoghlan@gmail.com>
Wed, 13 May 2015 05:54:02 +0000 (15:54 +1000)
committerNick Coghlan <ncoghlan@gmail.com>
Wed, 13 May 2015 05:54:02 +0000 (15:54 +1000)
* adds missing INCREF in WITH_CLEANUP_START
* adds missing DECREF in WITH_CLEANUP_FINISH
* adds several new tests Yury created while investigating this

Lib/test/test_coroutines.py
Python/ceval.c

index aa2a5e8ef57f88f58184eafded58c3fcb8c1d43f..6a6f868c129246755a515b6217db3ba1ef8ead0e 100644 (file)
@@ -492,6 +492,31 @@ class CoroutineTest(unittest.TestCase):
             run_async(foo())
 
     def test_with_7(self):
+        class CM:
+            async def __aenter__(self):
+                return self
+
+            def __aexit__(self, *e):
+                return 444
+
+        async def foo():
+            async with CM():
+                1/0
+
+        try:
+            run_async(foo())
+        except TypeError as exc:
+            self.assertRegex(
+                exc.args[0], "object int can't be used in 'await' expression")
+            self.assertTrue(exc.__context__ is not None)
+            self.assertTrue(isinstance(exc.__context__, ZeroDivisionError))
+        else:
+            self.fail('invalid asynchronous context manager did not fail')
+
+
+    def test_with_8(self):
+        CNT = 0
+
         class CM:
             async def __aenter__(self):
                 return self
@@ -500,14 +525,105 @@ class CoroutineTest(unittest.TestCase):
                 return 456
 
         async def foo():
+            nonlocal CNT
             async with CM():
-                pass
+                CNT += 1
+
 
         with self.assertRaisesRegex(
             TypeError, "object int can't be used in 'await' expression"):
 
             run_async(foo())
 
+        self.assertEqual(CNT, 1)
+
+
+    def test_with_9(self):
+        CNT = 0
+
+        class CM:
+            async def __aenter__(self):
+                return self
+
+            async def __aexit__(self, *e):
+                1/0
+
+        async def foo():
+            nonlocal CNT
+            async with CM():
+                CNT += 1
+
+        with self.assertRaises(ZeroDivisionError):
+            run_async(foo())
+
+        self.assertEqual(CNT, 1)
+
+    def test_with_10(self):
+        CNT = 0
+
+        class CM:
+            async def __aenter__(self):
+                return self
+
+            async def __aexit__(self, *e):
+                1/0
+
+        async def foo():
+            nonlocal CNT
+            async with CM():
+                async with CM():
+                    raise RuntimeError
+
+        try:
+            run_async(foo())
+        except ZeroDivisionError as exc:
+            self.assertTrue(exc.__context__ is not None)
+            self.assertTrue(isinstance(exc.__context__, ZeroDivisionError))
+            self.assertTrue(isinstance(exc.__context__.__context__,
+                                       RuntimeError))
+        else:
+            self.fail('exception from __aexit__ did not propagate')
+
+    def test_with_11(self):
+        CNT = 0
+
+        class CM:
+            async def __aenter__(self):
+                raise NotImplementedError
+
+            async def __aexit__(self, *e):
+                1/0
+
+        async def foo():
+            nonlocal CNT
+            async with CM():
+                raise RuntimeError
+
+        try:
+            run_async(foo())
+        except NotImplementedError as exc:
+            self.assertTrue(exc.__context__ is None)
+        else:
+            self.fail('exception from __aenter__ did not propagate')
+
+    def test_with_12(self):
+        CNT = 0
+
+        class CM:
+            async def __aenter__(self):
+                return self
+
+            async def __aexit__(self, *e):
+                return True
+
+        async def foo():
+            nonlocal CNT
+            async with CM() as cm:
+                self.assertIs(cm.__class__, CM)
+                raise RuntimeError
+
+        run_async(foo())
+
     def test_for_1(self):
         aiter_calls = 0
 
index 77085c25a03158be9986c96941051007a7f24be7..afb0f89aa8d0485c2d37a9d0c6d8485a20ed2fb4 100644 (file)
@@ -3156,6 +3156,7 @@ PyEval_EvalFrameEx(PyFrameObject *f, int throwflag)
             if (res == NULL)
                 goto error;
 
+            Py_INCREF(exc); /* Duplicating the exception on the stack */
             PUSH(exc);
             PUSH(res);
             PREDICT(WITH_CLEANUP_FINISH);
@@ -3174,6 +3175,7 @@ PyEval_EvalFrameEx(PyFrameObject *f, int throwflag)
                 err = 0;
 
             Py_DECREF(res);
+            Py_DECREF(exc);
 
             if (err < 0)
                 goto error;