]> granicus.if.org Git - python/commitdiff
bpo-33540, socketserver: Add _block_on_close for tests (GH-7317)
authorVictor Stinner <vstinner@redhat.com>
Fri, 1 Jun 2018 14:24:43 +0000 (16:24 +0200)
committerGitHub <noreply@github.com>
Fri, 1 Jun 2018 14:24:43 +0000 (16:24 +0200)
* Add a private _block_on_close attribute to ForkingMixIn and
  ThreadingMixIn classes of socketserver.
* Use _block_on_close=True in test_socketserver and test_logging

Lib/socketserver.py
Lib/test/test_logging.py
Lib/test/test_socketserver.py

index 41a37667721f0334fdae4ef6d47874ca0ccba891..c4d544b372da632708cffd082872d5c9850dca75 100644 (file)
@@ -547,8 +547,10 @@ if hasattr(os, "fork"):
         timeout = 300
         active_children = None
         max_children = 40
+        # If true, server_close() waits until all child processes complete.
+        _block_on_close = False
 
-        def collect_children(self):
+        def collect_children(self, *, blocking=False):
             """Internal routine to wait for children that have exited."""
             if self.active_children is None:
                 return
@@ -572,7 +574,8 @@ if hasattr(os, "fork"):
             # Now reap all defunct children.
             for pid in self.active_children.copy():
                 try:
-                    pid, _ = os.waitpid(pid, os.WNOHANG)
+                    flags = 0 if blocking else os.WNOHANG
+                    pid, _ = os.waitpid(pid, flags)
                     # if the child hasn't exited yet, pid will be 0 and ignored by
                     # discard() below
                     self.active_children.discard(pid)
@@ -621,6 +624,10 @@ if hasattr(os, "fork"):
                     finally:
                         os._exit(status)
 
+        def server_close(self):
+            super().server_close()
+            self.collect_children(blocking=self._block_on_close)
+
 
 class ThreadingMixIn:
     """Mix-in class to handle each request in a new thread."""
@@ -628,6 +635,11 @@ class ThreadingMixIn:
     # Decides how threads will act upon termination of the
     # main process
     daemon_threads = False
+    # If true, server_close() waits until all non-daemonic threads terminate.
+    _block_on_close = False
+    # For non-daemonic threads, list of threading.Threading objects
+    # used by server_close() to wait for all threads completion.
+    _threads = None
 
     def process_request_thread(self, request, client_address):
         """Same as in BaseServer but as a thread.
@@ -647,8 +659,21 @@ class ThreadingMixIn:
         t = threading.Thread(target = self.process_request_thread,
                              args = (request, client_address))
         t.daemon = self.daemon_threads
+        if not t.daemon and self._block_on_close:
+            if self._threads is None:
+                self._threads = []
+            self._threads.append(t)
         t.start()
 
+    def server_close(self):
+        super().server_close()
+        if self._block_on_close:
+            threads = self._threads
+            self._threads = None
+            if threads:
+                for thread in threads:
+                    thread.join()
+
 
 if hasattr(os, "fork"):
     class ForkingUDPServer(ForkingMixIn, UDPServer): pass
index fc067138c3b7fcbe7fcbb90837936bb9ab998611..d341ef8779bda33d48db07daa2051837069d9197 100644 (file)
@@ -883,6 +883,7 @@ if threading:
         """
 
         allow_reuse_address = True
+        _block_on_close = True
 
         def __init__(self, addr, handler, poll_interval=0.5,
                      bind_and_activate=True):
@@ -915,6 +916,8 @@ if threading:
                             before calling :meth:`start`, so that the server will
                             set up the socket and listen on it.
         """
+        _block_on_close = True
+
         def __init__(self, addr, handler, poll_interval=0.5,
                      bind_and_activate=True):
             class DelegatingUDPRequestHandler(DatagramRequestHandler):
@@ -1474,11 +1477,11 @@ class SocketHandlerTest(BaseTest):
     def tearDown(self):
         """Shutdown the TCP server."""
         try:
-            if self.server:
-                self.server.stop(2.0)
             if self.sock_hdlr:
                 self.root_logger.removeHandler(self.sock_hdlr)
                 self.sock_hdlr.close()
+            if self.server:
+                self.server.stop(2.0)
         finally:
             BaseTest.tearDown(self)
 
index 43621337e03dee4260ec695c8bc47d8f73171e22..8177c4178773da91008a9e995e810a963a146b8f 100644 (file)
@@ -48,11 +48,11 @@ def receive(sock, n, timeout=20):
 if HAVE_UNIX_SOCKETS and HAVE_FORKING:
     class ForkingUnixStreamServer(socketserver.ForkingMixIn,
                                   socketserver.UnixStreamServer):
-        pass
+        _block_on_close = True
 
     class ForkingUnixDatagramServer(socketserver.ForkingMixIn,
                                     socketserver.UnixDatagramServer):
-        pass
+        _block_on_close = True
 
 
 @contextlib.contextmanager
@@ -62,24 +62,14 @@ def simple_subprocess(testcase):
     if pid == 0:
         # Don't raise an exception; it would be caught by the test harness.
         os._exit(72)
-    yield None
-    pid2, status = os.waitpid(pid, 0)
-    testcase.assertEqual(pid2, pid)
-    testcase.assertEqual(72 << 8, status)
-
-
-def close_server(server):
-    server.server_close()
-
-    if hasattr(server, 'active_children'):
-        # ForkingMixIn: Manually reap all child processes, since server_close()
-        # calls waitpid() in non-blocking mode using the WNOHANG flag.
-        for pid in server.active_children.copy():
-            try:
-                os.waitpid(pid, 0)
-            except ChildProcessError:
-                pass
-        server.active_children.clear()
+    try:
+        yield None
+    except:
+        raise
+    finally:
+        pid2, status = os.waitpid(pid, 0)
+        testcase.assertEqual(pid2, pid)
+        testcase.assertEqual(72 << 8, status)
 
 
 @unittest.skipUnless(threading, 'Threading required for this test.')
@@ -115,6 +105,8 @@ class SocketServerTest(unittest.TestCase):
 
     def make_server(self, addr, svrcls, hdlrbase):
         class MyServer(svrcls):
+            _block_on_close = True
+
             def handle_error(self, request, client_address):
                 self.close_request(request)
                 raise
@@ -156,8 +148,12 @@ class SocketServerTest(unittest.TestCase):
         if verbose: print("waiting for server")
         server.shutdown()
         t.join()
-        close_server(server)
+        server.server_close()
         self.assertEqual(-1, server.socket.fileno())
+        if HAVE_FORKING and isinstance(server, socketserver.ForkingMixIn):
+            # bpo-31151: Check that ForkingMixIn.server_close() waits until
+            # all children completed
+            self.assertFalse(server.active_children)
         if verbose: print("done")
 
     def stream_examine(self, proto, addr):
@@ -280,7 +276,7 @@ class SocketServerTest(unittest.TestCase):
             s.shutdown()
         for t, s in threads:
             t.join()
-            close_server(s)
+            s.server_close()
 
     def test_tcpserver_bind_leak(self):
         # Issue #22435: the server socket wouldn't be closed if bind()/listen()
@@ -344,6 +340,8 @@ class ErrorHandlerTest(unittest.TestCase):
 
 
 class BaseErrorTestServer(socketserver.TCPServer):
+    _block_on_close = True
+
     def __init__(self, exception):
         self.exception = exception
         super().__init__((HOST, 0), BadHandler)
@@ -352,7 +350,7 @@ class BaseErrorTestServer(socketserver.TCPServer):
         try:
             self.handle_request()
         finally:
-            close_server(self)
+            self.server_close()
         self.wait_done()
 
     def handle_error(self, request, client_address):
@@ -386,7 +384,7 @@ class ThreadingErrorTestServer(socketserver.ThreadingMixIn,
 
 if HAVE_FORKING:
     class ForkingErrorTestServer(socketserver.ForkingMixIn, BaseErrorTestServer):
-        pass
+        _block_on_close = True
 
 
 class SocketWriterTest(unittest.TestCase):
@@ -398,7 +396,7 @@ class SocketWriterTest(unittest.TestCase):
                 self.server.request_fileno = self.request.fileno()
 
         server = socketserver.TCPServer((HOST, 0), Handler)
-        self.addCleanup(close_server, server)
+        self.addCleanup(server.server_close)
         s = socket.socket(
             server.address_family, socket.SOCK_STREAM, socket.IPPROTO_TCP)
         with s:
@@ -422,7 +420,7 @@ class SocketWriterTest(unittest.TestCase):
                 self.server.sent2 = self.wfile.write(big_chunk)
 
         server = socketserver.TCPServer((HOST, 0), Handler)
-        self.addCleanup(close_server, server)
+        self.addCleanup(server.server_close)
         interrupted = threading.Event()
 
         def signal_handler(signum, frame):
@@ -498,7 +496,7 @@ class MiscTestCase(unittest.TestCase):
         s.close()
         server.handle_request()
         self.assertEqual(server.shutdown_called, 1)
-        close_server(server)
+        server.server_close()
 
 
 if __name__ == "__main__":