]> granicus.if.org Git - python/commitdiff
Fix PR#7 comparisons of recursive objects
authorJeremy Hylton <jeremy@alum.mit.edu>
Fri, 14 Apr 2000 19:13:24 +0000 (19:13 +0000)
committerJeremy Hylton <jeremy@alum.mit.edu>
Fri, 14 Apr 2000 19:13:24 +0000 (19:13 +0000)
Note that comparisons of deeply nested objects can still dump core in
extreme cases.

Include/object.h
Lib/test/test_b1.py
Objects/object.c
Python/pythonrun.c

index 77f5c55b974e216ba445c864244d794788958351..fabf0b6829e839683b17436f77a47b7f43222c4b 100644 (file)
@@ -284,6 +284,9 @@ extern DL_IMPORT(int) PyNumber_CoerceEx Py_PROTO((PyObject **, PyObject **));
 extern DL_IMPORT(int) Py_ReprEnter Py_PROTO((PyObject *));
 extern DL_IMPORT(void) Py_ReprLeave Py_PROTO((PyObject *));
 
+/* tstate dict key for PyObject_Compare helper */
+extern PyObject *_PyCompareState_Key;
+
 /* Flag bits for printing: */
 #define Py_PRINT_RAW   1       /* No string quotes etc. */
 
index 6a89d2209aba55d43a61b86cea4012fca21f7a42..b063e5aa7187c19f04652d888dd0b5d44b0b9eba 100644 (file)
@@ -63,6 +63,15 @@ print 'cmp'
 if cmp(-1, 1) <> -1: raise TestFailed, 'cmp(-1, 1)'
 if cmp(1, -1) <> 1: raise TestFailed, 'cmp(1, -1)'
 if cmp(1, 1) <> 0: raise TestFailed, 'cmp(1, 1)'
+# verify that circular objects are handled
+a = []; a.append(a)
+b = []; b.append(b)
+from UserList import UserList
+c = UserList(); c.append(c)
+if cmp(a, b) != 0: raise TestFailed, "cmp(%s, %s)" % (a, b)
+if cmp(b, c) != 0: raise TestFailed, "cmp(%s, %s)" % (b, c)
+if cmp(c, a) != 0: raise TestFailed, "cmp(%s, %s)" % (c, a)
+if cmp(a, c) != 0: raise TestFailed, "cmp(%s, %s)" % (a, c)
 
 print 'coerce'
 if fcmp(coerce(1, 1.1), (1.0, 1.1)): raise TestFailed, 'coerce(1, 1.1)'
index 968fdd0f23c6e4942702c7bb52a6680b2a239b5c..bd1d17fb727ec6bd26e6144699e8f5a4302b7a2e 100644 (file)
@@ -298,11 +298,67 @@ do_cmp(v, w)
        return PyInt_FromLong(c);
 }
 
+PyObject *_PyCompareState_Key;
+
+/* _PyCompareState_nesting is incremented beforing call compare (for
+   some types) and decremented on exit.  If the count exceeds the
+   nesting limit, enable code to detect circular data structures.
+*/
+#define NESTING_LIMIT 500
+int _PyCompareState_nesting = 0;
+
+static PyObject*
+get_inprogress_dict()
+{
+       PyObject *tstate_dict, *inprogress;
+
+       tstate_dict = PyThreadState_GetDict();
+       if (tstate_dict == NULL) {
+               PyErr_BadInternalCall();
+               return NULL;
+       } 
+       inprogress = PyDict_GetItem(tstate_dict, _PyCompareState_Key); 
+       if (inprogress == NULL) {
+               PyErr_Clear();
+               inprogress = PyDict_New();
+               if (inprogress == NULL)
+                       return NULL;
+               if (PyDict_SetItem(tstate_dict, _PyCompareState_Key,
+                                  inprogress) == -1) {
+                   Py_DECREF(inprogress);
+                   return NULL;
+               }
+       }
+       return inprogress;
+}
+
+static PyObject *
+make_pair(v, w)
+       PyObject *v, *w;
+{
+       PyObject *pair;
+
+       pair = PyTuple_New(2);
+       if (pair == NULL) {
+               return NULL;
+       }
+       if ((long)v <= (long)w) {
+               PyTuple_SET_ITEM(pair, 0, PyLong_FromVoidPtr((void *)v));
+               PyTuple_SET_ITEM(pair, 1, PyLong_FromVoidPtr((void *)w));
+       } else {
+               PyTuple_SET_ITEM(pair, 0, PyLong_FromVoidPtr((void *)w));
+               PyTuple_SET_ITEM(pair, 1, PyLong_FromVoidPtr((void *)v));
+       }
+       return pair;
+}
+
 int
 PyObject_Compare(v, w)
        PyObject *v, *w;
 {
        PyTypeObject *vtp, *wtp;
+       int result;
+
        if (v == NULL || w == NULL) {
                PyErr_BadInternalCall();
                return -1;
@@ -314,7 +370,32 @@ PyObject_Compare(v, w)
                int c;
                if (!PyInstance_Check(v))
                        return -PyObject_Compare(w, v);
-               res = do_cmp(v, w);
+               if (++_PyCompareState_nesting > NESTING_LIMIT) {
+                       PyObject *inprogress, *pair;
+
+                       inprogress = get_inprogress_dict();
+                       if (inprogress == NULL) {
+                               return -1;
+                       }
+                       pair = make_pair(v, w);
+                       if (PyDict_GetItem(inprogress, pair)) {
+                               /* already comparing these objects.  assume
+                                  they're equal until shown otherwise */
+                               Py_DECREF(pair);
+                               --_PyCompareState_nesting;
+                               return 0;
+                       }
+                       if (PyDict_SetItem(inprogress, pair, pair) == -1) {
+                               return -1;
+                       }
+                       res = do_cmp(v, w);
+                       _PyCompareState_nesting--;
+                       /* XXX DelItem shouldn't fail */
+                       PyDict_DelItem(inprogress, pair);
+                       Py_DECREF(pair);
+               } else {
+                       res = do_cmp(v, w);
+               }
                if (res == NULL)
                        return -1;
                if (!PyInt_Check(res)) {
@@ -369,9 +450,37 @@ PyObject_Compare(v, w)
                /* Numerical types compare smaller than all other types */
                return strcmp(vname, wname);
        }
-       if (vtp->tp_compare == NULL)
+       if (vtp->tp_compare == NULL) {
                return (v < w) ? -1 : 1;
-       return (*vtp->tp_compare)(v, w);
+       }
+       if (++_PyCompareState_nesting > NESTING_LIMIT
+           && (vtp->tp_as_mapping 
+               || (vtp->tp_as_sequence && !PyString_Check(v)))) {
+               PyObject *inprogress, *pair;
+
+               inprogress = get_inprogress_dict();
+               if (inprogress == NULL) {
+                       return -1;
+               }
+               pair = make_pair(v, w);
+               if (PyDict_GetItem(inprogress, pair)) {
+                       /* already comparing these objects.  assume
+                          they're equal until shown otherwise */
+                       _PyCompareState_nesting--;
+                       Py_DECREF(pair);
+                       return 0;
+               }
+               if (PyDict_SetItem(inprogress, pair, pair) == -1) {
+                       return -1;
+               }
+               result = (*vtp->tp_compare)(v, w);
+               _PyCompareState_nesting--;
+               PyDict_DelItem(inprogress, pair); /* XXX shouldn't fail */
+               Py_DECREF(pair);
+       } else {
+               result = (*vtp->tp_compare)(v, w);
+       }
+       return result;
 }
 
 long
index eb93d47ebe5c5120a99a902f94869c677e4c2ae0..0ae15fafa268f706447d04b65af9c979a1680398 100644 (file)
@@ -149,6 +149,8 @@ Py_Initialize()
        /* Init Unicode implementation; relies on the codec registry */
        _PyUnicode_Init();
 
+       _PyCompareState_Key = PyString_InternFromString("cmp_state");
+
        bimod = _PyBuiltin_Init_1();
        if (bimod == NULL)
                Py_FatalError("Py_Initialize: can't initialize __builtin__");