]> granicus.if.org Git - python/commitdiff
Make TextIOWrapper's seek/tell work properly with stateful decoders;
authorKa-Ping Yee <ping@zesty.ca>
Tue, 18 Mar 2008 04:51:32 +0000 (04:51 +0000)
committerKa-Ping Yee <ping@zesty.ca>
Tue, 18 Mar 2008 04:51:32 +0000 (04:51 +0000)
document and rename things to make seek/tell workings a little clearer.

Add a weird decoder for testing TextIOWrapper's seek/tell methods.

Document the getstate/setstate protocol conventions for IncrementalDecoders.

Lib/codecs.py
Lib/io.py
Lib/test/test_io.py

index 6d8d5544edd65001dc9dbb4ec0a249644f5dc561..f05c4f7e2db4c18c5b0db16654642586301b2316 100644 (file)
@@ -237,7 +237,7 @@ class IncrementalDecoder(object):
     """
     def __init__(self, errors='strict'):
         """
-        Creates a IncrementalDecoder instance.
+        Create a IncrementalDecoder instance.
 
         The IncrementalDecoder may use different error handling schemes by
         providing the errors keyword argument. See the module docstring
@@ -247,28 +247,35 @@ class IncrementalDecoder(object):
 
     def decode(self, input, final=False):
         """
-        Decodes input and returns the resulting object.
+        Decode input and returns the resulting object.
         """
         raise NotImplementedError
 
     def reset(self):
         """
-        Resets the decoder to the initial state.
+        Reset the decoder to the initial state.
         """
 
     def getstate(self):
         """
-        Return the current state of the decoder. This must be a
-        (buffered_input, additional_state_info) tuple.  By convention,
-        additional_state_info should represent the state of the decoder
-        WITHOUT yet having processed the contents of buffered_input.
+        Return the current state of the decoder.
+
+        This must be a (buffered_input, additional_state_info) tuple.
+        buffered_input must be a bytes object containing bytes that
+        were passed to decode() that have not yet been converted.
+        additional_state_info must be a non-negative integer
+        representing the state of the decoder WITHOUT yet having
+        processed the contents of buffered_input.  In the initial state
+        and after reset(), getstate() must return (b"", 0).
         """
         return (b"", 0)
 
     def setstate(self, state):
         """
-        Set the current state of the decoder. state must have been
-        returned by getstate().
+        Set the current state of the decoder.
+
+        state must have been returned by getstate().  The effect of
+        setstate((b"", 0)) must be equivalent to reset().
         """
 
 class BufferedIncrementalDecoder(IncrementalDecoder):
index 98843d381642468bee7ca9ca83d35d9678596af3..d3c9f853f9cb0d1188c93cf3fc4df12dbfbab38b 100644 (file)
--- a/Lib/io.py
+++ b/Lib/io.py
@@ -802,11 +802,10 @@ class BufferedReader(_BufferedIOMixin):
         return self._read_buf
 
     def read1(self, n):
-        """Reads up to n bytes.
+        """Reads up to n bytes, with at most one read() system call.
 
-        Returns up to n bytes.  If at least one byte is buffered,
-        we only return buffered bytes.  Otherwise, we do one
-        raw read.
+        Returns up to n bytes.  If at least one byte is buffered, we
+        only return buffered bytes.  Otherwise, we do one raw read.
         """
         if n <= 0:
             return b""
@@ -1180,10 +1179,24 @@ class TextIOWrapper(TextIOBase):
         self._writenl = newline or os.linesep
         self._encoder = None
         self._decoder = None
-        self._pending = ""
-        self._snapshot = None
+        self._decoded_text = ""  # buffer for text produced by decoder
+        self._snapshot = None  # info for reconstructing decoder state
         self._seekable = self._telling = self.buffer.seekable()
 
+    # A word about _snapshot.  This attribute is either None, or a tuple
+    # (decoder_state, input_chunk, decoded_chars) where decoder_state is
+    # the second (integer) item of the decoder state, input_chunk is the
+    # chunk of bytes that was read, and decoded_chars is the number of
+    # characters rendered by the decoder after feeding it those bytes.
+    # We use this to reconstruct intermediate decoder states in tell().
+
+    # Naming convention:
+    #   - integer variables ending in "_bytes" count input bytes
+    #   - integer variables ending in "_chars" count decoded characters
+
+    def __repr__(self):
+        return '<TIOW %x>' % id(self)
+
     @property
     def encoding(self):
         return self._encoding
@@ -1196,13 +1209,6 @@ class TextIOWrapper(TextIOBase):
     def line_buffering(self):
         return self._line_buffering
 
-    # A word about _snapshot.  This attribute is either None, or a
-    # tuple (decoder_state, readahead, pending) where decoder_state is
-    # the second (integer) item of the decoder state, readahead is the
-    # chunk of bytes that was read, and pending is the characters that
-    # were rendered by the decoder after feeding it those bytes.  We
-    # use this to reconstruct intermediate decoder states in tell().
-
     def seekable(self):
         return self._seekable
 
@@ -1262,126 +1268,199 @@ class TextIOWrapper(TextIOBase):
         return decoder
 
     def _read_chunk(self):
+        """
+        Read and decode the next chunk of data from the BufferedReader.
+
+        Return a tuple of two elements: all the bytes that were read, and
+        the decoded string produced by the decoder.  (The entire input
+        chunk is sent to the decoder, but some of it may remain buffered
+        in the decoder, yet to be converted.)
+        """
+
         if self._decoder is None:
             raise ValueError("no decoder")
         if not self._telling:
-            readahead = self.buffer.read1(self._CHUNK_SIZE)
-            pending = self._decoder.decode(readahead, not readahead)
-            return readahead, pending
-        decoder_buffer, decoder_state = self._decoder.getstate()
-        readahead = self.buffer.read1(self._CHUNK_SIZE)
-        pending = self._decoder.decode(readahead, not readahead)
-        self._snapshot = (decoder_state, decoder_buffer + readahead, pending)
-        return readahead, pending
-
-    def _encode_decoder_state(self, ds, pos):
-        x = 0
-        for i in bytes(ds):
-            x = x<<8 | i
-        return (x<<64) | pos
-
-    def _decode_decoder_state(self, pos):
-        x, pos = divmod(pos, 1<<64)
-        if not x:
-            return None, pos
-        b = b""
-        while x:
-            b.append(x&0xff)
-            x >>= 8
-        return str(b[::-1]), pos
+            # No one should call tell(), so don't bother taking a snapshot.
+            input_chunk = self.buffer.read1(self._CHUNK_SIZE)
+            eof = not input_chunk
+            decoded = self._decoder.decode(input_chunk, eof)
+            return (input_chunk, decoded)
+
+        # The cookie returned by tell() cannot include the contents of
+        # the decoder's buffer, so we need to snapshot a point in the
+        # input where the decoder has nothing in its input buffer.
+
+        dec_buffer, dec_flags = self._decoder.getstate()
+        # The state tuple returned by getstate() contains the decoder's
+        # input buffer and an integer representing any other state.  Thus,
+        # there is a valid snapshot point len(decoder_buffer) bytes ago in
+        # the input, with the state tuple (b'', decoder_state).
+
+        input_chunk = self.buffer.read1(self._CHUNK_SIZE)
+        eof = not input_chunk
+        decoded = self._decoder.decode(input_chunk, eof)
+
+        # At the snapshot point len(dec_buffer) bytes ago, the next input
+        # to be passed to the decoder is dec_buffer + input_chunk.  Save
+        # len(decoded) so that later, tell() can figure out how much
+        # decoded data has been used up by TextIOWrapper.read().
+        self._snapshot = (dec_flags, dec_buffer + input_chunk, len(decoded))
+        return (input_chunk, decoded)
+
+    def _encode_tell_cookie(self, position, dec_flags=0,
+                            feed_bytes=0, need_eof=0, skip_chars=0):
+        # The meaning of a tell() cookie is: seek to position, set the
+        # decoder flags to dec_flags, read feed_bytes bytes, feed them
+        # into the decoder with need_eof as the EOF flag, then skip
+        # skip_chars characters of the decoded result.  For most simple
+        # decoders, this should often just be the position.
+        return (position | (dec_flags<<64) | (feed_bytes<<128) |
+                (skip_chars<<192) | bool(need_eof)<<256)
+
+    def _decode_tell_cookie(self, bigint):
+        rest, position = divmod(bigint, 1<<64)
+        rest, dec_flags = divmod(rest, 1<<64)
+        rest, feed_bytes = divmod(rest, 1<<64)
+        need_eof, skip_chars = divmod(rest, 1<<64)
+        return position, dec_flags, feed_bytes, need_eof, skip_chars
 
     def tell(self):
         if not self._seekable:
-            raise IOError("Underlying stream is not seekable")
+            raise IOError("underlying stream is not seekable")
         if not self._telling:
-            raise IOError("Telling position disabled by next() call")
+            raise IOError("telling position disabled by next() call")
         self.flush()
         position = self.buffer.tell()
         decoder = self._decoder
         if decoder is None or self._snapshot is None:
-            if self._pending:
-                raise ValueError("pending data")
+            if self._decoded_text:
+                # This should never happen.
+                raise AssertionError("pending decoded text")
             return position
-        decoder_state, readahead, pending = self._snapshot
-        position -= len(readahead)
-        needed = len(pending) - len(self._pending)
-        if not needed:
-            return self._encode_decoder_state(decoder_state, position)
+
+        # Skip backward to the snapshot point (see _read_chunk).
+        dec_flags, next_input, decoded_chars = self._snapshot
+        position -= len(next_input)
+
+        # How many decoded characters have been consumed since the snapshot?
+        skip_chars = decoded_chars - len(self._decoded_text)
+        if skip_chars == 0:
+            # We haven't moved from the snapshot point.
+            return self._encode_tell_cookie(position, dec_flags)
+
+        # Walk the decoder forward, one byte at a time, to find the minimum
+        # input necessary to give us the decoded characters we need to skip.
+        # As we go, look for the "safe point" nearest to the current location
+        # (i.e. a point where the decoder has nothing buffered, so we can
+        # safely start from there when trying to return to this location).
         saved_state = decoder.getstate()
         try:
-            decoder.setstate((b"", decoder_state))
-            n = 0
-            bb = bytearray(1)
-            for i, bb[0] in enumerate(readahead):
-                n += len(decoder.decode(bb))
-                if n >= needed:
-                    decoder_buffer, decoder_state = decoder.getstate()
-                    return self._encode_decoder_state(
-                        decoder_state,
-                        position + (i+1) - len(decoder_buffer) - (n - needed))
-            raise IOError("Can't reconstruct logical file position")
+            decoder.setstate((b"", dec_flags))
+            fed_bytes = 0
+            decoded_chars = 0
+            need_eof = 0
+            last_safe_point = (dec_flags, 0, 0)
+
+            next_byte = bytearray(1)
+            for next_byte[0] in next_input:
+                decoded = decoder.decode(next_byte)
+                fed_bytes += 1
+                decoded_chars += len(decoded)
+                dec_buffer, dec_flags = decoder.getstate()
+                if not dec_buffer and decoded_chars <= skip_chars:
+                    # Decoder buffer is empty, so it's safe to start from here.
+                    last_safe_point = (dec_flags, fed_bytes, decoded_chars)
+                if decoded_chars >= skip_chars:
+                    break
+            else:
+                # We didn't get enough decoded data; send EOF to get more.
+                decoded = decoder.decode(b"", True)
+                decoded_chars += len(decoded)
+                need_eof = 1
+                if decoded_chars < skip_chars:
+                    raise IOError("can't reconstruct logical file position")
+
+            # Advance the starting position to the last safe point.
+            dec_flags, safe_fed_bytes, safe_decoded_chars = last_safe_point
+            position += safe_fed_bytes
+            fed_bytes -= safe_fed_bytes
+            skip_chars -= safe_decoded_chars
+            return self._encode_tell_cookie(
+                position, dec_flags, fed_bytes, need_eof, skip_chars)
         finally:
             decoder.setstate(saved_state)
 
-    def seek(self, pos, whence=0):
+    def seek(self, cookie, whence=0):
         if not self._seekable:
-            raise IOError("Underlying stream is not seekable")
-        if whence == 1:
-            if pos != 0:
-                raise IOError("Can't do nonzero cur-relative seeks")
-            pos = self.tell()
+            raise IOError("underlying stream is not seekable")
+        if whence == 1: # seek relative to current position
+            if cookie != 0:
+                raise IOError("can't do nonzero cur-relative seeks")
+            # Seeking to the current position should attempt to
+            # sync the underlying buffer with the current position.
             whence = 0
-        if whence == 2:
-            if pos != 0:
-                raise IOError("Can't do nonzero end-relative seeks")
+            cookie = self.tell()
+        if whence == 2: # seek relative to end of file
+            if cookie != 0:
+                raise IOError("can't do nonzero end-relative seeks")
             self.flush()
-            pos = self.buffer.seek(0, 2)
+            position = self.buffer.seek(0, 2)
+            self._decoded_text = ""
             self._snapshot = None
-            self._pending = ""
             if self._decoder:
                 self._decoder.reset()
-            return pos
+            return position
         if whence != 0:
-            raise ValueError("Invalid whence (%r, should be 0, 1 or 2)" %
+            raise ValueError("invalid whence (%r, should be 0, 1 or 2)" %
                              (whence,))
-        if pos < 0:
-            raise ValueError("Negative seek position %r" % (pos,))
+        if cookie < 0:
+            raise ValueError("negative seek position %r" % (cookie,))
         self.flush()
-        orig_pos = pos
-        ds, pos = self._decode_decoder_state(pos)
-        if not ds:
-            self.buffer.seek(pos)
-            self._snapshot = None
-            self._pending = ""
-            if self._decoder:
-                self._decoder.reset()
-            return pos
-        decoder = self._decoder or self._get_decoder()
-        decoder.set_state(("", ds))
-        self.buffer.seek(pos)
-        self._snapshot = (ds, b"", "")
-        self._pending = ""
-        self._decoder = decoder
-        return orig_pos
+
+        # Seek back to the snapshot point.
+        position, dec_flags, feed_bytes, need_eof, skip_chars = \
+            self._decode_tell_cookie(cookie)
+        self.buffer.seek(position)
+        self._decoded_text = ""
+        self._snapshot = None
+
+        if self._decoder or dec_flags or feed_bytes or need_eof:
+            # Restore the decoder flags to their values from the snapshot.
+            self._decoder = self._decoder or self._get_decoder()
+            self._decoder.setstate((b"", dec_flags))
+
+        if feed_bytes or need_eof:
+            # Feed feed_bytes bytes to the decoder.
+            input_chunk = self.buffer.read(feed_bytes)
+            decoded = self._decoder.decode(input_chunk, need_eof)
+            if len(decoded) < skip_chars:
+                raise IOError("can't restore logical file position")
+
+            # Skip skip_chars of the decoded characters.
+            self._decoded_text = decoded[skip_chars:]
+
+            # Restore the snapshot.
+            self._snapshot = (dec_flags, input_chunk, len(decoded))
+        return cookie
 
     def read(self, n=None):
         if n is None:
             n = -1
         decoder = self._decoder or self._get_decoder()
-        res = self._pending
+        result = self._decoded_text
         if n < 0:
-            res += decoder.decode(self.buffer.read(), True)
-            self._pending = ""
+            result += decoder.decode(self.buffer.read(), True)
+            self._decoded_text = ""
             self._snapshot = None
-            return res
+            return result
         else:
-            while len(res) < n:
-                readahead, pending = self._read_chunk()
-                res += pending
-                if not readahead:
+            while len(result) < n:
+                input_chunk, decoded = self._read_chunk()
+                result += decoded
+                if not input_chunk:
                     break
-            self._pending = res[n:]
-            return res[:n]
+            self._decoded_text = result[n:]
+            return result[:n]
 
     def __next__(self):
         self._telling = False
@@ -1400,10 +1479,11 @@ class TextIOWrapper(TextIOBase):
             line = self.readline()
             if len(line) <= limit:
                 return line
-            line, self._pending = line[:limit], line[limit:] + self._pending
+            line, self._decoded_text = \
+                line[:limit], line[limit:] + self._decoded_text
             return line
 
-        line = self._pending
+        line = self._decoded_text
         start = 0
         decoder = self._decoder or self._get_decoder()
 
@@ -1467,11 +1547,11 @@ class TextIOWrapper(TextIOBase):
                 line += more_line
             else:
                 # end of file
-                self._pending = ''
+                self._decoded_text = ''
                 self._snapshot = None
                 return line
 
-        self._pending = line[endpos:]
+        self._decoded_text = line[endpos:]
         return line[:endpos]
 
     @property
index 0bc2b48bbc46926152f70be86eec6e7387bdcd52..49404e1cbae2452f9fb65111798c830db94b2111 100644 (file)
@@ -8,6 +8,7 @@ import unittest
 from itertools import chain
 from test import test_support
 
+import codecs
 import io  # The module under test
 
 
@@ -486,6 +487,122 @@ class BufferedRandomTest(unittest.TestCase):
         self.assertEquals(b"fl", rw.read(11))
         self.assertRaises(TypeError, rw.seek, 0.0)
 
+# To fully exercise seek/tell, the StatefulIncrementalDecoder has these
+# properties:
+#   - A single output character can correspond to many bytes of input.
+#   - The number of input bytes to complete the character can be
+#     undetermined until the last input byte is received.
+#   - The number of input bytes can vary depending on previous input.
+#   - A single input byte can correspond to many characters of output.
+#   - The number of output characters can be undetermined until the
+#     last input byte is received.
+#   - The number of output characters can vary depending on previous input.
+
+class StatefulIncrementalDecoder(codecs.IncrementalDecoder):
+    """
+    For testing seek/tell behavior with a stateful, buffering decoder.
+
+    Input is a sequence of words.  Words may be fixed-length (length set
+    by input) or variable-length (period-terminated).  In variable-length
+    mode, extra periods are ignored.  Possible words are:
+      - 'i' followed by a number sets the input length, I (maximum 99).
+        When I is set to 0, words are space-terminated.
+      - 'o' followed by a number sets the output length, O (maximum 99).
+      - Any other word is converted into a word followed by a period on
+        the output.  The output word consists of the input word truncated
+        or padded out with hyphens to make its length equal to O.  If O
+        is 0, the word is output verbatim without truncating or padding.
+    I and O are initially set to 1.  When I changes, any buffered input is
+    re-scanned according to the new I.  EOF also terminates the last word.
+    """
+
+    def __init__(self, errors='strict'):
+        codecs.IncrementalEncoder.__init__(self, errors)
+        self.reset()
+
+    def __repr__(self):
+        return '<SID %x>' % id(self)
+
+    def reset(self):
+        self.i = 1
+        self.o = 1
+        self.buffer = bytearray()
+
+    def getstate(self):
+        i, o = self.i ^ 1, self.o ^ 1 # so that flags = 0 after reset()
+        return bytes(self.buffer), i*100 + o
+
+    def setstate(self, state):
+        buffer, io = state
+        self.buffer = bytearray(buffer)
+        i, o = divmod(io, 100)
+        self.i, self.o = i ^ 1, o ^ 1
+
+    def decode(self, input, final=False):
+        output = ''
+        for b in input:
+            if self.i == 0: # variable-length, terminated with period
+                if b == ord('.'):
+                    if self.buffer:
+                        output += self.process_word()
+                else:
+                    self.buffer.append(b)
+            else: # fixed-length, terminate after self.i bytes
+                self.buffer.append(b)
+                if len(self.buffer) == self.i:
+                    output += self.process_word()
+        if final and self.buffer: # EOF terminates the last word
+            output += self.process_word()
+        return output
+
+    def process_word(self):
+        output = ''
+        if self.buffer[0] == ord('i'):
+            self.i = min(99, int(self.buffer[1:] or 0)) # set input length
+        elif self.buffer[0] == ord('o'):
+            self.o = min(99, int(self.buffer[1:] or 0)) # set output length
+        else:
+            output = self.buffer.decode('ascii')
+            if len(output) < self.o:
+                output += '-'*self.o # pad out with hyphens
+            if self.o:
+                output = output[:self.o] # truncate to output length
+            output += '.'
+        self.buffer = bytearray()
+        return output
+
+class StatefulIncrementalDecoderTest(unittest.TestCase):
+    """
+    Make sure the StatefulIncrementalDecoder actually works.
+    """
+
+    test_cases = [
+        # I=1 fixed-length mode
+        (b'abcd', False, 'a.b.c.d.'),
+        # I=0, O=0, variable-length mode
+        (b'oiabcd', True, 'abcd.'),
+        # I=0, O=0, variable-length mode, should ignore extra periods
+        (b'oi...abcd...', True, 'abcd.'),
+        # I=0, O=6
+        (b'i.o6.xyz.', False, 'xyz---.'),
+        # I=2, O=6
+        (b'i.i2.o6xyz', True, 'xy----.z-----.'),
+        # I=0, O=3
+        (b'i.o3.x.xyz.toolong.', False, 'x--.xyz.too.'),
+        # I=6, O=3
+        (b'i.o3.i6.abcdefghijklmnop', True, 'abc.ghi.mno.')
+    ]
+
+    def testDecoder(self):
+        # Try a few one-shot test cases.
+        for input, eof, output in self.test_cases:
+            d = StatefulIncrementalDecoder()
+            self.assertEquals(d.decode(input, eof), output)
+
+        # Also test an unfinished decode, followed by forcing EOF.
+        d = StatefulIncrementalDecoder()
+        self.assertEquals(d.decode(b'oiabcd'), '')
+        self.assertEquals(d.decode(b'', 1), 'abcd.')
 
 class TextIOWrapperTest(unittest.TestCase):
 
@@ -765,6 +882,60 @@ class TextIOWrapperTest(unittest.TestCase):
         f.readline()
         f.tell()
 
+    def testSeekAndTell(self):
+        """Test seek/tell using the StatefulIncrementalDecoder."""
+
+        def lookupTestDecoder(name):
+            if self.codecEnabled and name == 'test_decoder':
+                return codecs.CodecInfo(
+                    name='test_decoder', encode=None, decode=None,
+                    incrementalencoder=None,
+                    streamreader=None, streamwriter=None,
+                    incrementaldecoder=StatefulIncrementalDecoder)
+
+        def testSeekAndTellWithData(data, min_pos=0):
+            """Tell/seek to various points within a data stream and ensure
+            that the decoded data returned by read() is consistent."""
+            f = io.open(test_support.TESTFN, 'wb')
+            f.write(data)
+            f.close()
+            f = io.open(test_support.TESTFN, encoding='test_decoder')
+            decoded = f.read()
+            f.close()
+
+            for i in range(min_pos, len(decoded) + 1): # seek positions
+                for j in [1, 5, len(decoded) - i]: # read lengths
+                    f = io.open(test_support.TESTFN, encoding='test_decoder')
+                    self.assertEquals(f.read(i), decoded[:i])
+                    cookie = f.tell()
+                    self.assertEquals(f.read(j), decoded[i:i + j])
+                    f.seek(cookie)
+                    self.assertEquals(f.read(), decoded[i:])
+                    f.close()
+
+        # Register a special incremental decoder for testing.
+        codecs.register(lookupTestDecoder)
+        self.codecEnabled = 1
+
+        # Run the tests.
+        try:
+            # Try each test case.
+            for input, _, _ in StatefulIncrementalDecoderTest.test_cases:
+                testSeekAndTellWithData(input)
+
+            # Position each test case so that it crosses a chunk boundary.
+            CHUNK_SIZE = io.TextIOWrapper._CHUNK_SIZE
+            for input, _, _ in StatefulIncrementalDecoderTest.test_cases:
+                offset = CHUNK_SIZE - len(input)//2
+                prefix = b'.'*offset
+                # Don't bother seeking into the prefix (takes too long).
+                min_pos = offset*2
+                testSeekAndTellWithData(prefix + input, min_pos)
+
+        # Ensure our test decoder won't interfere with subsequent tests.
+        finally:
+            self.codecEnabled = 0
+
     def testEncodedWrites(self):
         data = "1234567890"
         tests = ("utf-16",