# True if os.dup() can duplicate socket descriptors.
# (On Windows at least, os.dup only works on files)
-_can_dup_socket = hasattr(_socket, "dup")
+_can_dup_socket = hasattr(_socket.socket, "dup")
if _can_dup_socket:
def fromfd(fd, family=AF_INET, type=SOCK_STREAM, proto=0):
nfd = os.dup(fd)
return socket(family, type, proto, fileno=nfd)
+class SocketCloser:
+
+ """Helper to manage socket close() logic for makefile().
+
+ The OS socket should not be closed until the socket and all
+ of its makefile-children are closed. If the refcount is zero
+ when socket.close() is called, this is easy: Just close the
+ socket. If the refcount is non-zero when socket.close() is
+ called, then the real close should not occur until the last
+ makefile-child is closed.
+ """
+
+ def __init__(self, sock):
+ self._sock = sock
+ self._makefile_refs = 0
+ # Test whether the socket is open.
+ try:
+ sock.fileno()
+ self._socket_open = True
+ except error:
+ self._socket_open = False
+
+ def socket_close(self):
+ self._socket_open = False
+ self.close()
+
+ def makefile_open(self):
+ self._makefile_refs += 1
+
+ def makefile_close(self):
+ self._makefile_refs -= 1
+ self.close()
+
+ def close(self):
+ if not (self._socket_open or self._makefile_refs):
+ self._sock._real_close()
+
class socket(_socket.socket):
"""A subclass of _socket.socket adding the makefile() method."""
- __slots__ = ["__weakref__"]
+ __slots__ = ["__weakref__", "_closer"]
if not _can_dup_socket:
__slots__.append("_base")
+ def __init__(self, family=AF_INET, type=SOCK_STREAM, proto=0, fileno=None):
+ if fileno is None:
+ _socket.socket.__init__(self, family, type, proto)
+ else:
+ _socket.socket.__init__(self, family, type, proto, fileno)
+ # Defer creating a SocketCloser until makefile() is actually called.
+ self._closer = None
+
def __repr__(self):
"""Wrap __repr__() to reveal the real class name."""
s = _socket.socket.__repr__(self)
conn.close()
return wrapper, addr
- if not _can_dup_socket:
- def close(self):
- """Wrap close() to close the _base as well."""
- _socket.socket.close(self)
- base = getattr(self, "_base", None)
- if base is not None:
- base.close()
-
def makefile(self, mode="r", buffering=None, *,
encoding=None, newline=None):
"""Return an I/O stream connected to the socket.
rawmode += "r"
if writing:
rawmode += "w"
- raw = io.SocketIO(self, rawmode)
+ if self._closer is None:
+ self._closer = SocketCloser(self)
+ raw = SocketIO(self, rawmode, self._closer)
if buffering is None:
buffering = -1
if buffering < 0:
text.mode = mode
return text
+ def close(self):
+ if self._closer is None:
+ self._real_close()
+ else:
+ self._closer.socket_close()
+
+ # _real_close calls close on the _socket.socket base class.
+
+ if not _can_dup_socket:
+ def _real_close(self):
+ _socket.socket.close(self)
+ base = getattr(self, "_base", None)
+ if base is not None:
+ self._base = None
+ base.close()
+ else:
+ def _real_close(self):
+ _socket.socket.close(self)
+
+
+class SocketIO(io.RawIOBase):
+
+ """Raw I/O implementation for stream sockets.
+
+ This class supports the makefile() method on sockets. It provides
+ the raw I/O interface on top of a socket object.
+ """
+
+ # XXX More docs
+
+ def __init__(self, sock, mode, closer):
+ assert mode in ("r", "w", "rw")
+ io.RawIOBase.__init__(self)
+ self._sock = sock
+ self._mode = mode
+ self._closer = closer
+ closer.makefile_open()
+
+ def readinto(self, b):
+ return self._sock.recv_into(b)
+
+ def write(self, b):
+ return self._sock.send(b)
+
+ def readable(self):
+ return "r" in self._mode
+
+ def writable(self):
+ return "w" in self._mode
+
+ def fileno(self):
+ return self._sock.fileno()
+
+ def close(self):
+ if self.closed:
+ return
+ self._closer.makefile_close()
+ io.RawIOBase.close(self)
+
def getfqdn(name=''):
"""Get fully qualified domain name from name.
self.cli = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
class SocketConnectedTest(ThreadedTCPSocketTest):
+ """Socket tests for client-server connection.
+
+ self.cli_conn is a client socket connected to the server. The
+ setUp() method guarantees that it is connected to the server.
+ """
def __init__(self, methodName='runTest'):
ThreadedTCPSocketTest.__init__(self, methodName=methodName)
self.assertEqual(read, [sd])
self.assertEqual(sd.recv(1), b'')
+ # Calling close() many times should be safe.
+ conn.close()
+ conn.close()
+
def _testClose(self):
self.cli.connect((HOST, PORT))
time.sleep(1.0)
self.cli.send(MSG)
class FileObjectClassTestCase(SocketConnectedTest):
+ """Unit tests for the object returned by socket.makefile()
+
+ self.serv_file is the io object returned by makefile() on
+ the client connection. You can read from this file to
+ get output from the server.
+
+ self.cli_file is the io object returned by makefile() on the
+ server connection. You can write to this file to send output
+ to the client.
+ """
bufsize = -1 # Use default buffer size
self.cli_file.write(MSG)
self.cli_file.flush()
+ def testCloseAfterMakefile(self):
+ # The file returned by makefile should keep the socket open.
+ self.cli_conn.close()
+ # read until EOF
+ msg = self.serv_file.read()
+ self.assertEqual(msg, MSG)
+
+ def _testCloseAfterMakefile(self):
+ self.cli_file.write(MSG)
+ self.cli_file.flush()
+
+ def testMakefileAfterMakefileClose(self):
+ self.serv_file.close()
+ msg = self.cli_conn.recv(len(MSG))
+ self.assertEqual(msg, MSG)
+
+ def _testMakefileAfterMakefileClose(self):
+ self.cli_file.write(MSG)
+ self.cli_file.flush()
+
def testClosedAttr(self):
self.assert_(not self.serv_file.closed)