From: Vinay Sajip Date: Tue, 24 Apr 2012 22:33:33 +0000 (+0100) Subject: Issue #14632: Updated WatchedFileHandler to deal with race condition. Thanks to John... X-Git-Tag: v3.3.0a3~84 X-Git-Url: https://granicus.if.org/sourcecode?a=commitdiff_plain;h=a5798ded26e3ad67946b7d8150a8e13b0075fbf8;p=python Issue #14632: Updated WatchedFileHandler to deal with race condition. Thanks to John Mulligan for the problem report and patch. --- a5798ded26e3ad67946b7d8150a8e13b0075fbf8 diff --cc Lib/logging/handlers.py index 9c63797431,22f8c3446c..7a97c80713 --- a/Lib/logging/handlers.py +++ b/Lib/logging/handlers.py @@@ -23,8 -23,7 +23,8 @@@ Copyright (C) 2001-2012 Vinay Sajip. Al To use, simply 'import logging.handlers' and log away! """ - import logging, socket, os, pickle, struct, time, re + import errno, logging, socket, os, pickle, struct, time, re +from codecs import BOM_UTF8 from stat import ST_DEV, ST_INO, ST_MTIME import queue try: @@@ -417,13 -381,15 +417,15 @@@ class WatchedFileHandler(logging.FileHa This handler is based on a suggestion and patch by Chad J. Schroeder. """ - def __init__(self, filename, mode='a', encoding=None, delay=0): + def __init__(self, filename, mode='a', encoding=None, delay=False): logging.FileHandler.__init__(self, filename, mode, encoding, delay) - if not os.path.exists(self.baseFilename): - self.dev, self.ino = -1, -1 - else: - stat = os.stat(self.baseFilename) - self.dev, self.ino = stat[ST_DEV], stat[ST_INO] + self.dev, self.ino = -1, -1 + self._statstream() + + def _statstream(self): + if self.stream: + sres = os.fstat(self.stream.fileno()) + self.dev, self.ino = sres[ST_DEV], sres[ST_INO] def emit(self, record): """ diff --cc Lib/test/test_logging.py index 75b3e4d814,ca4d9309f2..ee1c2113ca --- a/Lib/test/test_logging.py +++ b/Lib/test/test_logging.py @@@ -518,439 -498,6 +519,474 @@@ class CustomLevelsAndFiltersTest(BaseTe handler.removeFilter(garr) +class HandlerTest(BaseTest): + def test_name(self): + h = logging.Handler() + h.name = 'generic' + self.assertEqual(h.name, 'generic') + h.name = 'anothergeneric' + self.assertEqual(h.name, 'anothergeneric') + self.assertRaises(NotImplementedError, h.emit, None) + + def test_builtin_handlers(self): + # We can't actually *use* too many handlers in the tests, + # but we can try instantiating them with various options + if sys.platform in ('linux', 'darwin'): + for existing in (True, False): + fd, fn = tempfile.mkstemp() + os.close(fd) + if not existing: + os.unlink(fn) + h = logging.handlers.WatchedFileHandler(fn, delay=True) + if existing: + dev, ino = h.dev, h.ino - self.assertNotEqual(dev, -1) - self.assertNotEqual(ino, -1) ++ self.assertEqual(dev, -1) ++ self.assertEqual(ino, -1) + r = logging.makeLogRecord({'msg': 'Test'}) + h.handle(r) + # Now remove the file. + os.unlink(fn) + self.assertFalse(os.path.exists(fn)) + # The next call should recreate the file. + h.handle(r) + self.assertTrue(os.path.exists(fn)) + else: + self.assertEqual(h.dev, -1) + self.assertEqual(h.ino, -1) + h.close() + if existing: + os.unlink(fn) + if sys.platform == 'darwin': + sockname = '/var/run/syslog' + else: + sockname = '/dev/log' + try: + h = logging.handlers.SysLogHandler(sockname) + self.assertEqual(h.facility, h.LOG_USER) + self.assertTrue(h.unixsocket) + h.close() + except socket.error: # syslogd might not be available + pass + for method in ('GET', 'POST', 'PUT'): + if method == 'PUT': + self.assertRaises(ValueError, logging.handlers.HTTPHandler, + 'localhost', '/log', method) + else: + h = logging.handlers.HTTPHandler('localhost', '/log', method) + h.close() + h = logging.handlers.BufferingHandler(0) + r = logging.makeLogRecord({}) + self.assertTrue(h.shouldFlush(r)) + h.close() + h = logging.handlers.BufferingHandler(1) + self.assertFalse(h.shouldFlush(r)) + h.close() + ++ @unittest.skipUnless(threading, 'Threading required for this test.') ++ def test_race(self): ++ # Issue #14632 refers. ++ def remove_loop(fname, tries): ++ for _ in range(tries): ++ try: ++ os.unlink(fname) ++ except OSError: ++ pass ++ time.sleep(0.004 * random.randint(0, 4)) ++ ++ def cleanup(remover, fn, handler): ++ handler.close() ++ remover.join() ++ if os.path.exists(fn): ++ os.unlink(fn) ++ ++ fd, fn = tempfile.mkstemp('.log', 'test_logging-3-') ++ os.close(fd) ++ del_count = 1000 ++ log_count = 1000 ++ remover = threading.Thread(target=remove_loop, args=(fn, del_count)) ++ remover.daemon = True ++ remover.start() ++ for delay in (False, True): ++ h = logging.handlers.WatchedFileHandler(fn, delay=delay) ++ self.addCleanup(cleanup, remover, fn, h) ++ f = logging.Formatter('%(asctime)s: %(levelname)s: %(message)s') ++ h.setFormatter(f) ++ for _ in range(log_count): ++ time.sleep(0.005) ++ r = logging.makeLogRecord({'msg': 'testing' }) ++ h.handle(r) ++ ++ +class BadStream(object): + def write(self, data): + raise RuntimeError('deliberate mistake') + +class TestStreamHandler(logging.StreamHandler): + def handleError(self, record): + self.error_record = record + +class StreamHandlerTest(BaseTest): + def test_error_handling(self): + h = TestStreamHandler(BadStream()) + r = logging.makeLogRecord({}) + old_raise = logging.raiseExceptions + old_stderr = sys.stderr + try: + h.handle(r) + self.assertIs(h.error_record, r) + h = logging.StreamHandler(BadStream()) + sys.stderr = sio = io.StringIO() + h.handle(r) + self.assertIn('\nRuntimeError: deliberate mistake\n', + sio.getvalue()) + logging.raiseExceptions = False + sys.stderr = sio = io.StringIO() + h.handle(r) + self.assertEqual('', sio.getvalue()) + finally: + logging.raiseExceptions = old_raise + sys.stderr = old_stderr + +# -- The following section could be moved into a server_helper.py module +# -- if it proves to be of wider utility than just test_logging + +if threading: + class TestSMTPChannel(smtpd.SMTPChannel): + """ + This derived class has had to be created because smtpd does not + support use of custom channel maps, although they are allowed by + asyncore's design. Issue #11959 has been raised to address this, + and if resolved satisfactorily, some of this code can be removed. + """ + def __init__(self, server, conn, addr, sockmap): + asynchat.async_chat.__init__(self, conn, sockmap) + self.smtp_server = server + self.conn = conn + self.addr = addr + self.received_lines = [] + self.smtp_state = self.COMMAND + self.seen_greeting = '' + self.mailfrom = None + self.rcpttos = [] + self.received_data = '' + self.fqdn = socket.getfqdn() + self.num_bytes = 0 + try: + self.peer = conn.getpeername() + except socket.error as err: + # a race condition may occur if the other end is closing + # before we can get the peername + self.close() + if err.args[0] != errno.ENOTCONN: + raise + return + self.push('220 %s %s' % (self.fqdn, smtpd.__version__)) + self.set_terminator(b'\r\n') + + + class TestSMTPServer(smtpd.SMTPServer): + """ + This class implements a test SMTP server. + + :param addr: A (host, port) tuple which the server listens on. + You can specify a port value of zero: the server's + *port* attribute will hold the actual port number + used, which can be used in client connections. + :param handler: A callable which will be called to process + incoming messages. The handler will be passed + the client address tuple, who the message is from, + a list of recipients and the message data. + :param poll_interval: The interval, in seconds, used in the underlying + :func:`select` or :func:`poll` call by + :func:`asyncore.loop`. + :param sockmap: A dictionary which will be used to hold + :class:`asyncore.dispatcher` instances used by + :func:`asyncore.loop`. This avoids changing the + :mod:`asyncore` module's global state. + """ + channel_class = TestSMTPChannel + + def __init__(self, addr, handler, poll_interval, sockmap): + self._localaddr = addr + self._remoteaddr = None + self.sockmap = sockmap + asyncore.dispatcher.__init__(self, map=sockmap) + try: + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.setblocking(0) + self.set_socket(sock, map=sockmap) + # try to re-use a server port if possible + self.set_reuse_addr() + self.bind(addr) + self.port = sock.getsockname()[1] + self.listen(5) + except: + self.close() + raise + self._handler = handler + self._thread = None + self.poll_interval = poll_interval + + def handle_accepted(self, conn, addr): + """ + Redefined only because the base class does not pass in a + map, forcing use of a global in :mod:`asyncore`. + """ + channel = self.channel_class(self, conn, addr, self.sockmap) + + def process_message(self, peer, mailfrom, rcpttos, data): + """ + Delegates to the handler passed in to the server's constructor. + + Typically, this will be a test case method. + :param peer: The client (host, port) tuple. + :param mailfrom: The address of the sender. + :param rcpttos: The addresses of the recipients. + :param data: The message. + """ + self._handler(peer, mailfrom, rcpttos, data) + + def start(self): + """ + Start the server running on a separate daemon thread. + """ + self._thread = t = threading.Thread(target=self.serve_forever, + args=(self.poll_interval,)) + t.setDaemon(True) + t.start() + + def serve_forever(self, poll_interval): + """ + Run the :mod:`asyncore` loop until normal termination + conditions arise. + :param poll_interval: The interval, in seconds, used in the underlying + :func:`select` or :func:`poll` call by + :func:`asyncore.loop`. + """ + try: + asyncore.loop(poll_interval, map=self.sockmap) + except select.error: + # On FreeBSD 8, closing the server repeatably + # raises this error. We swallow it if the + # server has been closed. + if self.connected or self.accepting: + raise + + def stop(self, timeout=None): + """ + Stop the thread by closing the server instance. + Wait for the server thread to terminate. + + :param timeout: How long to wait for the server thread + to terminate. + """ + self.close() + self._thread.join(timeout) + self._thread = None + + class ControlMixin(object): + """ + This mixin is used to start a server on a separate thread, and + shut it down programmatically. Request handling is simplified - instead + of needing to derive a suitable RequestHandler subclass, you just + provide a callable which will be passed each received request to be + processed. + + :param handler: A handler callable which will be called with a + single parameter - the request - in order to + process the request. This handler is called on the + server thread, effectively meaning that requests are + processed serially. While not quite Web scale ;-), + this should be fine for testing applications. + :param poll_interval: The polling interval in seconds. + """ + def __init__(self, handler, poll_interval): + self._thread = None + self.poll_interval = poll_interval + self._handler = handler + self.ready = threading.Event() + + def start(self): + """ + Create a daemon thread to run the server, and start it. + """ + self._thread = t = threading.Thread(target=self.serve_forever, + args=(self.poll_interval,)) + t.setDaemon(True) + t.start() + + def serve_forever(self, poll_interval): + """ + Run the server. Set the ready flag before entering the + service loop. + """ + self.ready.set() + super(ControlMixin, self).serve_forever(poll_interval) + + def stop(self, timeout=None): + """ + Tell the server thread to stop, and wait for it to do so. + + :param timeout: How long to wait for the server thread + to terminate. + """ + self.shutdown() + if self._thread is not None: + self._thread.join(timeout) + self._thread = None + self.server_close() + self.ready.clear() + + class TestHTTPServer(ControlMixin, HTTPServer): + """ + An HTTP server which is controllable using :class:`ControlMixin`. + + :param addr: A tuple with the IP address and port to listen on. + :param handler: A handler callable which will be called with a + single parameter - the request - in order to + process the request. + :param poll_interval: The polling interval in seconds. + :param log: Pass ``True`` to enable log messages. + """ + def __init__(self, addr, handler, poll_interval=0.5, + log=False, sslctx=None): + class DelegatingHTTPRequestHandler(BaseHTTPRequestHandler): + def __getattr__(self, name, default=None): + if name.startswith('do_'): + return self.process_request + raise AttributeError(name) + + def process_request(self): + self.server._handler(self) + + def log_message(self, format, *args): + if log: + super(DelegatingHTTPRequestHandler, + self).log_message(format, *args) + HTTPServer.__init__(self, addr, DelegatingHTTPRequestHandler) + ControlMixin.__init__(self, handler, poll_interval) + self.sslctx = sslctx + + def get_request(self): + try: + sock, addr = self.socket.accept() + if self.sslctx: + sock = self.sslctx.wrap_socket(sock, server_side=True) + except socket.error as e: + # socket errors are silenced by the caller, print them here + sys.stderr.write("Got an error:\n%s\n" % e) + raise + return sock, addr + + class TestTCPServer(ControlMixin, ThreadingTCPServer): + """ + A TCP server which is controllable using :class:`ControlMixin`. + + :param addr: A tuple with the IP address and port to listen on. + :param handler: A handler callable which will be called with a single + parameter - the request - in order to process the request. + :param poll_interval: The polling interval in seconds. + :bind_and_activate: If True (the default), binds the server and starts it + listening. If False, you need to call + :meth:`server_bind` and :meth:`server_activate` at + some later time before calling :meth:`start`, so that + the server will set up the socket and listen on it. + """ + + allow_reuse_address = True + + def __init__(self, addr, handler, poll_interval=0.5, + bind_and_activate=True): + class DelegatingTCPRequestHandler(StreamRequestHandler): + + def handle(self): + self.server._handler(self) + ThreadingTCPServer.__init__(self, addr, DelegatingTCPRequestHandler, + bind_and_activate) + ControlMixin.__init__(self, handler, poll_interval) + + def server_bind(self): + super(TestTCPServer, self).server_bind() + self.port = self.socket.getsockname()[1] + + class TestUDPServer(ControlMixin, ThreadingUDPServer): + """ + A UDP server which is controllable using :class:`ControlMixin`. + + :param addr: A tuple with the IP address and port to listen on. + :param handler: A handler callable which will be called with a + single parameter - the request - in order to + process the request. + :param poll_interval: The polling interval for shutdown requests, + in seconds. + :bind_and_activate: If True (the default), binds the server and + starts it listening. If False, you need to + call :meth:`server_bind` and + :meth:`server_activate` at some later time + before calling :meth:`start`, so that the server will + set up the socket and listen on it. + """ + def __init__(self, addr, handler, poll_interval=0.5, + bind_and_activate=True): + class DelegatingUDPRequestHandler(DatagramRequestHandler): + + def handle(self): + self.server._handler(self) + + def finish(self): + data = self.wfile.getvalue() + if data: + try: + super(DelegatingUDPRequestHandler, self).finish() + except socket.error: + if not self.server._closed: + raise + + ThreadingUDPServer.__init__(self, addr, + DelegatingUDPRequestHandler, + bind_and_activate) + ControlMixin.__init__(self, handler, poll_interval) + self._closed = False + + def server_bind(self): + super(TestUDPServer, self).server_bind() + self.port = self.socket.getsockname()[1] + + def server_close(self): + super(TestUDPServer, self).server_close() + self._closed = True + +# - end of server_helper section + +@unittest.skipUnless(threading, 'Threading required for this test.') +class SMTPHandlerTest(BaseTest): + def test_basic(self): + sockmap = {} + server = TestSMTPServer(('localhost', 0), self.process_message, 0.001, + sockmap) + server.start() + addr = ('localhost', server.port) + h = logging.handlers.SMTPHandler(addr, 'me', 'you', 'Log', timeout=5.0) + self.assertEqual(h.toaddrs, ['you']) + self.messages = [] + r = logging.makeLogRecord({'msg': 'Hello'}) + self.handled = threading.Event() + h.handle(r) + self.handled.wait(5.0) # 14314: don't wait forever + server.stop() + self.assertTrue(self.handled.is_set()) + self.assertEqual(len(self.messages), 1) + peer, mailfrom, rcpttos, data = self.messages[0] + self.assertEqual(mailfrom, 'me') + self.assertEqual(rcpttos, ['you']) + self.assertIn('\nSubject: Log\n', data) + self.assertTrue(data.endswith('\n\nHello')) + h.close() + + def process_message(self, *args): + self.messages.append(args) + self.handled.set() + class MemoryHandlerTest(BaseTest): """Tests for the MemoryHandler."""