]> granicus.if.org Git - python/commitdiff
Make StreamServer.close() tests more robust (GH-13790)
authorAndrew Svetlov <andrew.svetlov@gmail.com>
Tue, 4 Jun 2019 11:37:10 +0000 (14:37 +0300)
committerYury Selivanov <yury@magic.io>
Tue, 4 Jun 2019 11:37:10 +0000 (13:37 +0200)
Lib/test/test_asyncio/test_streams.py

index c1b9dc95ee621f09ab43432475e2de752bda351b..e484746432af508d67a74c557d29c82f35385725 100644 (file)
@@ -1507,10 +1507,14 @@ os.close(fd)
 
     def test_stream_server_abort(self):
         server_stream_aborted = False
-        fut = self.loop.create_future()
+        fut1 = self.loop.create_future()
+        fut2 = self.loop.create_future()
 
         async def handle_client(stream):
-            await fut
+            data = await stream.readexactly(4)
+            self.assertEqual(b'data', data)
+            fut1.set_result(None)
+            await fut2
             self.assertEqual(b'', await stream.readline())
             nonlocal server_stream_aborted
             server_stream_aborted = True
@@ -1518,7 +1522,8 @@ os.close(fd)
         async def client(srv):
             addr = srv.sockets[0].getsockname()
             stream = await asyncio.connect(*addr)
-            fut.set_result(None)
+            await stream.write(b'data')
+            await fut2
             self.assertEqual(b'', await stream.readline())
             await stream.close()
 
@@ -1526,7 +1531,8 @@ os.close(fd)
             async with asyncio.StreamServer(handle_client, '127.0.0.1', 0) as server:
                 await server.start_serving()
                 task = asyncio.create_task(client(server))
-                await fut
+                await fut1
+                fut2.set_result(None)
                 await server.abort()
                 await task
 
@@ -1534,21 +1540,31 @@ os.close(fd)
         self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
         self.loop.run_until_complete(test())
         self.assertEqual(messages, [])
-        self.assertTrue(fut.done())
+        self.assertTrue(fut1.done())
+        self.assertTrue(fut2.done())
         self.assertTrue(server_stream_aborted)
 
     def test_stream_shutdown_hung_task(self):
         fut1 = self.loop.create_future()
         fut2 = self.loop.create_future()
+        cancelled = self.loop.create_future()
 
         async def handle_client(stream):
-            while True:
-                await asyncio.sleep(0.01)
+            data = await stream.readexactly(4)
+            self.assertEqual(b'data', data)
+            fut1.set_result(None)
+            await fut2
+            try:
+                while True:
+                    await asyncio.sleep(0.01)
+            except asyncio.CancelledError:
+                cancelled.set_result(None)
+                raise
 
         async def client(srv):
             addr = srv.sockets[0].getsockname()
             stream = await asyncio.connect(*addr)
-            fut1.set_result(None)
+            await stream.write(b'data')
             await fut2
             self.assertEqual(b'', await stream.readline())
             await stream.close()
@@ -1561,9 +1577,10 @@ os.close(fd)
                 await server.start_serving()
                 task = asyncio.create_task(client(server))
                 await fut1
-                await server.close()
                 fut2.set_result(None)
+                await server.close()
                 await task
+                await cancelled
 
         messages = []
         self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
@@ -1571,21 +1588,28 @@ os.close(fd)
         self.assertEqual(messages, [])
         self.assertTrue(fut1.done())
         self.assertTrue(fut2.done())
+        self.assertTrue(cancelled.done())
 
     def test_stream_shutdown_hung_task_prevents_cancellation(self):
         fut1 = self.loop.create_future()
         fut2 = self.loop.create_future()
+        cancelled = self.loop.create_future()
         do_handle_client = True
 
         async def handle_client(stream):
+            data = await stream.readexactly(4)
+            self.assertEqual(b'data', data)
+            fut1.set_result(None)
+            await fut2
             while do_handle_client:
                 with contextlib.suppress(asyncio.CancelledError):
                     await asyncio.sleep(0.01)
+            cancelled.set_result(None)
 
         async def client(srv):
             addr = srv.sockets[0].getsockname()
             stream = await asyncio.connect(*addr)
-            fut1.set_result(None)
+            await stream.write(b'data')
             await fut2
             self.assertEqual(b'', await stream.readline())
             await stream.close()
@@ -1598,11 +1622,12 @@ os.close(fd)
                 await server.start_serving()
                 task = asyncio.create_task(client(server))
                 await fut1
+                fut2.set_result(None)
                 await server.close()
                 nonlocal do_handle_client
                 do_handle_client = False
-                fut2.set_result(None)
                 await task
+                await cancelled
 
         messages = []
         self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
@@ -1612,6 +1637,7 @@ os.close(fd)
                          "<Task pending .+ ignored cancellation request")
         self.assertTrue(fut1.done())
         self.assertTrue(fut2.done())
+        self.assertTrue(cancelled.done())
 
     def test_sendfile(self):
         messages = []