]> granicus.if.org Git - python/commitdiff
Issue #14128: Exposing Element as an actual type from _elementtree, rather than a...
authorEli Bendersky <eliben@gmail.com>
Sun, 4 Mar 2012 05:14:03 +0000 (07:14 +0200)
committerEli Bendersky <eliben@gmail.com>
Sun, 4 Mar 2012 05:14:03 +0000 (07:14 +0200)
This makes the C implementation more aligned with the Python implementation.
Also added some tests to ensure that Element is now a type and that it can
be subclassed.

Lib/test/test_xml_etree.py
Lib/test/test_xml_etree_c.py
Lib/xml/etree/ElementTree.py
Modules/_elementtree.c

index 58fdcd4c5c60e87df2887459d2820bc564b4ef1a..869a1597f70a1d44e89e86bf717c22116109cf87 100644 (file)
@@ -1901,16 +1901,51 @@ class CleanContext(object):
 class TestAcceleratorNotImported(unittest.TestCase):
     # Test that the C accelerator was not imported for pyET
     def test_correct_import_pyET(self):
-        self.assertEqual(pyET.Element.__module__, 'xml.etree.ElementTree')
+        self.assertEqual(pyET.SubElement.__module__, 'xml.etree.ElementTree')
+
+
+class TestElementClass(unittest.TestCase):
+    def test_Element_is_a_type(self):
+        self.assertIsInstance(ET.Element, type)
+
+    def test_Element_subclass_trivial(self):
+        class MyElement(ET.Element):
+            pass
+
+        mye = MyElement('foo')
+        self.assertIsInstance(mye, ET.Element)
+        self.assertIsInstance(mye, MyElement)
+        self.assertEqual(mye.tag, 'foo')
+
+    def test_Element_subclass_constructor(self):
+        class MyElement(ET.Element):
+            def __init__(self, tag, attrib={}, **extra):
+                super(MyElement, self).__init__(tag + '__', attrib, **extra)
+
+        mye = MyElement('foo', {'a': 1, 'b': 2}, c=3, d=4)
+        self.assertEqual(mye.tag, 'foo__')
+        self.assertEqual(sorted(mye.items()),
+            [('a', 1), ('b', 2), ('c', 3), ('d', 4)])
+
+    def test_Element_subclass_new_method(self):
+        class MyElement(ET.Element):
+            def newmethod(self):
+                return self.tag
+
+        mye = MyElement('joe')
+        self.assertEqual(mye.newmethod(), 'joe')
 
 
 def test_main(module=pyET):
     from test import test_xml_etree
 
+    # Run the tests specific to the Python implementation
+    support.run_unittest(TestAcceleratorNotImported)
+
     # The same doctests are used for both the Python and the C implementations
     test_xml_etree.ET = module
 
-    support.run_unittest(TestAcceleratorNotImported)
+    support.run_unittest(TestElementClass)
 
     # XXX the C module should give the same warnings as the Python module
     with CleanContext(quiet=(module is not pyET)):
index a73d0c4b82bbbb167cd1f8b2cfdeb232946baae4..cfd18ee66bfa2f0e1b2412c86d88900be56354ec 100644 (file)
@@ -46,14 +46,22 @@ class MiscTests(unittest.TestCase):
         finally:
             data = None
 
+@unittest.skipUnless(cET, 'requires _elementtree')
+class TestAliasWorking(unittest.TestCase):
+    # Test that the cET alias module is alive
+    def test_alias_working(self):
+        e = cET_alias.Element('foo')
+        self.assertEqual(e.tag, 'foo')
+        
+
 @unittest.skipUnless(cET, 'requires _elementtree')
 class TestAcceleratorImported(unittest.TestCase):
     # Test that the C accelerator was imported, as expected
     def test_correct_import_cET(self):
-        self.assertEqual(cET.Element.__module__, '_elementtree')
+        self.assertEqual(cET.SubElement.__module__, '_elementtree')
 
     def test_correct_import_cET_alias(self):
-        self.assertEqual(cET_alias.Element.__module__, '_elementtree')
+        self.assertEqual(cET_alias.SubElement.__module__, '_elementtree')
 
 
 def test_main():
@@ -61,13 +69,15 @@ def test_main():
 
     # Run the tests specific to the C implementation
     support.run_doctest(test_xml_etree_c, verbosity=True)
-
-    support.run_unittest(MiscTests, TestAcceleratorImported)
+    support.run_unittest(
+        MiscTests,
+        TestAliasWorking,
+        TestAcceleratorImported
+        )
 
     # Run the same test suite as the Python module
     test_xml_etree.test_main(module=cET)
-    # Exercise the deprecated alias
-    test_xml_etree.test_main(module=cET_alias)
+
 
 if __name__ == '__main__':
     test_main()
index defef0d0ebca36dc35f85c1d22490763feffe736..a864fa5263d62cadbc12f038cdde47aff5660f35 100644 (file)
@@ -101,7 +101,6 @@ import sys
 import re
 import warnings
 
-
 class _SimpleElementPath:
     # emulate pre-1.2 find/findtext/findall behaviour
     def find(self, element, tag, namespaces=None):
index 78d8177d68bbf9edb076b942e0ebd9ac989f2fb6..179cadc5f5916a53d5da8d6c00c7dc9db41546be 100644 (file)
@@ -191,7 +191,7 @@ list_join(PyObject* list)
 }
 
 /* -------------------------------------------------------------------- */
-/* the element type */
+/* the Element type */
 
 typedef struct {
 
@@ -236,10 +236,10 @@ static PyTypeObject Element_Type;
 #define Element_CheckExact(op) (Py_TYPE(op) == &Element_Type)
 
 /* -------------------------------------------------------------------- */
-/* element constructor and destructor */
+/* Element constructors and destructor */
 
 LOCAL(int)
-element_new_extra(ElementObject* self, PyObject* attrib)
+create_extra(ElementObject* self, PyObject* attrib)
 {
     self->extra = PyObject_Malloc(sizeof(ElementObjectExtra));
     if (!self->extra)
@@ -259,7 +259,7 @@ element_new_extra(ElementObject* self, PyObject* attrib)
 }
 
 LOCAL(void)
-element_dealloc_extra(ElementObject* self)
+dealloc_extra(ElementObject* self)
 {
     int i;
 
@@ -274,8 +274,11 @@ element_dealloc_extra(ElementObject* self)
     PyObject_Free(self->extra);
 }
 
+/* Convenience internal function to create new Element objects with the given
+ * tag and attributes.
+*/
 LOCAL(PyObject*)
-element_new(PyObject* tag, PyObject* attrib)
+create_new_element(PyObject* tag, PyObject* attrib)
 {
     ElementObject* self;
 
@@ -290,16 +293,10 @@ element_new(PyObject* tag, PyObject* attrib)
     self->extra = NULL;
 
     if (attrib != Py_None) {
-
-        if (element_new_extra(self, attrib) < 0) {
+        if (create_extra(self, attrib) < 0) {
             PyObject_Del(self);
             return NULL;
         }
-
-        self->extra->length = 0;
-        self->extra->allocated = STATIC_CHILDREN;
-        self->extra->children = self->extra->_children;
-
     }
 
     Py_INCREF(tag);
@@ -316,6 +313,86 @@ element_new(PyObject* tag, PyObject* attrib)
     return (PyObject*) self;
 }
 
+static PyObject *
+element_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
+{
+    ElementObject *e = (ElementObject *)type->tp_alloc(type, 0);
+    if (e != NULL) {
+        Py_INCREF(Py_None);
+        e->tag = Py_None;
+
+        Py_INCREF(Py_None);
+        e->text = Py_None;
+
+        Py_INCREF(Py_None);
+        e->tail = Py_None;
+
+        e->extra = NULL;
+    }
+    return (PyObject *)e;
+}
+
+static int
+element_init(PyObject *self, PyObject *args, PyObject *kwds)
+{
+    PyObject *tag;
+    PyObject *tmp;
+    PyObject *attrib = NULL;
+    ElementObject *self_elem;
+
+    if (!PyArg_ParseTuple(args, "O|O!:Element", &tag, &PyDict_Type, &attrib))
+        return -1;
+
+    if (attrib || kwds) {
+        attrib = (attrib) ? PyDict_Copy(attrib) : PyDict_New();
+        if (!attrib)
+            return -1;
+        if (kwds)
+            PyDict_Update(attrib, kwds);
+    } else {
+        Py_INCREF(Py_None);
+        attrib = Py_None;
+    }
+
+    self_elem = (ElementObject *)self;
+
+    /* Use None for empty dictionaries */
+    if (PyDict_CheckExact(attrib) && PyDict_Size(attrib) == 0) {
+        Py_INCREF(Py_None);
+        attrib = Py_None;
+    }
+
+    if (attrib != Py_None) {
+        if (create_extra(self_elem, attrib) < 0) {
+            PyObject_Del(self_elem);
+            return -1;
+        }
+    }
+
+    /* If create_extra needed attrib, it took a reference to it, so we can
+     * release ours anyway.
+    */
+    Py_DECREF(attrib);
+
+    /* Replace the objects already pointed to by tag, text and tail. */
+    tmp = self_elem->tag;
+    self_elem->tag = tag;
+    Py_INCREF(tag);
+    Py_DECREF(tmp);
+
+    tmp = self_elem->text;
+    self_elem->text = Py_None;
+    Py_INCREF(Py_None);
+    Py_DECREF(JOIN_OBJ(tmp));
+
+    tmp = self_elem->tail;
+    self_elem->tail = Py_None;
+    Py_INCREF(Py_None);
+    Py_DECREF(JOIN_OBJ(tmp));
+
+    return 0;
+}
+
 LOCAL(int)
 element_resize(ElementObject* self, int extra)
 {
@@ -326,7 +403,7 @@ element_resize(ElementObject* self, int extra)
        elements.  set an exception and return -1 if allocation failed */
 
     if (!self->extra)
-        element_new_extra(self, NULL);
+        create_extra(self, NULL);
 
     size = self->extra->length + extra;
 
@@ -443,35 +520,6 @@ element_get_tail(ElementObject* self)
     return res;
 }
 
-static PyObject*
-element(PyObject* self, PyObject* args, PyObject* kw)
-{
-    PyObject* elem;
-
-    PyObject* tag;
-    PyObject* attrib = NULL;
-    if (!PyArg_ParseTuple(args, "O|O!:Element", &tag,
-                          &PyDict_Type, &attrib))
-        return NULL;
-
-    if (attrib || kw) {
-        attrib = (attrib) ? PyDict_Copy(attrib) : PyDict_New();
-        if (!attrib)
-            return NULL;
-        if (kw)
-            PyDict_Update(attrib, kw);
-    } else {
-        Py_INCREF(Py_None);
-        attrib = Py_None;
-    }
-
-    elem = element_new(tag, attrib);
-
-    Py_DECREF(attrib);
-
-    return elem;
-}
-
 static PyObject*
 subelement(PyObject* self, PyObject* args, PyObject* kw)
 {
@@ -496,7 +544,7 @@ subelement(PyObject* self, PyObject* args, PyObject* kw)
         attrib = Py_None;
     }
 
-    elem = element_new(tag, attrib);
+    elem = create_new_element(tag, attrib);
 
     Py_DECREF(attrib);
 
@@ -512,7 +560,7 @@ static void
 element_dealloc(ElementObject* self)
 {
     if (self->extra)
-        element_dealloc_extra(self);
+        dealloc_extra(self);
 
     /* discard attributes */
     Py_DECREF(self->tag);
@@ -521,7 +569,7 @@ element_dealloc(ElementObject* self)
 
     RELEASE(sizeof(ElementObject), "destroy element");
 
-    PyObject_Del(self);
+    Py_TYPE(self)->tp_free((PyObject *)self);
 }
 
 /* -------------------------------------------------------------------- */
@@ -547,7 +595,7 @@ element_clear(ElementObject* self, PyObject* args)
         return NULL;
 
     if (self->extra) {
-        element_dealloc_extra(self);
+        dealloc_extra(self);
         self->extra = NULL;
     }
 
@@ -571,7 +619,7 @@ element_copy(ElementObject* self, PyObject* args)
     if (!PyArg_ParseTuple(args, ":__copy__"))
         return NULL;
 
-    element = (ElementObject*) element_new(
+    element = (ElementObject*) create_new_element(
         self->tag, (self->extra) ? self->extra->attrib : Py_None
         );
     if (!element)
@@ -634,7 +682,7 @@ element_deepcopy(ElementObject* self, PyObject* args)
         attrib = Py_None;
     }
 
-    element = (ElementObject*) element_new(tag, attrib);
+    element = (ElementObject*) create_new_element(tag, attrib);
 
     Py_DECREF(tag);
     Py_DECREF(attrib);
@@ -1029,7 +1077,7 @@ element_insert(ElementObject* self, PyObject* args)
         return NULL;
 
     if (!self->extra)
-        element_new_extra(self, NULL);
+        create_extra(self, NULL);
 
     if (index < 0) {
         index += self->extra->length;
@@ -1100,7 +1148,7 @@ element_makeelement(PyObject* self, PyObject* args, PyObject* kw)
     if (!attrib)
         return NULL;
 
-    elem = element_new(tag, attrib);
+    elem = create_new_element(tag, attrib);
 
     Py_DECREF(attrib);
 
@@ -1154,7 +1202,10 @@ element_remove(ElementObject* self, PyObject* args)
 static PyObject*
 element_repr(ElementObject* self)
 {
-    return PyUnicode_FromFormat("<Element %R at %p>", self->tag, self);
+    if (self->tag)
+        return PyUnicode_FromFormat("<Element %R at %p>", self->tag, self);
+    else
+        return PyUnicode_FromFormat("<Element at %p>", self);
 }
 
 static PyObject*
@@ -1168,7 +1219,7 @@ element_set(ElementObject* self, PyObject* args)
         return NULL;
 
     if (!self->extra)
-        element_new_extra(self, NULL);
+        create_extra(self, NULL);
 
     attrib = element_get_attrib(self);
     if (!attrib)
@@ -1284,7 +1335,7 @@ element_ass_subscr(PyObject* self_, PyObject* item, PyObject* value)
         PyObject* seq = NULL;
 
         if (!self->extra)
-            element_new_extra(self, NULL);
+            create_extra(self, NULL);
 
         if (PySlice_GetIndicesEx(item,
                 self->extra->length,
@@ -1448,7 +1499,7 @@ element_getattro(ElementObject* self, PyObject* nameobj)
     } else if (strcmp(name, "attrib") == 0) {
         PyErr_Clear();
         if (!self->extra)
-            element_new_extra(self, NULL);
+            create_extra(self, NULL);
         res = element_get_attrib(self);
     }
 
@@ -1484,7 +1535,7 @@ element_setattr(ElementObject* self, const char* name, PyObject* value)
         Py_INCREF(self->tail);
     } else if (strcmp(name, "attrib") == 0) {
         if (!self->extra)
-            element_new_extra(self, NULL);
+            create_extra(self, NULL);
         Py_DECREF(self->extra->attrib);
         self->extra->attrib = value;
         Py_INCREF(self->extra->attrib);
@@ -1516,31 +1567,41 @@ static PyTypeObject Element_Type = {
     PyVarObject_HEAD_INIT(NULL, 0)
     "Element", sizeof(ElementObject), 0,
     /* methods */
-    (destructor)element_dealloc, /* tp_dealloc */
-    0, /* tp_print */
-    0, /* tp_getattr */
-    (setattrfunc)element_setattr, /* tp_setattr */
-    0, /* tp_reserved */
-    (reprfunc)element_repr, /* tp_repr */
-    0, /* tp_as_number */
-    &element_as_sequence, /* tp_as_sequence */
-    &element_as_mapping, /* tp_as_mapping */
-    0, /* tp_hash */
-    0, /* tp_call */
-    0, /* tp_str */
-    (getattrofunc)element_getattro, /* tp_getattro */
-    0, /* tp_setattro */
-    0, /* tp_as_buffer */
-    Py_TPFLAGS_DEFAULT, /* tp_flags */
-    0, /* tp_doc */
-    0, /* tp_traverse */
-    0, /* tp_clear */
-    0, /* tp_richcompare */
-    0, /* tp_weaklistoffset */
-    0, /* tp_iter */
-    0, /* tp_iternext */
-    element_methods, /* tp_methods */
-    0, /* tp_members */
+    (destructor)element_dealloc,                    /* tp_dealloc */
+    0,                                              /* tp_print */
+    0,                                              /* tp_getattr */
+    (setattrfunc)element_setattr,                   /* tp_setattr */
+    0,                                              /* tp_reserved */
+    (reprfunc)element_repr,                         /* tp_repr */
+    0,                                              /* tp_as_number */
+    &element_as_sequence,                           /* tp_as_sequence */
+    &element_as_mapping,                            /* tp_as_mapping */
+    0,                                              /* tp_hash */
+    0,                                              /* tp_call */
+    0,                                              /* tp_str */
+    (getattrofunc)element_getattro,                 /* tp_getattro */
+    0,                                              /* tp_setattro */
+    0,                                              /* tp_as_buffer */
+    Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE,       /* tp_flags */
+    0,                                              /* tp_doc */
+    0,                                              /* tp_traverse */
+    0,                                              /* tp_clear */
+    0,                                              /* tp_richcompare */
+    0,                                              /* tp_weaklistoffset */
+    0,                                              /* tp_iter */
+    0,                                              /* tp_iternext */
+    element_methods,                                /* tp_methods */
+    0,                                              /* tp_members */
+    0,                                              /* tp_getset */
+    0,                                              /* tp_base */
+    0,                                              /* tp_dict */
+    0,                                              /* tp_descr_get */
+    0,                                              /* tp_descr_set */
+    0,                                              /* tp_dictoffset */
+    (initproc)element_init,                         /* tp_init */
+    PyType_GenericAlloc,                            /* tp_alloc */
+    element_new,                                    /* tp_new */
+    0,                                              /* tp_free */
 };
 
 /* ==================================================================== */
@@ -1666,7 +1727,7 @@ treebuilder_handle_start(TreeBuilderObject* self, PyObject* tag,
         self->data = NULL;
     }
 
-    node = element_new(tag, attrib);
+    node = create_new_element(tag, attrib);
     if (!node)
         return NULL;
 
@@ -2801,7 +2862,6 @@ static PyTypeObject XMLParser_Type = {
 /* python module interface */
 
 static PyMethodDef _functions[] = {
-    {"Element", (PyCFunction) element, METH_VARARGS|METH_KEYWORDS},
     {"SubElement", (PyCFunction) subelement, METH_VARARGS|METH_KEYWORDS},
     {"TreeBuilder", (PyCFunction) treebuilder, METH_VARARGS},
 #if defined(USE_EXPAT)
@@ -2911,5 +2971,8 @@ PyInit__elementtree(void)
     Py_INCREF(elementtree_parseerror_obj);
     PyModule_AddObject(m, "ParseError", elementtree_parseerror_obj);
 
+    Py_INCREF((PyObject *)&Element_Type);
+    PyModule_AddObject(m, "Element", (PyObject *)&Element_Type);
+
     return m;
 }