]> granicus.if.org Git - python/commitdiff
Close #1767933: Badly formed XML using etree and utf-16. Patch by Serhiy Storchaka...
authorEli Bendersky <eliben@gmail.com>
Sun, 15 Jul 2012 03:02:22 +0000 (06:02 +0300)
committerEli Bendersky <eliben@gmail.com>
Sun, 15 Jul 2012 03:02:22 +0000 (06:02 +0300)
Doc/library/xml.etree.elementtree.rst
Lib/test/test_xml_etree.py
Lib/xml/etree/ElementTree.py

index 335a6e2cef63a80be83ef764fe89626e631e2d9f..3c2ddd3cd0503aeedb45ad8869bd35809fbd82f2 100644 (file)
@@ -659,7 +659,6 @@ ElementTree Objects
       should be added to the file.  Use False for never, True for always, None
       for only if not US-ASCII or UTF-8 or Unicode (default is None).  *method* is
       either ``"xml"``, ``"html"`` or ``"text"`` (default is ``"xml"``).
-      Returns an (optionally) encoded string.
 
 This is the XML file that is going to be manipulated::
 
index c1fc955da6a7262c25858413ff2514f8e938a9db..d90f9780c529b97524ffbd7731f0331a6c7603a0 100644 (file)
@@ -21,7 +21,7 @@ import unittest
 import weakref
 
 from test import support
-from test.support import findfile, import_fresh_module, gc_collect
+from test.support import TESTFN, findfile, unlink, import_fresh_module, gc_collect
 
 pyET = None
 ET = None
@@ -888,65 +888,6 @@ def check_encoding(encoding):
     """
     ET.XML("<?xml version='1.0' encoding='%s'?><xml />" % encoding)
 
-def encoding():
-    r"""
-    Test encoding issues.
-
-    >>> elem = ET.Element("tag")
-    >>> elem.text = "abc"
-    >>> serialize(elem)
-    '<tag>abc</tag>'
-    >>> serialize(elem, encoding="utf-8")
-    b'<tag>abc</tag>'
-    >>> serialize(elem, encoding="us-ascii")
-    b'<tag>abc</tag>'
-    >>> serialize(elem, encoding="iso-8859-1")
-    b"<?xml version='1.0' encoding='iso-8859-1'?>\n<tag>abc</tag>"
-
-    >>> elem.text = "<&\"\'>"
-    >>> serialize(elem)
-    '<tag>&lt;&amp;"\'&gt;</tag>'
-    >>> serialize(elem, encoding="utf-8")
-    b'<tag>&lt;&amp;"\'&gt;</tag>'
-    >>> serialize(elem, encoding="us-ascii") # cdata characters
-    b'<tag>&lt;&amp;"\'&gt;</tag>'
-    >>> serialize(elem, encoding="iso-8859-1")
-    b'<?xml version=\'1.0\' encoding=\'iso-8859-1\'?>\n<tag>&lt;&amp;"\'&gt;</tag>'
-
-    >>> elem.attrib["key"] = "<&\"\'>"
-    >>> elem.text = None
-    >>> serialize(elem)
-    '<tag key="&lt;&amp;&quot;\'&gt;" />'
-    >>> serialize(elem, encoding="utf-8")
-    b'<tag key="&lt;&amp;&quot;\'&gt;" />'
-    >>> serialize(elem, encoding="us-ascii")
-    b'<tag key="&lt;&amp;&quot;\'&gt;" />'
-    >>> serialize(elem, encoding="iso-8859-1")
-    b'<?xml version=\'1.0\' encoding=\'iso-8859-1\'?>\n<tag key="&lt;&amp;&quot;\'&gt;" />'
-
-    >>> elem.text = '\xe5\xf6\xf6<>'
-    >>> elem.attrib.clear()
-    >>> serialize(elem)
-    '<tag>\xe5\xf6\xf6&lt;&gt;</tag>'
-    >>> serialize(elem, encoding="utf-8")
-    b'<tag>\xc3\xa5\xc3\xb6\xc3\xb6&lt;&gt;</tag>'
-    >>> serialize(elem, encoding="us-ascii")
-    b'<tag>&#229;&#246;&#246;&lt;&gt;</tag>'
-    >>> serialize(elem, encoding="iso-8859-1")
-    b"<?xml version='1.0' encoding='iso-8859-1'?>\n<tag>\xe5\xf6\xf6&lt;&gt;</tag>"
-
-    >>> elem.attrib["key"] = '\xe5\xf6\xf6<>'
-    >>> elem.text = None
-    >>> serialize(elem)
-    '<tag key="\xe5\xf6\xf6&lt;&gt;" />'
-    >>> serialize(elem, encoding="utf-8")
-    b'<tag key="\xc3\xa5\xc3\xb6\xc3\xb6&lt;&gt;" />'
-    >>> serialize(elem, encoding="us-ascii")
-    b'<tag key="&#229;&#246;&#246;&lt;&gt;" />'
-    >>> serialize(elem, encoding="iso-8859-1")
-    b'<?xml version=\'1.0\' encoding=\'iso-8859-1\'?>\n<tag key="\xe5\xf6\xf6&lt;&gt;" />'
-    """
-
 def methods():
     r"""
     Test serialization methods.
@@ -2166,16 +2107,185 @@ class ElementSlicingTest(unittest.TestCase):
         self.assertEqual(self._subelem_tags(e), ['a1'])
 
 
-class StringIOTest(unittest.TestCase):
+class IOTest(unittest.TestCase):
+    def tearDown(self):
+        unlink(TESTFN)
+
+    def test_encoding(self):
+        # Test encoding issues.
+        elem = ET.Element("tag")
+        elem.text = "abc"
+        self.assertEqual(serialize(elem), '<tag>abc</tag>')
+        self.assertEqual(serialize(elem, encoding="utf-8"),
+                b'<tag>abc</tag>')
+        self.assertEqual(serialize(elem, encoding="us-ascii"),
+                b'<tag>abc</tag>')
+        for enc in ("iso-8859-1", "utf-16", "utf-32"):
+            self.assertEqual(serialize(elem, encoding=enc),
+                    ("<?xml version='1.0' encoding='%s'?>\n"
+                     "<tag>abc</tag>" % enc).encode(enc))
+
+        elem = ET.Element("tag")
+        elem.text = "<&\"\'>"
+        self.assertEqual(serialize(elem), '<tag>&lt;&amp;"\'&gt;</tag>')
+        self.assertEqual(serialize(elem, encoding="utf-8"),
+                b'<tag>&lt;&amp;"\'&gt;</tag>')
+        self.assertEqual(serialize(elem, encoding="us-ascii"),
+                b'<tag>&lt;&amp;"\'&gt;</tag>')
+        for enc in ("iso-8859-1", "utf-16", "utf-32"):
+            self.assertEqual(serialize(elem, encoding=enc),
+                    ("<?xml version='1.0' encoding='%s'?>\n"
+                     "<tag>&lt;&amp;\"'&gt;</tag>" % enc).encode(enc))
+
+        elem = ET.Element("tag")
+        elem.attrib["key"] = "<&\"\'>"
+        self.assertEqual(serialize(elem), '<tag key="&lt;&amp;&quot;\'&gt;" />')
+        self.assertEqual(serialize(elem, encoding="utf-8"),
+                b'<tag key="&lt;&amp;&quot;\'&gt;" />')
+        self.assertEqual(serialize(elem, encoding="us-ascii"),
+                b'<tag key="&lt;&amp;&quot;\'&gt;" />')
+        for enc in ("iso-8859-1", "utf-16", "utf-32"):
+            self.assertEqual(serialize(elem, encoding=enc),
+                    ("<?xml version='1.0' encoding='%s'?>\n"
+                     "<tag key=\"&lt;&amp;&quot;'&gt;\" />" % enc).encode(enc))
+
+        elem = ET.Element("tag")
+        elem.text = '\xe5\xf6\xf6<>'
+        self.assertEqual(serialize(elem), '<tag>\xe5\xf6\xf6&lt;&gt;</tag>')
+        self.assertEqual(serialize(elem, encoding="utf-8"),
+                b'<tag>\xc3\xa5\xc3\xb6\xc3\xb6&lt;&gt;</tag>')
+        self.assertEqual(serialize(elem, encoding="us-ascii"),
+                b'<tag>&#229;&#246;&#246;&lt;&gt;</tag>')
+        for enc in ("iso-8859-1", "utf-16", "utf-32"):
+            self.assertEqual(serialize(elem, encoding=enc),
+                    ("<?xml version='1.0' encoding='%s'?>\n"
+                     "<tag>åöö&lt;&gt;</tag>" % enc).encode(enc))
+
+        elem = ET.Element("tag")
+        elem.attrib["key"] = '\xe5\xf6\xf6<>'
+        self.assertEqual(serialize(elem), '<tag key="\xe5\xf6\xf6&lt;&gt;" />')
+        self.assertEqual(serialize(elem, encoding="utf-8"),
+                b'<tag key="\xc3\xa5\xc3\xb6\xc3\xb6&lt;&gt;" />')
+        self.assertEqual(serialize(elem, encoding="us-ascii"),
+                b'<tag key="&#229;&#246;&#246;&lt;&gt;" />')
+        for enc in ("iso-8859-1", "utf-16", "utf-16le", "utf-16be", "utf-32"):
+            self.assertEqual(serialize(elem, encoding=enc),
+                    ("<?xml version='1.0' encoding='%s'?>\n"
+                     "<tag key=\"åöö&lt;&gt;\" />" % enc).encode(enc))
+
+    def test_write_to_filename(self):
+        tree = ET.ElementTree(ET.XML('''<site />'''))
+        tree.write(TESTFN)
+        with open(TESTFN, 'rb') as f:
+            self.assertEqual(f.read(), b'''<site />''')
+
+    def test_write_to_text_file(self):
+        tree = ET.ElementTree(ET.XML('''<site />'''))
+        with open(TESTFN, 'w', encoding='utf-8') as f:
+            tree.write(f, encoding='unicode')
+            self.assertFalse(f.closed)
+        with open(TESTFN, 'rb') as f:
+            self.assertEqual(f.read(), b'''<site />''')
+
+    def test_write_to_binary_file(self):
+        tree = ET.ElementTree(ET.XML('''<site />'''))
+        with open(TESTFN, 'wb') as f:
+            tree.write(f)
+            self.assertFalse(f.closed)
+        with open(TESTFN, 'rb') as f:
+            self.assertEqual(f.read(), b'''<site />''')
+
+    def test_write_to_binary_file_with_bom(self):
+        tree = ET.ElementTree(ET.XML('''<site />'''))
+        # test BOM writing to buffered file
+        with open(TESTFN, 'wb') as f:
+            tree.write(f, encoding='utf-16')
+            self.assertFalse(f.closed)
+        with open(TESTFN, 'rb') as f:
+            self.assertEqual(f.read(),
+                    '''<?xml version='1.0' encoding='utf-16'?>\n'''
+                    '''<site />'''.encode("utf-16"))
+        # test BOM writing to non-buffered file
+        with open(TESTFN, 'wb', buffering=0) as f:
+            tree.write(f, encoding='utf-16')
+            self.assertFalse(f.closed)
+        with open(TESTFN, 'rb') as f:
+            self.assertEqual(f.read(),
+                    '''<?xml version='1.0' encoding='utf-16'?>\n'''
+                    '''<site />'''.encode("utf-16"))
+
     def test_read_from_stringio(self):
         tree = ET.ElementTree()
-        stream = io.StringIO()
-        stream.write('''<?xml version="1.0"?><site></site>''')
-        stream.seek(0)
+        stream = io.StringIO('''<?xml version="1.0"?><site></site>''')
         tree.parse(stream)
+        self.assertEqual(tree.getroot().tag, 'site')
 
+    def test_write_to_stringio(self):
+        tree = ET.ElementTree(ET.XML('''<site />'''))
+        stream = io.StringIO()
+        tree.write(stream, encoding='unicode')
+        self.assertEqual(stream.getvalue(), '''<site />''')
+
+    def test_read_from_bytesio(self):
+        tree = ET.ElementTree()
+        raw = io.BytesIO(b'''<?xml version="1.0"?><site></site>''')
+        tree.parse(raw)
+        self.assertEqual(tree.getroot().tag, 'site')
+
+    def test_write_to_bytesio(self):
+        tree = ET.ElementTree(ET.XML('''<site />'''))
+        raw = io.BytesIO()
+        tree.write(raw)
+        self.assertEqual(raw.getvalue(), b'''<site />''')
+
+    class dummy:
+        pass
+
+    def test_read_from_user_text_reader(self):
+        stream = io.StringIO('''<?xml version="1.0"?><site></site>''')
+        reader = self.dummy()
+        reader.read = stream.read
+        tree = ET.ElementTree()
+        tree.parse(reader)
         self.assertEqual(tree.getroot().tag, 'site')
 
+    def test_write_to_user_text_writer(self):
+        tree = ET.ElementTree(ET.XML('''<site />'''))
+        stream = io.StringIO()
+        writer = self.dummy()
+        writer.write = stream.write
+        tree.write(writer, encoding='unicode')
+        self.assertEqual(stream.getvalue(), '''<site />''')
+
+    def test_read_from_user_binary_reader(self):
+        raw = io.BytesIO(b'''<?xml version="1.0"?><site></site>''')
+        reader = self.dummy()
+        reader.read = raw.read
+        tree = ET.ElementTree()
+        tree.parse(reader)
+        self.assertEqual(tree.getroot().tag, 'site')
+        tree = ET.ElementTree()
+
+    def test_write_to_user_binary_writer(self):
+        tree = ET.ElementTree(ET.XML('''<site />'''))
+        raw = io.BytesIO()
+        writer = self.dummy()
+        writer.write = raw.write
+        tree.write(writer)
+        self.assertEqual(raw.getvalue(), b'''<site />''')
+
+    def test_write_to_user_binary_writer_with_bom(self):
+        tree = ET.ElementTree(ET.XML('''<site />'''))
+        raw = io.BytesIO()
+        writer = self.dummy()
+        writer.write = raw.write
+        writer.seekable = lambda: True
+        writer.tell = raw.tell
+        tree.write(writer, encoding="utf-16")
+        self.assertEqual(raw.getvalue(),
+                '''<?xml version='1.0' encoding='utf-16'?>\n'''
+                '''<site />'''.encode("utf-16"))
+
 
 class ParseErrorTest(unittest.TestCase):
     def test_subclass(self):
@@ -2299,7 +2409,7 @@ def test_main(module=None):
     test_classes = [
         ElementSlicingTest,
         BasicElementTest,
-        StringIOTest,
+        IOTest,
         ParseErrorTest,
         XincludeTest,
         ElementTreeTest,
index 61fe1550355dafc14e24620546056398ec546559..10bf84992ed656d630600dc4a6a22c8f42a166be 100644 (file)
@@ -100,6 +100,8 @@ VERSION = "1.3.0"
 import sys
 import re
 import warnings
+import io
+import contextlib
 
 from . import ElementPath
 
@@ -792,59 +794,38 @@ class ElementTree:
     #     None for only if not US-ASCII or UTF-8 or Unicode.  None is default.
 
     def write(self, file_or_filename,
-              # keyword arguments
               encoding=None,
               xml_declaration=None,
               default_namespace=None,
               method=None):
-        # assert self._root is not None
         if not method:
             method = "xml"
         elif method not in _serialize:
-            # FIXME: raise an ImportError for c14n if ElementC14N is missing?
             raise ValueError("unknown method %r" % method)
         if not encoding:
             if method == "c14n":
                 encoding = "utf-8"
             else:
                 encoding = "us-ascii"
-        elif encoding == str:  # lxml.etree compatibility.
-            encoding = "unicode"
         else:
             encoding = encoding.lower()
-        if hasattr(file_or_filename, "write"):
-            file = file_or_filename
-        else:
-            if encoding != "unicode":
-                file = open(file_or_filename, "wb")
+        with _get_writer(file_or_filename, encoding) as write:
+            if method == "xml" and (xml_declaration or
+                    (xml_declaration is None and
+                     encoding not in ("utf-8", "us-ascii", "unicode"))):
+                declared_encoding = encoding
+                if encoding == "unicode":
+                    # Retrieve the default encoding for the xml declaration
+                    import locale
+                    declared_encoding = locale.getpreferredencoding()
+                write("<?xml version='1.0' encoding='%s'?>\n" % (
+                    declared_encoding,))
+            if method == "text":
+                _serialize_text(write, self._root)
             else:
-                file = open(file_or_filename, "w")
-        if encoding != "unicode":
-            def write(text):
-                try:
-                    return file.write(text.encode(encoding,
-                                                  "xmlcharrefreplace"))
-                except (TypeError, AttributeError):
-                    _raise_serialization_error(text)
-        else:
-            write = file.write
-        if method == "xml" and (xml_declaration or
-                (xml_declaration is None and
-                 encoding not in ("utf-8", "us-ascii", "unicode"))):
-            declared_encoding = encoding
-            if encoding == "unicode":
-                # Retrieve the default encoding for the xml declaration
-                import locale
-                declared_encoding = locale.getpreferredencoding()
-            write("<?xml version='1.0' encoding='%s'?>\n" % declared_encoding)
-        if method == "text":
-            _serialize_text(write, self._root)
-        else:
-            qnames, namespaces = _namespaces(self._root, default_namespace)
-            serialize = _serialize[method]
-            serialize(write, self._root, qnames, namespaces)
-        if file_or_filename is not file:
-            file.close()
+                qnames, namespaces = _namespaces(self._root, default_namespace)
+                serialize = _serialize[method]
+                serialize(write, self._root, qnames, namespaces)
 
     def write_c14n(self, file):
         # lxml.etree compatibility.  use output method instead
@@ -853,6 +834,58 @@ class ElementTree:
 # --------------------------------------------------------------------
 # serialization support
 
+@contextlib.contextmanager
+def _get_writer(file_or_filename, encoding):
+    # returns text write method and release all resourses after using
+    try:
+        write = file_or_filename.write
+    except AttributeError:
+        # file_or_filename is a file name
+        if encoding == "unicode":
+            file = open(file_or_filename, "w")
+        else:
+            file = open(file_or_filename, "w", encoding=encoding,
+                        errors="xmlcharrefreplace")
+        with file:
+            yield file.write
+    else:
+        # file_or_filename is a file-like object
+        # encoding determines if it is a text or binary writer
+        if encoding == "unicode":
+            # use a text writer as is
+            yield write
+        else:
+            # wrap a binary writer with TextIOWrapper
+            with contextlib.ExitStack() as stack:
+                if isinstance(file_or_filename, io.BufferedIOBase):
+                    file = file_or_filename
+                elif isinstance(file_or_filename, io.RawIOBase):
+                    file = io.BufferedWriter(file_or_filename)
+                    # Keep the original file open when the BufferedWriter is
+                    # destroyed
+                    stack.callback(file.detach)
+                else:
+                    # This is to handle passed objects that aren't in the
+                    # IOBase hierarchy, but just have a write method
+                    file = io.BufferedIOBase()
+                    file.writable = lambda: True
+                    file.write = write
+                    try:
+                        # TextIOWrapper uses this methods to determine
+                        # if BOM (for UTF-16, etc) should be added
+                        file.seekable = file_or_filename.seekable
+                        file.tell = file_or_filename.tell
+                    except AttributeError:
+                        pass
+                file = io.TextIOWrapper(file,
+                                        encoding=encoding,
+                                        errors="xmlcharrefreplace",
+                                        newline="\n")
+                # Keep the original file open when the TextIOWrapper is
+                # destroyed
+                stack.callback(file.detach)
+                yield file.write
+
 def _namespaces(elem, default_namespace=None):
     # identify namespaces used in this tree
 
@@ -1134,22 +1167,13 @@ def _escape_attrib_html(text):
 # @defreturn string
 
 def tostring(element, encoding=None, method=None):
-    class dummy:
-        pass
-    data = []
-    file = dummy()
-    file.write = data.append
-    ElementTree(element).write(file, encoding, method=method)
-    if encoding in (str, "unicode"):
-        return "".join(data)
-    else:
-        return b"".join(data)
+    stream = io.StringIO() if encoding == 'unicode' else io.BytesIO()
+    ElementTree(element).write(stream, encoding, method=method)
+    return stream.getvalue()
 
 ##
 # Generates a string representation of an XML element, including all
-# subelements.  If encoding is False, the string is returned as a
-# sequence of string fragments; otherwise it is a sequence of
-# bytestrings.
+# subelements.
 #
 # @param element An Element instance.
 # @keyparam encoding Optional output encoding (default is US-ASCII).
@@ -1161,13 +1185,15 @@ def tostring(element, encoding=None, method=None):
 # @since 1.3
 
 def tostringlist(element, encoding=None, method=None):
-    class dummy:
-        pass
     data = []
-    file = dummy()
-    file.write = data.append
-    ElementTree(element).write(file, encoding, method=method)
-    # FIXME: merge small fragments into larger parts
+    class DataStream(io.BufferedIOBase):
+        def writable(self):
+            return True
+
+        def write(self, b):
+            data.append(b)
+
+    ElementTree(element).write(DataStream(), encoding, method=method)
     return data
 
 ##