]> granicus.if.org Git - python/commitdiff
Add a BufferedIncrementalEncoder class that can be used for implementing
authorWalter Dörwald <walter@livinglogic.de>
Fri, 14 Apr 2006 18:25:39 +0000 (18:25 +0000)
committerWalter Dörwald <walter@livinglogic.de>
Fri, 14 Apr 2006 18:25:39 +0000 (18:25 +0000)
an incremental encoder that must retain part of the data between calls
to the encode() method.

Fix the incremental encoder and decoder for the IDNA encoding.

This closes SF patch #1453235.

Lib/codecs.py
Lib/encodings/idna.py
Lib/test/test_codecs.py

index c138187dd2f5cbc7da2a269069e48c28ceae39ab..1518d75f9d20fd926558077dea05cba49a38662f 100644 (file)
@@ -181,6 +181,33 @@ class IncrementalEncoder(object):
         Resets the encoder to the initial state.
         """
 
+class BufferedIncrementalEncoder(IncrementalEncoder):
+    """
+    This subclass of IncrementalEncoder can be used as the baseclass for an
+    incremental encoder if the encoder must keep some of the output in a
+    buffer between calls to encode().
+    """
+    def __init__(self, errors='strict'):
+        IncrementalEncoder.__init__(self, errors)
+        self.buffer = "" # unencoded input that is kept between calls to encode()
+
+    def _buffer_encode(self, input, errors, final):
+        # Overwrite this method in subclasses: It must encode input
+        # and return an (output, length consumed) tuple
+        raise NotImplementedError
+
+    def encode(self, input, final=False):
+        # encode input (taking the buffer into account)
+        data = self.buffer + input
+        (result, consumed) = self._buffer_encode(data, self.errors, final)
+        # keep unencoded input until the next call
+        self.buffer = data[consumed:]
+        return result
+
+    def reset(self):
+        IncrementalEncoder.reset(self)
+        self.buffer = ""
+
 class IncrementalDecoder(object):
     """
     An IncrementalDecoder decodes an input in multiple steps. The input can be
index 1aa4e965ef54f323f010e79ad32175b94d5f2a76..ea90d67142f6acf68b32c5a59cf431aba04d0798 100644 (file)
@@ -194,13 +194,79 @@ class Codec(codecs.Codec):
 
         return u".".join(result)+trailing_dot, len(input)
 
-class IncrementalEncoder(codecs.IncrementalEncoder):
-    def encode(self, input, final=False):
-        return Codec().encode(input, self.errors)[0]
+class IncrementalEncoder(codecs.BufferedIncrementalEncoder):
+    def _buffer_encode(self, input, errors, final):
+        if errors != 'strict':
+            # IDNA is quite clear that implementations must be strict
+            raise UnicodeError("unsupported error handling "+errors)
+
+        if not input:
+            return ("", 0)
+
+        labels = dots.split(input)
+        trailing_dot = u''
+        if labels:
+            if not labels[-1]:
+                trailing_dot = '.'
+                del labels[-1]
+            elif not final:
+                # Keep potentially unfinished label until the next call
+                del labels[-1]
+                if labels:
+                    trailing_dot = '.'
+
+        result = []
+        size = 0
+        for label in labels:
+            result.append(ToASCII(label))
+            if size:
+                size += 1
+            size += len(label)
+
+        # Join with U+002E
+        result = ".".join(result) + trailing_dot
+        size += len(trailing_dot)
+        return (result, size)
+
+class IncrementalDecoder(codecs.BufferedIncrementalDecoder):
+    def _buffer_decode(self, input, errors, final):
+        if errors != 'strict':
+            raise UnicodeError("Unsupported error handling "+errors)
+
+        if not input:
+            return (u"", 0)
+
+        # IDNA allows decoding to operate on Unicode strings, too.
+        if isinstance(input, unicode):
+            labels = dots.split(input)
+        else:
+            # Must be ASCII string
+            input = str(input)
+            unicode(input, "ascii")
+            labels = input.split(".")
+
+        trailing_dot = u''
+        if labels:
+            if not labels[-1]:
+                trailing_dot = u'.'
+                del labels[-1]
+            elif not final:
+                # Keep potentially unfinished label until the next call
+                del labels[-1]
+                if labels:
+                    trailing_dot = u'.'
+
+        result = []
+        size = 0
+        for label in labels:
+            result.append(ToUnicode(label))
+            if size:
+                size += 1
+            size += len(label)
 
-class IncrementalDecoder(codecs.IncrementalDecoder):
-    def decode(self, input, final=False):
-        return Codec().decode(input, self.errors)[0]
+        result = u".".join(result) + trailing_dot
+        size += len(trailing_dot)
+        return (result, size)
 
 class StreamWriter(Codec,codecs.StreamWriter):
     pass
index 22d9060565025beb8b975659b6141a91bb571968..6ea49cc5f9253345ec3c694bee35de7d5d736e5b 100644 (file)
@@ -781,9 +781,18 @@ class NameprepTest(unittest.TestCase):
                 except Exception,e:
                     raise test_support.TestFailed("Test 3.%d: %s" % (pos+1, str(e)))
 
-class CodecTest(unittest.TestCase):
-    def test_builtin(self):
+class IDNACodecTest(unittest.TestCase):
+    def test_builtin_decode(self):
         self.assertEquals(unicode("python.org", "idna"), u"python.org")
+        self.assertEquals(unicode("python.org.", "idna"), u"python.org.")
+        self.assertEquals(unicode("xn--pythn-mua.org", "idna"), u"pyth\xf6n.org")
+        self.assertEquals(unicode("xn--pythn-mua.org.", "idna"), u"pyth\xf6n.org.")
+
+    def test_builtin_encode(self):
+        self.assertEquals(u"python.org".encode("idna"), "python.org")
+        self.assertEquals("python.org.".encode("idna"), "python.org.")
+        self.assertEquals(u"pyth\xf6n.org".encode("idna"), "xn--pythn-mua.org")
+        self.assertEquals(u"pyth\xf6n.org.".encode("idna"), "xn--pythn-mua.org.")
 
     def test_stream(self):
         import StringIO
@@ -791,6 +800,64 @@ class CodecTest(unittest.TestCase):
         r.read(3)
         self.assertEquals(r.read(), u"")
 
+    def test_incremental_decode(self):
+        self.assertEquals(
+            "".join(codecs.iterdecode("python.org", "idna")),
+            u"python.org"
+        )
+        self.assertEquals(
+            "".join(codecs.iterdecode("python.org.", "idna")),
+            u"python.org."
+        )
+        self.assertEquals(
+            "".join(codecs.iterdecode("xn--pythn-mua.org.", "idna")),
+            u"pyth\xf6n.org."
+        )
+        self.assertEquals(
+            "".join(codecs.iterdecode("xn--pythn-mua.org.", "idna")),
+            u"pyth\xf6n.org."
+        )
+
+        decoder = codecs.getincrementaldecoder("idna")()
+        self.assertEquals(decoder.decode("xn--xam", ), u"")
+        self.assertEquals(decoder.decode("ple-9ta.o", ), u"\xe4xample.")
+        self.assertEquals(decoder.decode(u"rg"), u"")
+        self.assertEquals(decoder.decode(u"", True), u"org")
+
+        decoder.reset()
+        self.assertEquals(decoder.decode("xn--xam", ), u"")
+        self.assertEquals(decoder.decode("ple-9ta.o", ), u"\xe4xample.")
+        self.assertEquals(decoder.decode("rg."), u"org.")
+        self.assertEquals(decoder.decode("", True), u"")
+
+    def test_incremental_encode(self):
+        self.assertEquals(
+            "".join(codecs.iterencode(u"python.org", "idna")),
+            "python.org"
+        )
+        self.assertEquals(
+            "".join(codecs.iterencode(u"python.org.", "idna")),
+            "python.org."
+        )
+        self.assertEquals(
+            "".join(codecs.iterencode(u"pyth\xf6n.org.", "idna")),
+            "xn--pythn-mua.org."
+        )
+        self.assertEquals(
+            "".join(codecs.iterencode(u"pyth\xf6n.org.", "idna")),
+            "xn--pythn-mua.org."
+        )
+
+        encoder = codecs.getincrementalencoder("idna")()
+        self.assertEquals(encoder.encode(u"\xe4x"), "")
+        self.assertEquals(encoder.encode(u"ample.org"), "xn--xample-9ta.")
+        self.assertEquals(encoder.encode(u"", True), "org")
+
+        encoder.reset()
+        self.assertEquals(encoder.encode(u"\xe4x"), "")
+        self.assertEquals(encoder.encode(u"ample.org."), "xn--xample-9ta.org.")
+        self.assertEquals(encoder.encode(u"", True), "")
+
 class CodecsModuleTest(unittest.TestCase):
 
     def test_decode(self):
@@ -1158,7 +1225,7 @@ def test_main():
         PunycodeTest,
         UnicodeInternalTest,
         NameprepTest,
-        CodecTest,
+        IDNACodecTest,
         CodecsModuleTest,
         StreamReaderTest,
         Str2StrTest,