return self._ssl_protocol._get_extra_info(name, default)
def set_protocol(self, protocol):
- self._ssl_protocol._app_protocol = protocol
+ self._ssl_protocol._set_app_protocol(protocol)
def get_protocol(self):
return self._ssl_protocol._app_protocol
self._waiter = waiter
self._loop = loop
- self._app_protocol = app_protocol
- self._app_protocol_is_buffer = \
- isinstance(app_protocol, protocols.BufferedProtocol)
+ self._set_app_protocol(app_protocol)
self._app_transport = _SSLProtocolTransport(self._loop, self)
# _SSLPipe instance (None until the connection is made)
self._sslpipe = None
self._call_connection_made = call_connection_made
self._ssl_handshake_timeout = ssl_handshake_timeout
+ def _set_app_protocol(self, app_protocol):
+ self._app_protocol = app_protocol
+ self._app_protocol_is_buffer = \
+ isinstance(app_protocol, protocols.BufferedProtocol)
+
def _wakeup_waiter(self, exc=None):
if self._waiter is None:
return
server_context = test_utils.simple_server_sslcontext()
client_context = test_utils.simple_client_sslcontext()
+ client_con_made_calls = 0
def serve(sock):
sock.settimeout(self.TIMEOUT)
data = sock.recv_all(len(HELLO_MSG))
self.assertEqual(len(data), len(HELLO_MSG))
+ sock.sendall(b'2')
+ data = sock.recv_all(len(HELLO_MSG))
+ self.assertEqual(len(data), len(HELLO_MSG))
+
sock.shutdown(socket.SHUT_RDWR)
sock.close()
- class ClientProto(asyncio.BufferedProtocol):
- def __init__(self, on_data, on_eof):
+ class ClientProtoFirst(asyncio.BufferedProtocol):
+ def __init__(self, on_data):
self.on_data = on_data
- self.on_eof = on_eof
- self.con_made_cnt = 0
self.buf = bytearray(1)
- def connection_made(proto, tr):
- proto.con_made_cnt += 1
- # Ensure connection_made gets called only once.
- self.assertEqual(proto.con_made_cnt, 1)
+ def connection_made(self, tr):
+ nonlocal client_con_made_calls
+ client_con_made_calls += 1
def get_buffer(self, sizehint):
return self.buf
assert nsize == 1
self.on_data.set_result(bytes(self.buf[:nsize]))
+ class ClientProtoSecond(asyncio.Protocol):
+ def __init__(self, on_data, on_eof):
+ self.on_data = on_data
+ self.on_eof = on_eof
+ self.con_made_cnt = 0
+
+ def connection_made(self, tr):
+ nonlocal client_con_made_calls
+ client_con_made_calls += 1
+
+ def data_received(self, data):
+ self.on_data.set_result(data)
+
def eof_received(self):
self.on_eof.set_result(True)
async def client(addr):
await asyncio.sleep(0.5, loop=self.loop)
- on_data = self.loop.create_future()
+ on_data1 = self.loop.create_future()
+ on_data2 = self.loop.create_future()
on_eof = self.loop.create_future()
tr, proto = await self.loop.create_connection(
- lambda: ClientProto(on_data, on_eof), *addr)
+ lambda: ClientProtoFirst(on_data1), *addr)
tr.write(HELLO_MSG)
new_tr = await self.loop.start_tls(tr, proto, client_context)
- self.assertEqual(await on_data, b'O')
+ self.assertEqual(await on_data1, b'O')
+ new_tr.write(HELLO_MSG)
+
+ new_tr.set_protocol(ClientProtoSecond(on_data2, on_eof))
+ self.assertEqual(await on_data2, b'2')
new_tr.write(HELLO_MSG)
await on_eof
new_tr.close()
+ # connection_made() should be called only once -- when
+ # we establish connection for the first time. Start TLS
+ # doesn't call connection_made() on application protocols.
+ self.assertEqual(client_con_made_calls, 1)
+
with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
self.loop.run_until_complete(
asyncio.wait_for(client(srv.addr),