timeout = 300
active_children = None
max_children = 40
+ # If true, server_close() waits until all child processes complete.
+ _block_on_close = False
- def collect_children(self):
+ def collect_children(self, *, blocking=False):
"""Internal routine to wait for children that have exited."""
if self.active_children is None:
return
# Now reap all defunct children.
for pid in self.active_children.copy():
try:
- pid, _ = os.waitpid(pid, os.WNOHANG)
+ flags = 0 if blocking else os.WNOHANG
+ pid, _ = os.waitpid(pid, flags)
# if the child hasn't exited yet, pid will be 0 and ignored by
# discard() below
self.active_children.discard(pid)
finally:
os._exit(status)
+ def server_close(self):
+ super().server_close()
+ self.collect_children(blocking=self._block_on_close)
+
class ThreadingMixIn:
"""Mix-in class to handle each request in a new thread."""
# Decides how threads will act upon termination of the
# main process
daemon_threads = False
+ # If true, server_close() waits until all non-daemonic threads terminate.
+ _block_on_close = False
+ # For non-daemonic threads, list of threading.Threading objects
+ # used by server_close() to wait for all threads completion.
+ _threads = None
def process_request_thread(self, request, client_address):
"""Same as in BaseServer but as a thread.
t = threading.Thread(target = self.process_request_thread,
args = (request, client_address))
t.daemon = self.daemon_threads
+ if not t.daemon and self._block_on_close:
+ if self._threads is None:
+ self._threads = []
+ self._threads.append(t)
t.start()
+ def server_close(self):
+ super().server_close()
+ if self._block_on_close:
+ threads = self._threads
+ self._threads = None
+ if threads:
+ for thread in threads:
+ thread.join()
+
if hasattr(os, "fork"):
class ForkingUDPServer(ForkingMixIn, UDPServer): pass
if HAVE_UNIX_SOCKETS and HAVE_FORKING:
class ForkingUnixStreamServer(socketserver.ForkingMixIn,
socketserver.UnixStreamServer):
- pass
+ _block_on_close = True
class ForkingUnixDatagramServer(socketserver.ForkingMixIn,
socketserver.UnixDatagramServer):
- pass
+ _block_on_close = True
@contextlib.contextmanager
if pid == 0:
# Don't raise an exception; it would be caught by the test harness.
os._exit(72)
- yield None
- pid2, status = os.waitpid(pid, 0)
- testcase.assertEqual(pid2, pid)
- testcase.assertEqual(72 << 8, status)
-
-
-def close_server(server):
- server.server_close()
-
- if hasattr(server, 'active_children'):
- # ForkingMixIn: Manually reap all child processes, since server_close()
- # calls waitpid() in non-blocking mode using the WNOHANG flag.
- for pid in server.active_children.copy():
- try:
- os.waitpid(pid, 0)
- except ChildProcessError:
- pass
- server.active_children.clear()
+ try:
+ yield None
+ except:
+ raise
+ finally:
+ pid2, status = os.waitpid(pid, 0)
+ testcase.assertEqual(pid2, pid)
+ testcase.assertEqual(72 << 8, status)
@unittest.skipUnless(threading, 'Threading required for this test.')
def make_server(self, addr, svrcls, hdlrbase):
class MyServer(svrcls):
+ _block_on_close = True
+
def handle_error(self, request, client_address):
self.close_request(request)
raise
if verbose: print("waiting for server")
server.shutdown()
t.join()
- close_server(server)
+ server.server_close()
self.assertEqual(-1, server.socket.fileno())
+ if HAVE_FORKING and isinstance(server, socketserver.ForkingMixIn):
+ # bpo-31151: Check that ForkingMixIn.server_close() waits until
+ # all children completed
+ self.assertFalse(server.active_children)
if verbose: print("done")
def stream_examine(self, proto, addr):
s.shutdown()
for t, s in threads:
t.join()
- close_server(s)
+ s.server_close()
def test_tcpserver_bind_leak(self):
# Issue #22435: the server socket wouldn't be closed if bind()/listen()
class BaseErrorTestServer(socketserver.TCPServer):
+ _block_on_close = True
+
def __init__(self, exception):
self.exception = exception
super().__init__((HOST, 0), BadHandler)
try:
self.handle_request()
finally:
- close_server(self)
+ self.server_close()
self.wait_done()
def handle_error(self, request, client_address):
if HAVE_FORKING:
class ForkingErrorTestServer(socketserver.ForkingMixIn, BaseErrorTestServer):
- pass
+ _block_on_close = True
class SocketWriterTest(unittest.TestCase):
self.server.request_fileno = self.request.fileno()
server = socketserver.TCPServer((HOST, 0), Handler)
- self.addCleanup(close_server, server)
+ self.addCleanup(server.server_close)
s = socket.socket(
server.address_family, socket.SOCK_STREAM, socket.IPPROTO_TCP)
with s:
self.server.sent2 = self.wfile.write(big_chunk)
server = socketserver.TCPServer((HOST, 0), Handler)
- self.addCleanup(close_server, server)
+ self.addCleanup(server.server_close)
interrupted = threading.Event()
def signal_handler(signum, frame):
s.close()
server.handle_request()
self.assertEqual(server.shutdown_called, 1)
- close_server(server)
+ server.server_close()
if __name__ == "__main__":