From a05a6ef1ca781e2f98fb4332284aca649f24f75d Mon Sep 17 00:00:00 2001 From: Yury Selivanov Date: Sun, 11 Sep 2016 21:11:02 -0400 Subject: [PATCH] asyncio: Add set_protocol / get_protocol methods to Transports --- Lib/asyncio/base_subprocess.py | 6 ++++++ Lib/asyncio/proactor_events.py | 6 ++++++ Lib/asyncio/selector_events.py | 6 ++++++ Lib/asyncio/sslproto.py | 6 ++++++ Lib/asyncio/transports.py | 8 ++++++++ Lib/asyncio/unix_events.py | 12 ++++++++++++ Lib/test/test_asyncio/test_sslproto.py | 1 + 7 files changed, 45 insertions(+) diff --git a/Lib/asyncio/base_subprocess.py b/Lib/asyncio/base_subprocess.py index 8fc253c18e..bcc481d20e 100644 --- a/Lib/asyncio/base_subprocess.py +++ b/Lib/asyncio/base_subprocess.py @@ -87,6 +87,12 @@ class BaseSubprocessTransport(transports.SubprocessTransport): def _start(self, args, shell, stdin, stdout, stderr, bufsize, **kwargs): raise NotImplementedError + def set_protocol(self, protocol): + self._protocol = protocol + + def get_protocol(self): + return self._protocol + def is_closing(self): return self._closed diff --git a/Lib/asyncio/proactor_events.py b/Lib/asyncio/proactor_events.py index 3ac314c0cc..97ab487f97 100644 --- a/Lib/asyncio/proactor_events.py +++ b/Lib/asyncio/proactor_events.py @@ -66,6 +66,12 @@ class _ProactorBasePipeTransport(transports._FlowControlMixin, def _set_extra(self, sock): self._extra['pipe'] = sock + def set_protocol(self, protocol): + self._protocol = protocol + + def get_protocol(self): + return self._protocol + def is_closing(self): return self._closing diff --git a/Lib/asyncio/selector_events.py b/Lib/asyncio/selector_events.py index ed2b4d756f..c57f509a12 100644 --- a/Lib/asyncio/selector_events.py +++ b/Lib/asyncio/selector_events.py @@ -560,6 +560,12 @@ class _SelectorTransport(transports._FlowControlMixin, def abort(self): self._force_close(None) + def set_protocol(self, protocol): + self._protocol = protocol + + def get_protocol(self): + return self._protocol + def is_closing(self): return self._closing diff --git a/Lib/asyncio/sslproto.py b/Lib/asyncio/sslproto.py index 33d5de2db0..afe85a1438 100644 --- a/Lib/asyncio/sslproto.py +++ b/Lib/asyncio/sslproto.py @@ -305,6 +305,12 @@ class _SSLProtocolTransport(transports._FlowControlMixin, """Get optional transport information.""" return self._ssl_protocol._get_extra_info(name, default) + def set_protocol(self, protocol): + self._app_protocol = protocol + + def get_protocol(self): + return self._app_protocol + def is_closing(self): return self._closed diff --git a/Lib/asyncio/transports.py b/Lib/asyncio/transports.py index 9a6d9197d9..0db0875715 100644 --- a/Lib/asyncio/transports.py +++ b/Lib/asyncio/transports.py @@ -33,6 +33,14 @@ class BaseTransport: """ raise NotImplementedError + def set_protocol(self, protocol): + """Set a new protocol.""" + raise NotImplementedError + + def get_protocol(self): + """Return the current protocol.""" + raise NotImplementedError + class ReadTransport(BaseTransport): """Interface for read-only transports.""" diff --git a/Lib/asyncio/unix_events.py b/Lib/asyncio/unix_events.py index 18519fc120..f7f9eb2a1d 100644 --- a/Lib/asyncio/unix_events.py +++ b/Lib/asyncio/unix_events.py @@ -374,6 +374,12 @@ class _UnixReadPipeTransport(transports.ReadTransport): def resume_reading(self): self._loop.add_reader(self._fileno, self._read_ready) + def set_protocol(self, protocol): + self._protocol = protocol + + def get_protocol(self): + return self._protocol + def is_closing(self): return self._closing @@ -570,6 +576,12 @@ class _UnixWritePipeTransport(transports._FlowControlMixin, self._loop.remove_reader(self._fileno) self._loop.call_soon(self._call_connection_lost, None) + def set_protocol(self, protocol): + self._protocol = protocol + + def get_protocol(self): + return self._protocol + def is_closing(self): return self._closing diff --git a/Lib/test/test_asyncio/test_sslproto.py b/Lib/test/test_asyncio/test_sslproto.py index 8d5233565e..7dfa6c2c63 100644 --- a/Lib/test/test_asyncio/test_sslproto.py +++ b/Lib/test/test_asyncio/test_sslproto.py @@ -25,6 +25,7 @@ class SslProtoHandshakeTests(test_utils.TestCase): sslcontext = test_utils.dummy_ssl_context() app_proto = asyncio.Protocol() proto = sslproto.SSLProtocol(self.loop, app_proto, sslcontext, waiter) + self.assertIs(proto._app_transport.get_protocol(), app_proto) self.addCleanup(proto._app_transport.close) return proto -- 2.40.0