]> granicus.if.org Git - python/commitdiff
asyncio: Fix SSLProtocol.eof_received()
authorVictor Stinner <victor.stinner@gmail.com>
Wed, 28 Jan 2015 23:35:56 +0000 (00:35 +0100)
committerVictor Stinner <victor.stinner@gmail.com>
Wed, 28 Jan 2015 23:35:56 +0000 (00:35 +0100)
Wake-up the waiter if it is not done yet.

Lib/asyncio/sslproto.py
Lib/test/test_asyncio/test_sslproto.py

index f2b856c40cb9f6afedd3ef3888a78c3b03edea1b..26937c84265362a6bb397ae1bb56700aa4dcf11a 100644 (file)
@@ -489,6 +489,10 @@ class SSLProtocol(protocols.Protocol):
         try:
             if self._loop.get_debug():
                 logger.debug("%r received EOF", self)
+
+            if self._waiter is not None and not self._waiter.done():
+                self._waiter.set_exception(ConnectionResetError())
+
             if not self._in_handshake:
                 keep_open = self._app_protocol.eof_received()
                 if keep_open:
index b1a61c483d2ae250915b6162c20e8db2eedaadaa..148e30dffeb4dab851caecc510cb018f2f663341 100644 (file)
@@ -12,21 +12,36 @@ from asyncio import sslproto
 from asyncio import test_utils
 
 
+@unittest.skipIf(ssl is None, 'No ssl module')
 class SslProtoHandshakeTests(test_utils.TestCase):
 
     def setUp(self):
         self.loop = asyncio.new_event_loop()
         self.set_event_loop(self.loop)
 
-    @unittest.skipIf(ssl is None, 'No ssl module')
+    def ssl_protocol(self, waiter=None):
+        sslcontext = test_utils.dummy_ssl_context()
+        app_proto = asyncio.Protocol()
+        return sslproto.SSLProtocol(self.loop, app_proto, sslcontext, waiter)
+
+    def connection_made(self, ssl_proto, do_handshake=None):
+        transport = mock.Mock()
+        sslpipe = mock.Mock()
+        sslpipe.shutdown.return_value = b''
+        if do_handshake:
+            sslpipe.do_handshake.side_effect = do_handshake
+        else:
+            def mock_handshake(callback):
+                return []
+            sslpipe.do_handshake.side_effect = mock_handshake
+        with mock.patch('asyncio.sslproto._SSLPipe', return_value=sslpipe):
+            ssl_proto.connection_made(transport)
+
     def test_cancel_handshake(self):
         # Python issue #23197: cancelling an handshake must not raise an
         # exception or log an error, even if the handshake failed
-        sslcontext = test_utils.dummy_ssl_context()
-        app_proto = asyncio.Protocol()
         waiter = asyncio.Future(loop=self.loop)
-        ssl_proto = sslproto.SSLProtocol(self.loop, app_proto, sslcontext,
-                                         waiter)
+        ssl_proto = self.ssl_protocol(waiter)
         handshake_fut = asyncio.Future(loop=self.loop)
 
         def do_handshake(callback):
@@ -36,12 +51,7 @@ class SslProtoHandshakeTests(test_utils.TestCase):
             return []
 
         waiter.cancel()
-        transport = mock.Mock()
-        sslpipe = mock.Mock()
-        sslpipe.shutdown.return_value = b''
-        sslpipe.do_handshake.side_effect = do_handshake
-        with mock.patch('asyncio.sslproto._SSLPipe', return_value=sslpipe):
-            ssl_proto.connection_made(transport)
+        self.connection_made(ssl_proto, do_handshake)
 
         with test_utils.disable_logger():
             self.loop.run_until_complete(handshake_fut)
@@ -49,6 +59,14 @@ class SslProtoHandshakeTests(test_utils.TestCase):
         # Close the transport
         ssl_proto._app_transport.close()
 
+    def test_eof_received_waiter(self):
+        waiter = asyncio.Future(loop=self.loop)
+        ssl_proto = self.ssl_protocol(waiter)
+        self.connection_made(ssl_proto)
+        ssl_proto.eof_received()
+        test_utils.run_briefly(self.loop)
+        self.assertIsInstance(waiter.exception(), ConnectionResetError)
+
 
 if __name__ == '__main__':
     unittest.main()