]> granicus.if.org Git - python/commitdiff
Changing diapers reminded Guido that he wanted to allow for some measure
authorTim Peters <tim.peters@gmail.com>
Wed, 14 Nov 2001 23:32:33 +0000 (23:32 +0000)
committerTim Peters <tim.peters@gmail.com>
Wed, 14 Nov 2001 23:32:33 +0000 (23:32 +0000)
of multiple inheritance from a mix of new- and classic-style classes.
This is his patch, plus a start at some test cases from me.  Will check
in more, plus a NEWS blurb, later tonight.

Lib/test/test_descr.py
Objects/typeobject.c

index ca09ca97e80720799f35c152edde8961ff5ed958..cf7ef2d64a55bdce2c473c5755e436bc67dfe73f 100644 (file)
@@ -829,6 +829,53 @@ def multi():
     vereq(Frag().__int__(), 42)
     vereq(int(Frag()), 42)
 
+    # MI mixing classic and new-style classes.
+    class C:
+        def cmethod(self):
+            return "C a"
+        def all_method(self):
+            return "C b"
+
+    class M1(C, object):
+        def m1method(self):
+            return "M1 a"
+        def all_method(self):
+            return "M1 b"
+
+    vereq(M1.__mro__, (M1, C, object))
+    m = M1()
+    vereq(m.cmethod(), "C a")
+    vereq(m.m1method(), "M1 a")
+    vereq(m.all_method(), "M1 b")
+
+    class D(C):
+        def dmethod(self):
+            return "D a"
+        def all_method(self):
+            return "D b"
+
+    class M2(object, D):
+        def m2method(self):
+            return "M2 a"
+        def all_method(self):
+            return "M2 b"
+
+    vereq(M2.__mro__, (M2, object, D, C))
+    m = M2()
+    vereq(m.cmethod(), "C a")
+    vereq(m.dmethod(), "D a")
+    vereq(m.m2method(), "M2 a")
+    vereq(m.all_method(), "M2 b")
+
+    class M3(M1, object, M2):
+        def m3method(self):
+            return "M3 a"
+        def all_method(self):
+            return "M3 b"
+    # XXX Expected this (the commented-out result):
+    # vereq(M3.__mro__, (M3, M1, M2, object, D, C))
+    vereq(M3.__mro__, (M3, M1, M2, D, C, object))  # XXX ?
+
 def diamond():
     if verbose: print "Testing multiple inheritance special cases..."
     class A(object):
@@ -1016,14 +1063,6 @@ def errors():
     class Classic:
         pass
 
-    try:
-        class C(object, Classic):
-            pass
-    except TypeError:
-        pass
-    else:
-        verify(0, "inheritance from object and Classic should be illegal")
-
     try:
         class C(type(len)):
             pass
index 39214e7a6f18e158698d4a6b84458a35ee062a1a..07032f4e24e35823edf3ae5438ef56b4a20f0df8 100644 (file)
@@ -577,6 +577,47 @@ serious_order_disagreements(PyObject *left, PyObject *right)
        return 0; /* XXX later -- for now, we cheat: "don't do that" */
 }
 
+static int
+fill_classic_mro(PyObject *mro, PyObject *cls)
+{
+       PyObject *bases, *base;
+       int i, n;
+
+       assert(PyList_Check(mro));
+       assert(PyClass_Check(cls));
+       i = PySequence_Contains(mro, cls);
+       if (i < 0)
+               return -1;
+       if (!i) {
+               if (PyList_Append(mro, cls) < 0)
+                       return -1;
+       }
+       bases = ((PyClassObject *)cls)->cl_bases;
+       assert(bases && PyTuple_Check(bases));
+       n = PyTuple_GET_SIZE(bases);
+       for (i = 0; i < n; i++) {
+               base = PyTuple_GET_ITEM(bases, i);
+               if (fill_classic_mro(mro, base) < 0)
+                       return -1;
+       }
+       return 0;
+}
+
+static PyObject *
+classic_mro(PyObject *cls)
+{
+       PyObject *mro;
+
+       assert(PyClass_Check(cls));
+       mro = PyList_New(0);
+       if (mro != NULL) {
+               if (fill_classic_mro(mro, cls) == 0)
+                       return mro;
+               Py_DECREF(mro);
+       }
+       return NULL;
+}
+
 static PyObject *
 mro_implementation(PyTypeObject *type)
 {
@@ -589,9 +630,13 @@ mro_implementation(PyTypeObject *type)
        if (result == NULL)
                return NULL;
        for (i = 0; i < n; i++) {
-               PyTypeObject *base =
-                       (PyTypeObject *) PyTuple_GET_ITEM(bases, i);
-               PyObject *parentMRO = PySequence_List(base->tp_mro);
+               PyObject *base = PyTuple_GET_ITEM(bases, i);
+               PyObject *parentMRO;
+               if (PyType_Check(base))
+                       parentMRO = PySequence_List(
+                               ((PyTypeObject*)base)->tp_mro);
+               else
+                       parentMRO = classic_mro(base);
                if (parentMRO == NULL) {
                        Py_DECREF(result);
                        return NULL;
@@ -651,26 +696,34 @@ best_base(PyObject *bases)
 {
        int i, n;
        PyTypeObject *base, *winner, *candidate, *base_i;
+       PyObject *base_proto;
 
        assert(PyTuple_Check(bases));
        n = PyTuple_GET_SIZE(bases);
        assert(n > 0);
-       base = (PyTypeObject *)PyTuple_GET_ITEM(bases, 0);
-       winner = &PyBaseObject_Type;
+       base = NULL;
+       winner = NULL;
        for (i = 0; i < n; i++) {
-               base_i = (PyTypeObject *)PyTuple_GET_ITEM(bases, i);
-               if (!PyType_Check((PyObject *)base_i)) {
+               base_proto = PyTuple_GET_ITEM(bases, i);
+               if (PyClass_Check(base_proto))
+                       continue;
+               if (!PyType_Check(base_proto)) {
                        PyErr_SetString(
                                PyExc_TypeError,
                                "bases must be types");
                        return NULL;
                }
+               base_i = (PyTypeObject *)base_proto;
                if (base_i->tp_dict == NULL) {
                        if (PyType_Ready(base_i) < 0)
                                return NULL;
                }
                candidate = solid_base(base_i);
-               if (PyType_IsSubtype(winner, candidate))
+               if (winner == NULL) {
+                       winner = candidate;
+                       base = base_i;
+               }
+               else if (PyType_IsSubtype(winner, candidate))
                        ;
                else if (PyType_IsSubtype(candidate, winner)) {
                        winner = candidate;
@@ -827,6 +880,8 @@ type_new(PyTypeObject *metatype, PyObject *args, PyObject *kwds)
        for (i = 0; i < nbases; i++) {
                tmp = PyTuple_GET_ITEM(bases, i);
                tmptype = tmp->ob_type;
+               if (tmptype == &PyClass_Type)
+                       continue; /* Special case classic classes */
                if (PyType_IsSubtype(winner, tmptype))
                        continue;
                if (PyType_IsSubtype(tmptype, winner)) {
@@ -1079,16 +1134,20 @@ PyObject *
 _PyType_Lookup(PyTypeObject *type, PyObject *name)
 {
        int i, n;
-       PyObject *mro, *res, *dict;
+       PyObject *mro, *res, *base, *dict;
 
        /* Look in tp_dict of types in MRO */
        mro = type->tp_mro;
        assert(PyTuple_Check(mro));
        n = PyTuple_GET_SIZE(mro);
        for (i = 0; i < n; i++) {
-               type = (PyTypeObject *) PyTuple_GET_ITEM(mro, i);
-               assert(PyType_Check(type));
-               dict = type->tp_dict;
+               base = PyTuple_GET_ITEM(mro, i);
+               if (PyClass_Check(base))
+                       dict = ((PyClassObject *)base)->cl_dict;
+               else {
+                       assert(PyType_Check(base));
+                       dict = ((PyTypeObject *)base)->tp_dict;
+               }
                assert(dict && PyDict_Check(dict));
                res = PyDict_GetItem(dict, name);
                if (res != NULL)
@@ -1920,9 +1979,9 @@ PyType_Ready(PyTypeObject *type)
        assert(PyTuple_Check(bases));
        n = PyTuple_GET_SIZE(bases);
        for (i = 1; i < n; i++) {
-               base = (PyTypeObject *)PyTuple_GET_ITEM(bases, i);
-               assert(PyType_Check(base));
-               inherit_slots(type, base);
+               PyObject *b = PyTuple_GET_ITEM(bases, i);
+               if (PyType_Check(b))
+                       inherit_slots(type, (PyTypeObject *)b);
        }
 
        /* Some more special stuff */
@@ -1940,8 +1999,9 @@ PyType_Ready(PyTypeObject *type)
        bases = type->tp_bases;
        n = PyTuple_GET_SIZE(bases);
        for (i = 0; i < n; i++) {
-               base = (PyTypeObject *) PyTuple_GET_ITEM(bases, i);
-               if (add_subclass((PyTypeObject *)base, type) < 0)
+               PyObject *b = PyTuple_GET_ITEM(bases, i);
+               if (PyType_Check(b) &&
+                   add_subclass((PyTypeObject *)b, type) < 0)
                        goto error;
        }
 
@@ -3665,7 +3725,6 @@ fixup_slot_dispatchers(PyTypeObject *type)
 {
        slotdef *p;
        PyObject *mro, *descr;
-       PyTypeObject *base;
        PyWrapperDescrObject *d;
        int i, n, offset;
        void **ptr;
@@ -3690,13 +3749,18 @@ fixup_slot_dispatchers(PyTypeObject *type)
                do {
                        descr = NULL;
                        for (i = 0; i < n; i++) {
-                               base = (PyTypeObject *)
-                                       PyTuple_GET_ITEM(mro, i);
-                               assert(PyType_Check(base));
-                               descr = PyDict_GetItem(
-                                       base->tp_dict, p->name_strobj);
-                               if (descr != NULL)
-                                       break;
+                               PyObject *b = PyTuple_GET_ITEM(mro, i);
+                               PyObject *dict = NULL;
+                               if (PyType_Check(b))
+                                       dict = ((PyTypeObject *)b)->tp_dict;
+                               else if (PyClass_Check(b))
+                                       dict = ((PyClassObject *)b)->cl_dict;
+                               if (dict != NULL) {
+                                       descr = PyDict_GetItem(
+                                               dict, p->name_strobj);
+                                       if (descr != NULL)
+                                               break;
+                               }
                        }
                        if (descr == NULL)
                                continue;
@@ -3825,7 +3889,7 @@ super_getattro(PyObject *self, PyObject *name)
        superobject *su = (superobject *)self;
 
        if (su->obj != NULL) {
-               PyObject *mro, *res, *tmp;
+               PyObject *mro, *res, *tmp, *dict;
                descrgetfunc f;
                int i, n;
 
@@ -3858,9 +3922,13 @@ super_getattro(PyObject *self, PyObject *name)
                res = NULL;
                for (; i < n; i++) {
                        tmp = PyTuple_GET_ITEM(mro, i);
-                       assert(PyType_Check(tmp));
-                       res = PyDict_GetItem(
-                               ((PyTypeObject *)tmp)->tp_dict, name);
+                       if (PyType_Check(tmp))
+                               dict = ((PyTypeObject *)tmp)->tp_dict;
+                       else if (PyClass_Check(tmp))
+                               dict = ((PyClassObject *)tmp)->cl_dict;
+                       else
+                               continue;
+                       res = PyDict_GetItem(dict, name);
                        if (res != NULL) {
                                Py_INCREF(res);
                                f = res->ob_type->tp_descr_get;