]> granicus.if.org Git - python/commitdiff
asyncio: SSL transports now clear their reference to the waiter
authorVictor Stinner <victor.stinner@gmail.com>
Wed, 28 Jan 2015 23:36:35 +0000 (00:36 +0100)
committerVictor Stinner <victor.stinner@gmail.com>
Wed, 28 Jan 2015 23:36:35 +0000 (00:36 +0100)
* Rephrase also the comment explaining why the waiter is not awaken immediatly.
* SSLProtocol.eof_received() doesn't instanciate ConnectionResetError exception
  directly, it will be done by Future.set_exception(). The exception is not
  used if the waiter was cancelled or if there is no waiter.

Lib/asyncio/proactor_events.py
Lib/asyncio/selector_events.py
Lib/asyncio/sslproto.py
Lib/asyncio/unix_events.py

index ed170622144e2c0735187680cfc9f4022d083a8b..0f533a5e59061c0e8cc74345f509e2b785a76d17 100644 (file)
@@ -38,7 +38,7 @@ class _ProactorBasePipeTransport(transports._FlowControlMixin,
             self._server._attach()
         self._loop.call_soon(self._protocol.connection_made, self)
         if waiter is not None:
-            # wait until protocol.connection_made() has been called
+            # only wake up the waiter when connection_made() has been called
             self._loop.call_soon(waiter._set_result_unless_cancelled, None)
 
     def __repr__(self):
index 24f8461509afed065b9cc1d13154c894b20475a6..42d88f5de7949c33c6c5f26ac5ed6a3bc846bd5f 100644 (file)
@@ -581,7 +581,7 @@ class _SelectorSocketTransport(_SelectorTransport):
         self._loop.add_reader(self._sock_fd, self._read_ready)
         self._loop.call_soon(self._protocol.connection_made, self)
         if waiter is not None:
-            # wait until protocol.connection_made() has been called
+            # only wake up the waiter when connection_made() has been called
             self._loop.call_soon(waiter._set_result_unless_cancelled, None)
 
     def pause_reading(self):
@@ -732,6 +732,16 @@ class _SelectorSslTransport(_SelectorTransport):
             start_time = None
         self._on_handshake(start_time)
 
+    def _wakeup_waiter(self, exc=None):
+        if self._waiter is None:
+            return
+        if not self._waiter.cancelled():
+            if exc is not None:
+                self._waiter.set_exception(exc)
+            else:
+                self._waiter.set_result(None)
+        self._waiter = None
+
     def _on_handshake(self, start_time):
         try:
             self._sock.do_handshake()
@@ -750,8 +760,7 @@ class _SelectorSslTransport(_SelectorTransport):
             self._loop.remove_reader(self._sock_fd)
             self._loop.remove_writer(self._sock_fd)
             self._sock.close()
-            if self._waiter is not None and not self._waiter.cancelled():
-                self._waiter.set_exception(exc)
+            self._wakeup_waiter(exc)
             if isinstance(exc, Exception):
                 return
             else:
@@ -774,9 +783,7 @@ class _SelectorSslTransport(_SelectorTransport):
                                        "on matching the hostname",
                                        self, exc_info=True)
                     self._sock.close()
-                    if (self._waiter is not None
-                    and not self._waiter.cancelled()):
-                        self._waiter.set_exception(exc)
+                    self._wakeup_waiter(exc)
                     return
 
         # Add extra info that becomes available after handshake.
@@ -789,10 +796,8 @@ class _SelectorSslTransport(_SelectorTransport):
         self._write_wants_read = False
         self._loop.add_reader(self._sock_fd, self._read_ready)
         self._loop.call_soon(self._protocol.connection_made, self)
-        if self._waiter is not None:
-            # wait until protocol.connection_made() has been called
-            self._loop.call_soon(self._waiter._set_result_unless_cancelled,
-                                 None)
+        # only wake up the waiter when connection_made() has been called
+        self._loop.call_soon(self._wakeup_waiter)
 
         if self._loop.get_debug():
             dt = self._loop.time() - start_time
@@ -924,7 +929,7 @@ class _SelectorDatagramTransport(_SelectorTransport):
         self._loop.add_reader(self._sock_fd, self._read_ready)
         self._loop.call_soon(self._protocol.connection_made, self)
         if waiter is not None:
-            # wait until protocol.connection_made() has been called
+            # only wake up the waiter when connection_made() has been called
             self._loop.call_soon(waiter._set_result_unless_cancelled, None)
 
     def get_write_buffer_size(self):
index 26937c84265362a6bb397ae1bb56700aa4dcf11a..fc809b9831de5abeceeb56ebb36083a37b16a5ce 100644 (file)
@@ -418,6 +418,16 @@ class SSLProtocol(protocols.Protocol):
         self._in_shutdown = False
         self._transport = None
 
+    def _wakeup_waiter(self, exc=None):
+        if self._waiter is None:
+            return
+        if not self._waiter.cancelled():
+            if exc is not None:
+                self._waiter.set_exception(exc)
+            else:
+                self._waiter.set_result(None)
+        self._waiter = None
+
     def connection_made(self, transport):
         """Called when the low-level connection is made.
 
@@ -490,8 +500,7 @@ class SSLProtocol(protocols.Protocol):
             if self._loop.get_debug():
                 logger.debug("%r received EOF", self)
 
-            if self._waiter is not None and not self._waiter.done():
-                self._waiter.set_exception(ConnectionResetError())
+            self._wakeup_waiter(ConnectionResetError)
 
             if not self._in_handshake:
                 keep_open = self._app_protocol.eof_received()
@@ -556,8 +565,7 @@ class SSLProtocol(protocols.Protocol):
                                    self, exc_info=True)
             self._transport.close()
             if isinstance(exc, Exception):
-                if self._waiter is not None and not self._waiter.cancelled():
-                    self._waiter.set_exception(exc)
+                self._wakeup_waiter(exc)
                 return
             else:
                 raise
@@ -572,9 +580,7 @@ class SSLProtocol(protocols.Protocol):
                            compression=sslobj.compression(),
                            )
         self._app_protocol.connection_made(self._app_transport)
-        if self._waiter is not None:
-            # wait until protocol.connection_made() has been called
-            self._waiter._set_result_unless_cancelled(None)
+        self._wakeup_waiter()
         self._session_established = True
         # In case transport.write() was already called. Don't call
         # immediatly _process_write_backlog(), but schedule it:
index 97f9addde88a0ada13eba3e26907918bbe2bcb69..67973f14f3ff5c5f0362bb0a2bc7c05db6941dd5 100644 (file)
@@ -301,7 +301,7 @@ class _UnixReadPipeTransport(transports.ReadTransport):
         self._loop.add_reader(self._fileno, self._read_ready)
         self._loop.call_soon(self._protocol.connection_made, self)
         if waiter is not None:
-            # wait until protocol.connection_made() has been called
+            # only wake up the waiter when connection_made() has been called
             self._loop.call_soon(waiter._set_result_unless_cancelled, None)
 
     def __repr__(self):
@@ -409,7 +409,7 @@ class _UnixWritePipeTransport(transports._FlowControlMixin,
 
         self._loop.call_soon(self._protocol.connection_made, self)
         if waiter is not None:
-            # wait until protocol.connection_made() has been called
+            # only wake up the waiter when connection_made() has been called
             self._loop.call_soon(waiter._set_result_unless_cancelled, None)
 
     def __repr__(self):