]> granicus.if.org Git - python/commitdiff
Issue #14632: Updated WatchedFileHandler to deal with race condition. Thanks to John...
authorVinay Sajip <vinay_sajip@yahoo.co.uk>
Tue, 24 Apr 2012 22:33:33 +0000 (23:33 +0100)
committerVinay Sajip <vinay_sajip@yahoo.co.uk>
Tue, 24 Apr 2012 22:33:33 +0000 (23:33 +0100)
1  2 
Lib/logging/handlers.py
Lib/test/test_logging.py

index 9c637974319c703c8883486802d5aac78ec7bc1d,22f8c3446cfaa63854d4d82337e4ac747cf7d190..7a97c807136b1030e1b0d2e8b9f3535765ec4b4a
@@@ -23,8 -23,7 +23,8 @@@ Copyright (C) 2001-2012 Vinay Sajip. Al
  To use, simply 'import logging.handlers' and log away!
  """
  
- import logging, socket, os, pickle, struct, time, re
+ import errno, logging, socket, os, pickle, struct, time, re
 +from codecs import BOM_UTF8
  from stat import ST_DEV, ST_INO, ST_MTIME
  import queue
  try:
@@@ -417,13 -381,15 +417,15 @@@ class WatchedFileHandler(logging.FileHa
      This handler is based on a suggestion and patch by Chad J.
      Schroeder.
      """
 -    def __init__(self, filename, mode='a', encoding=None, delay=0):
 +    def __init__(self, filename, mode='a', encoding=None, delay=False):
          logging.FileHandler.__init__(self, filename, mode, encoding, delay)
-         if not os.path.exists(self.baseFilename):
-             self.dev, self.ino = -1, -1
-         else:
-             stat = os.stat(self.baseFilename)
-             self.dev, self.ino = stat[ST_DEV], stat[ST_INO]
+         self.dev, self.ino = -1, -1
+         self._statstream()
+     def _statstream(self):
+         if self.stream:
+             sres = os.fstat(self.stream.fileno())
+             self.dev, self.ino = sres[ST_DEV], sres[ST_INO]
  
      def emit(self, record):
          """
index 75b3e4d8142968f759e4348290a1c76c138014f5,ca4d9309f27c43de804a4d7209ae6e5b82e13c3d..ee1c2113cac68bc71563a0f53eddddec19b79563
@@@ -518,439 -498,6 +519,474 @@@ class CustomLevelsAndFiltersTest(BaseTe
              handler.removeFilter(garr)
  
  
-                     self.assertNotEqual(dev, -1)
-                     self.assertNotEqual(ino, -1)
 +class HandlerTest(BaseTest):
 +    def test_name(self):
 +        h = logging.Handler()
 +        h.name = 'generic'
 +        self.assertEqual(h.name, 'generic')
 +        h.name = 'anothergeneric'
 +        self.assertEqual(h.name, 'anothergeneric')
 +        self.assertRaises(NotImplementedError, h.emit, None)
 +
 +    def test_builtin_handlers(self):
 +        # We can't actually *use* too many handlers in the tests,
 +        # but we can try instantiating them with various options
 +        if sys.platform in ('linux', 'darwin'):
 +            for existing in (True, False):
 +                fd, fn = tempfile.mkstemp()
 +                os.close(fd)
 +                if not existing:
 +                    os.unlink(fn)
 +                h = logging.handlers.WatchedFileHandler(fn, delay=True)
 +                if existing:
 +                    dev, ino = h.dev, h.ino
++                    self.assertEqual(dev, -1)
++                    self.assertEqual(ino, -1)
 +                    r = logging.makeLogRecord({'msg': 'Test'})
 +                    h.handle(r)
 +                    # Now remove the file.
 +                    os.unlink(fn)
 +                    self.assertFalse(os.path.exists(fn))
 +                    # The next call should recreate the file.
 +                    h.handle(r)
 +                    self.assertTrue(os.path.exists(fn))
 +                else:
 +                    self.assertEqual(h.dev, -1)
 +                    self.assertEqual(h.ino, -1)
 +                h.close()
 +                if existing:
 +                    os.unlink(fn)
 +            if sys.platform == 'darwin':
 +                sockname = '/var/run/syslog'
 +            else:
 +                sockname = '/dev/log'
 +            try:
 +                h = logging.handlers.SysLogHandler(sockname)
 +                self.assertEqual(h.facility, h.LOG_USER)
 +                self.assertTrue(h.unixsocket)
 +                h.close()
 +            except socket.error: # syslogd might not be available
 +                pass
 +        for method in ('GET', 'POST', 'PUT'):
 +            if method == 'PUT':
 +                self.assertRaises(ValueError, logging.handlers.HTTPHandler,
 +                                  'localhost', '/log', method)
 +            else:
 +                h = logging.handlers.HTTPHandler('localhost', '/log', method)
 +                h.close()
 +        h = logging.handlers.BufferingHandler(0)
 +        r = logging.makeLogRecord({})
 +        self.assertTrue(h.shouldFlush(r))
 +        h.close()
 +        h = logging.handlers.BufferingHandler(1)
 +        self.assertFalse(h.shouldFlush(r))
 +        h.close()
 +
++    @unittest.skipUnless(threading, 'Threading required for this test.')
++    def test_race(self):
++        # Issue #14632 refers.
++        def remove_loop(fname, tries):
++            for _ in range(tries):
++                try:
++                    os.unlink(fname)
++                except OSError:
++                    pass
++                time.sleep(0.004 * random.randint(0, 4))
++
++        def cleanup(remover, fn, handler):
++            handler.close()
++            remover.join()
++            if os.path.exists(fn):
++                os.unlink(fn)
++
++        fd, fn = tempfile.mkstemp('.log', 'test_logging-3-')
++        os.close(fd)
++        del_count = 1000
++        log_count = 1000
++        remover = threading.Thread(target=remove_loop, args=(fn, del_count))
++        remover.daemon = True
++        remover.start()
++        for delay in (False, True):
++            h = logging.handlers.WatchedFileHandler(fn, delay=delay)
++            self.addCleanup(cleanup, remover, fn, h)
++            f = logging.Formatter('%(asctime)s: %(levelname)s: %(message)s')
++            h.setFormatter(f)
++            for _ in range(log_count):
++                time.sleep(0.005)
++                r = logging.makeLogRecord({'msg': 'testing' })
++                h.handle(r)
++
++
 +class BadStream(object):
 +    def write(self, data):
 +        raise RuntimeError('deliberate mistake')
 +
 +class TestStreamHandler(logging.StreamHandler):
 +    def handleError(self, record):
 +        self.error_record = record
 +
 +class StreamHandlerTest(BaseTest):
 +    def test_error_handling(self):
 +        h = TestStreamHandler(BadStream())
 +        r = logging.makeLogRecord({})
 +        old_raise = logging.raiseExceptions
 +        old_stderr = sys.stderr
 +        try:
 +            h.handle(r)
 +            self.assertIs(h.error_record, r)
 +            h = logging.StreamHandler(BadStream())
 +            sys.stderr = sio = io.StringIO()
 +            h.handle(r)
 +            self.assertIn('\nRuntimeError: deliberate mistake\n',
 +                          sio.getvalue())
 +            logging.raiseExceptions = False
 +            sys.stderr = sio = io.StringIO()
 +            h.handle(r)
 +            self.assertEqual('', sio.getvalue())
 +        finally:
 +            logging.raiseExceptions = old_raise
 +            sys.stderr = old_stderr
 +
 +# -- The following section could be moved into a server_helper.py module
 +# -- if it proves to be of wider utility than just test_logging
 +
 +if threading:
 +    class TestSMTPChannel(smtpd.SMTPChannel):
 +        """
 +        This derived class has had to be created because smtpd does not
 +        support use of custom channel maps, although they are allowed by
 +        asyncore's design. Issue #11959 has been raised to address this,
 +        and if resolved satisfactorily, some of this code can be removed.
 +        """
 +        def __init__(self, server, conn, addr, sockmap):
 +            asynchat.async_chat.__init__(self, conn, sockmap)
 +            self.smtp_server = server
 +            self.conn = conn
 +            self.addr = addr
 +            self.received_lines = []
 +            self.smtp_state = self.COMMAND
 +            self.seen_greeting = ''
 +            self.mailfrom = None
 +            self.rcpttos = []
 +            self.received_data = ''
 +            self.fqdn = socket.getfqdn()
 +            self.num_bytes = 0
 +            try:
 +                self.peer = conn.getpeername()
 +            except socket.error as err:
 +                # a race condition  may occur if the other end is closing
 +                # before we can get the peername
 +                self.close()
 +                if err.args[0] != errno.ENOTCONN:
 +                    raise
 +                return
 +            self.push('220 %s %s' % (self.fqdn, smtpd.__version__))
 +            self.set_terminator(b'\r\n')
 +
 +
 +    class TestSMTPServer(smtpd.SMTPServer):
 +        """
 +        This class implements a test SMTP server.
 +
 +        :param addr: A (host, port) tuple which the server listens on.
 +                     You can specify a port value of zero: the server's
 +                     *port* attribute will hold the actual port number
 +                     used, which can be used in client connections.
 +        :param handler: A callable which will be called to process
 +                        incoming messages. The handler will be passed
 +                        the client address tuple, who the message is from,
 +                        a list of recipients and the message data.
 +        :param poll_interval: The interval, in seconds, used in the underlying
 +                              :func:`select` or :func:`poll` call by
 +                              :func:`asyncore.loop`.
 +        :param sockmap: A dictionary which will be used to hold
 +                        :class:`asyncore.dispatcher` instances used by
 +                        :func:`asyncore.loop`. This avoids changing the
 +                        :mod:`asyncore` module's global state.
 +        """
 +        channel_class = TestSMTPChannel
 +
 +        def __init__(self, addr, handler, poll_interval, sockmap):
 +            self._localaddr = addr
 +            self._remoteaddr = None
 +            self.sockmap = sockmap
 +            asyncore.dispatcher.__init__(self, map=sockmap)
 +            try:
 +                sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
 +                sock.setblocking(0)
 +                self.set_socket(sock, map=sockmap)
 +                # try to re-use a server port if possible
 +                self.set_reuse_addr()
 +                self.bind(addr)
 +                self.port = sock.getsockname()[1]
 +                self.listen(5)
 +            except:
 +                self.close()
 +                raise
 +            self._handler = handler
 +            self._thread = None
 +            self.poll_interval = poll_interval
 +
 +        def handle_accepted(self, conn, addr):
 +            """
 +            Redefined only because the base class does not pass in a
 +            map, forcing use of a global in :mod:`asyncore`.
 +            """
 +            channel = self.channel_class(self, conn, addr, self.sockmap)
 +
 +        def process_message(self, peer, mailfrom, rcpttos, data):
 +            """
 +            Delegates to the handler passed in to the server's constructor.
 +
 +            Typically, this will be a test case method.
 +            :param peer: The client (host, port) tuple.
 +            :param mailfrom: The address of the sender.
 +            :param rcpttos: The addresses of the recipients.
 +            :param data: The message.
 +            """
 +            self._handler(peer, mailfrom, rcpttos, data)
 +
 +        def start(self):
 +            """
 +            Start the server running on a separate daemon thread.
 +            """
 +            self._thread = t = threading.Thread(target=self.serve_forever,
 +                                                args=(self.poll_interval,))
 +            t.setDaemon(True)
 +            t.start()
 +
 +        def serve_forever(self, poll_interval):
 +            """
 +            Run the :mod:`asyncore` loop until normal termination
 +            conditions arise.
 +            :param poll_interval: The interval, in seconds, used in the underlying
 +                                  :func:`select` or :func:`poll` call by
 +                                  :func:`asyncore.loop`.
 +            """
 +            try:
 +                asyncore.loop(poll_interval, map=self.sockmap)
 +            except select.error:
 +                # On FreeBSD 8, closing the server repeatably
 +                # raises this error. We swallow it if the
 +                # server has been closed.
 +                if self.connected or self.accepting:
 +                    raise
 +
 +        def stop(self, timeout=None):
 +            """
 +            Stop the thread by closing the server instance.
 +            Wait for the server thread to terminate.
 +
 +            :param timeout: How long to wait for the server thread
 +                            to terminate.
 +            """
 +            self.close()
 +            self._thread.join(timeout)
 +            self._thread = None
 +
 +    class ControlMixin(object):
 +        """
 +        This mixin is used to start a server on a separate thread, and
 +        shut it down programmatically. Request handling is simplified - instead
 +        of needing to derive a suitable RequestHandler subclass, you just
 +        provide a callable which will be passed each received request to be
 +        processed.
 +
 +        :param handler: A handler callable which will be called with a
 +                        single parameter - the request - in order to
 +                        process the request. This handler is called on the
 +                        server thread, effectively meaning that requests are
 +                        processed serially. While not quite Web scale ;-),
 +                        this should be fine for testing applications.
 +        :param poll_interval: The polling interval in seconds.
 +        """
 +        def __init__(self, handler, poll_interval):
 +            self._thread = None
 +            self.poll_interval = poll_interval
 +            self._handler = handler
 +            self.ready = threading.Event()
 +
 +        def start(self):
 +            """
 +            Create a daemon thread to run the server, and start it.
 +            """
 +            self._thread = t = threading.Thread(target=self.serve_forever,
 +                                                args=(self.poll_interval,))
 +            t.setDaemon(True)
 +            t.start()
 +
 +        def serve_forever(self, poll_interval):
 +            """
 +            Run the server. Set the ready flag before entering the
 +            service loop.
 +            """
 +            self.ready.set()
 +            super(ControlMixin, self).serve_forever(poll_interval)
 +
 +        def stop(self, timeout=None):
 +            """
 +            Tell the server thread to stop, and wait for it to do so.
 +
 +            :param timeout: How long to wait for the server thread
 +                            to terminate.
 +            """
 +            self.shutdown()
 +            if self._thread is not None:
 +                self._thread.join(timeout)
 +                self._thread = None
 +            self.server_close()
 +            self.ready.clear()
 +
 +    class TestHTTPServer(ControlMixin, HTTPServer):
 +        """
 +        An HTTP server which is controllable using :class:`ControlMixin`.
 +
 +        :param addr: A tuple with the IP address and port to listen on.
 +        :param handler: A handler callable which will be called with a
 +                        single parameter - the request - in order to
 +                        process the request.
 +        :param poll_interval: The polling interval in seconds.
 +        :param log: Pass ``True`` to enable log messages.
 +        """
 +        def __init__(self, addr, handler, poll_interval=0.5,
 +                     log=False, sslctx=None):
 +            class DelegatingHTTPRequestHandler(BaseHTTPRequestHandler):
 +                def __getattr__(self, name, default=None):
 +                    if name.startswith('do_'):
 +                        return self.process_request
 +                    raise AttributeError(name)
 +
 +                def process_request(self):
 +                    self.server._handler(self)
 +
 +                def log_message(self, format, *args):
 +                    if log:
 +                        super(DelegatingHTTPRequestHandler,
 +                              self).log_message(format, *args)
 +            HTTPServer.__init__(self, addr, DelegatingHTTPRequestHandler)
 +            ControlMixin.__init__(self, handler, poll_interval)
 +            self.sslctx = sslctx
 +
 +        def get_request(self):
 +            try:
 +                sock, addr = self.socket.accept()
 +                if self.sslctx:
 +                    sock = self.sslctx.wrap_socket(sock, server_side=True)
 +            except socket.error as e:
 +                # socket errors are silenced by the caller, print them here
 +                sys.stderr.write("Got an error:\n%s\n" % e)
 +                raise
 +            return sock, addr
 +
 +    class TestTCPServer(ControlMixin, ThreadingTCPServer):
 +        """
 +        A TCP server which is controllable using :class:`ControlMixin`.
 +
 +        :param addr: A tuple with the IP address and port to listen on.
 +        :param handler: A handler callable which will be called with a single
 +                        parameter - the request - in order to process the request.
 +        :param poll_interval: The polling interval in seconds.
 +        :bind_and_activate: If True (the default), binds the server and starts it
 +                            listening. If False, you need to call
 +                            :meth:`server_bind` and :meth:`server_activate` at
 +                            some later time before calling :meth:`start`, so that
 +                            the server will set up the socket and listen on it.
 +        """
 +
 +        allow_reuse_address = True
 +
 +        def __init__(self, addr, handler, poll_interval=0.5,
 +                     bind_and_activate=True):
 +            class DelegatingTCPRequestHandler(StreamRequestHandler):
 +
 +                def handle(self):
 +                    self.server._handler(self)
 +            ThreadingTCPServer.__init__(self, addr, DelegatingTCPRequestHandler,
 +                                        bind_and_activate)
 +            ControlMixin.__init__(self, handler, poll_interval)
 +
 +        def server_bind(self):
 +            super(TestTCPServer, self).server_bind()
 +            self.port = self.socket.getsockname()[1]
 +
 +    class TestUDPServer(ControlMixin, ThreadingUDPServer):
 +        """
 +        A UDP server which is controllable using :class:`ControlMixin`.
 +
 +        :param addr: A tuple with the IP address and port to listen on.
 +        :param handler: A handler callable which will be called with a
 +                        single parameter - the request - in order to
 +                        process the request.
 +        :param poll_interval: The polling interval for shutdown requests,
 +                              in seconds.
 +        :bind_and_activate: If True (the default), binds the server and
 +                            starts it listening. If False, you need to
 +                            call :meth:`server_bind` and
 +                            :meth:`server_activate` at some later time
 +                            before calling :meth:`start`, so that the server will
 +                            set up the socket and listen on it.
 +        """
 +        def __init__(self, addr, handler, poll_interval=0.5,
 +                     bind_and_activate=True):
 +            class DelegatingUDPRequestHandler(DatagramRequestHandler):
 +
 +                def handle(self):
 +                    self.server._handler(self)
 +
 +                def finish(self):
 +                    data = self.wfile.getvalue()
 +                    if data:
 +                        try:
 +                            super(DelegatingUDPRequestHandler, self).finish()
 +                        except socket.error:
 +                            if not self.server._closed:
 +                                raise
 +
 +            ThreadingUDPServer.__init__(self, addr,
 +                                        DelegatingUDPRequestHandler,
 +                                        bind_and_activate)
 +            ControlMixin.__init__(self, handler, poll_interval)
 +            self._closed = False
 +
 +        def server_bind(self):
 +            super(TestUDPServer, self).server_bind()
 +            self.port = self.socket.getsockname()[1]
 +
 +        def server_close(self):
 +            super(TestUDPServer, self).server_close()
 +            self._closed = True
 +
 +# - end of server_helper section
 +
 +@unittest.skipUnless(threading, 'Threading required for this test.')
 +class SMTPHandlerTest(BaseTest):
 +    def test_basic(self):
 +        sockmap = {}
 +        server = TestSMTPServer(('localhost', 0), self.process_message, 0.001,
 +                                sockmap)
 +        server.start()
 +        addr = ('localhost', server.port)
 +        h = logging.handlers.SMTPHandler(addr, 'me', 'you', 'Log', timeout=5.0)
 +        self.assertEqual(h.toaddrs, ['you'])
 +        self.messages = []
 +        r = logging.makeLogRecord({'msg': 'Hello'})
 +        self.handled = threading.Event()
 +        h.handle(r)
 +        self.handled.wait(5.0)  # 14314: don't wait forever
 +        server.stop()
 +        self.assertTrue(self.handled.is_set())
 +        self.assertEqual(len(self.messages), 1)
 +        peer, mailfrom, rcpttos, data = self.messages[0]
 +        self.assertEqual(mailfrom, 'me')
 +        self.assertEqual(rcpttos, ['you'])
 +        self.assertIn('\nSubject: Log\n', data)
 +        self.assertTrue(data.endswith('\n\nHello'))
 +        h.close()
 +
 +    def process_message(self, *args):
 +        self.messages.append(args)
 +        self.handled.set()
 +
  class MemoryHandlerTest(BaseTest):
  
      """Tests for the MemoryHandler."""