]> granicus.if.org Git - python/commitdiff
Deleting cyclic object comparison.
authorArmin Rigo <arigo@tunes.org>
Tue, 28 Oct 2003 12:05:48 +0000 (12:05 +0000)
committerArmin Rigo <arigo@tunes.org>
Tue, 28 Oct 2003 12:05:48 +0000 (12:05 +0000)
SF patch 825639
http://mail.python.org/pipermail/python-dev/2003-October/039445.html

Include/ceval.h
Lib/test/pickletester.py
Lib/test/test_builtin.py
Lib/test/test_copy.py
Lib/test/test_richcmp.py
Misc/NEWS
Objects/classobject.c
Objects/object.c
Python/ceval.c

index 411cf3e97bd4bf6ed2812f6223828471d9cbb5c9..dc3864b8eb3dbf0ef37c495de51a67a834d255a7 100644 (file)
@@ -43,9 +43,23 @@ PyAPI_FUNC(int) Py_FlushLine(void);
 PyAPI_FUNC(int) Py_AddPendingCall(int (*func)(void *), void *arg);
 PyAPI_FUNC(int) Py_MakePendingCalls(void);
 
+/* Protection against deeply nested recursive calls */
 PyAPI_FUNC(void) Py_SetRecursionLimit(int);
 PyAPI_FUNC(int) Py_GetRecursionLimit(void);
 
+#define Py_EnterRecursiveCall(where)                                    \
+           (_Py_MakeRecCheck(PyThreadState_GET()->recursion_depth) &&  \
+            _Py_CheckRecursiveCall(where))
+#define Py_LeaveRecursiveCall()                                \
+           (--PyThreadState_GET()->recursion_depth)
+PyAPI_FUNC(int) _Py_CheckRecursiveCall(char *where);
+PyAPI_DATA(int) _Py_CheckRecursionLimit;
+#ifdef USE_STACKCHECK
+#  define _Py_MakeRecCheck(x)  (++(x) > --_Py_CheckRecursionLimit)
+#else
+#  define _Py_MakeRecCheck(x)  (++(x) > _Py_CheckRecursionLimit)
+#endif
+
 PyAPI_FUNC(char *) PyEval_GetFuncName(PyObject *);
 PyAPI_FUNC(char *) PyEval_GetFuncDesc(PyObject *);
 
index cf1bb373ab7c181c6cbcb66caf8468bfd27701de..6e6d97de5c74ed6cf61ba4fe1d01dd1daa8339b6 100644 (file)
@@ -424,9 +424,8 @@ class AbstractPickleTests(unittest.TestCase):
         for proto in protocols:
             s = self.dumps(l, proto)
             x = self.loads(s)
-            self.assertEqual(x, l)
-            self.assertEqual(x, x[0])
-            self.assertEqual(id(x), id(x[0]))
+            self.assertEqual(len(x), 1)
+            self.assert_(x is x[0])
 
     def test_recursive_dict(self):
         d = {}
@@ -434,9 +433,8 @@ class AbstractPickleTests(unittest.TestCase):
         for proto in protocols:
             s = self.dumps(d, proto)
             x = self.loads(s)
-            self.assertEqual(x, d)
-            self.assertEqual(x[1], x)
-            self.assertEqual(id(x[1]), id(x))
+            self.assertEqual(x.keys(), [1])
+            self.assert_(x[1] is x)
 
     def test_recursive_inst(self):
         i = C()
@@ -444,9 +442,8 @@ class AbstractPickleTests(unittest.TestCase):
         for proto in protocols:
             s = self.dumps(i, 2)
             x = self.loads(s)
-            self.assertEqual(x, i)
-            self.assertEqual(x.attr, x)
-            self.assertEqual(id(x.attr), id(x))
+            self.assertEqual(dir(x), dir(i))
+            self.assert_(x.attr is x)
 
     def test_recursive_multi(self):
         l = []
@@ -457,12 +454,10 @@ class AbstractPickleTests(unittest.TestCase):
         for proto in protocols:
             s = self.dumps(l, proto)
             x = self.loads(s)
-            self.assertEqual(x, l)
-            self.assertEqual(x[0], i)
-            self.assertEqual(x[0].attr, d)
-            self.assertEqual(x[0].attr[1], x)
-            self.assertEqual(x[0].attr[1][0], i)
-            self.assertEqual(x[0].attr[1][0].attr, d)
+            self.assertEqual(len(x), 1)
+            self.assertEqual(dir(x[0]), dir(i))
+            self.assertEqual(x[0].attr.keys(), [1])
+            self.assert_(x[0].attr[1] is x)
 
     def test_garyp(self):
         self.assertRaises(self.error, self.loads, 'garyp')
index 652163411e839e785f5b3f03824f4a0662172c3d..e84cfbd50efb7ec1e12c02c7932ad5609803fd17 100644 (file)
@@ -167,16 +167,16 @@ class BuiltinTest(unittest.TestCase):
         self.assertEqual(cmp(-1, 1), -1)
         self.assertEqual(cmp(1, -1), 1)
         self.assertEqual(cmp(1, 1), 0)
-        # verify that circular objects are handled
+        # verify that circular objects are not handled
         a = []; a.append(a)
         b = []; b.append(b)
         from UserList import UserList
         c = UserList(); c.append(c)
-        self.assertEqual(cmp(a, b), 0)
-        self.assertEqual(cmp(b, c), 0)
-        self.assertEqual(cmp(c, a), 0)
-        self.assertEqual(cmp(a, c), 0)
-        # okay, now break the cycles
+        self.assertRaises(RuntimeError, cmp, a, b)
+        self.assertRaises(RuntimeError, cmp, b, c)
+        self.assertRaises(RuntimeError, cmp, c, a)
+        self.assertRaises(RuntimeError, cmp, a, c)
+       # okay, now break the cycles
         a.pop(); b.pop(); c.pop()
         self.assertRaises(TypeError, cmp)
 
index 3d44304db9cd57776dff2ac692c414e862c521ba..6e32ddd88f7273a6cac4d93179583e31e43be105 100644 (file)
@@ -272,10 +272,10 @@ class TestCopy(unittest.TestCase):
         x = []
         x.append(x)
         y = copy.deepcopy(x)
-        self.assertEqual(y, x)
+        self.assertRaises(RuntimeError, cmp, y, x)
         self.assert_(y is not x)
-        self.assert_(y[0] is not x[0])
-        self.assert_(y is y[0])
+        self.assert_(y[0] is y)
+        self.assertEqual(len(y), 1)
 
     def test_deepcopy_tuple(self):
         x = ([1, 2], 3)
@@ -288,7 +288,7 @@ class TestCopy(unittest.TestCase):
         x = ([],)
         x[0].append(x)
         y = copy.deepcopy(x)
-        self.assertEqual(y, x)
+        self.assertRaises(RuntimeError, cmp, y, x)
         self.assert_(y is not x)
         self.assert_(y[0] is not x[0])
         self.assert_(y[0][0] is y)
@@ -304,10 +304,10 @@ class TestCopy(unittest.TestCase):
         x = {}
         x['foo'] = x
         y = copy.deepcopy(x)
-        self.assertEqual(y, x)
+        self.assertRaises(RuntimeError, cmp, y, x)
         self.assert_(y is not x)
         self.assert_(y['foo'] is y)
-        self.assertEqual(y, {'foo': y})
+        self.assertEqual(len(y), 1)
 
     def test_deepcopy_keepalive(self):
         memo = {}
index 5ade8ede5699915715beb0723bcda67139cb2858..006b1528c947d6d7b376a9061935564dbade774c 100644 (file)
@@ -224,57 +224,36 @@ class MiscTest(unittest.TestCase):
             self.assertRaises(Exc, func, Bad())
 
     def test_recursion(self):
-        # Check comparison for recursive objects
+        # Check that comparison for recursive objects fails gracefully
         from UserList import UserList
-        a = UserList(); a.append(a)
-        b = UserList(); b.append(b)
-
-        self.assert_(a == b)
-        self.assert_(not a != b)
-        a.append(1)
-        self.assert_(a == a[0])
-        self.assert_(not a != a[0])
-        self.assert_(a != b)
-        self.assert_(not a == b)
-        b.append(0)
-        self.assert_(a != b)
-        self.assert_(not a == b)
-        a[1] = -1
-        self.assert_(a != b)
-        self.assert_(not a == b)
-
         a = UserList()
         b = UserList()
         a.append(b)
         b.append(a)
-        self.assert_(a == b)
-        self.assert_(not a != b)
+        self.assertRaises(RuntimeError, operator.eq, a, b)
+        self.assertRaises(RuntimeError, operator.ne, a, b)
+        self.assertRaises(RuntimeError, operator.lt, a, b)
+        self.assertRaises(RuntimeError, operator.le, a, b)
+        self.assertRaises(RuntimeError, operator.gt, a, b)
+        self.assertRaises(RuntimeError, operator.ge, a, b)
 
         b.append(17)
+        # Even recursive lists of different lengths are different,
+        # but they cannot be ordered
+        self.assert_(not (a == b))
         self.assert_(a != b)
-        self.assert_(not a == b)
+        self.assertRaises(RuntimeError, operator.lt, a, b)
+        self.assertRaises(RuntimeError, operator.le, a, b)
+        self.assertRaises(RuntimeError, operator.gt, a, b)
+        self.assertRaises(RuntimeError, operator.ge, a, b)
         a.append(17)
-        self.assert_(a == b)
-        self.assert_(not a != b)
-
-    def test_recursion2(self):
-        # This test exercises the circular structure handling code
-        # in PyObject_RichCompare()
-        class Weird(object):
-            def __eq__(self, other):
-                return self != other
-            def __ne__(self, other):
-                return self == other
-            def __lt__(self, other):
-                return self > other
-            def __gt__(self, other):
-                return self < other
-
-        self.assert_(Weird() == Weird())
-        self.assert_(not (Weird() != Weird()))
-
-        for op in opmap["lt"]:
-            self.assertRaises(ValueError, op, Weird(), Weird())
+        self.assertRaises(RuntimeError, operator.eq, a, b)
+        self.assertRaises(RuntimeError, operator.ne, a, b)
+        a.insert(0, 11)
+        b.insert(0, 12)
+        self.assert_(not (a == b))
+        self.assert_(a != b)
+        self.assert_(a < b)
 
 class DictTest(unittest.TestCase):
 
index 74096c0289539213531ae5023b6665a796fdddb1..a16f1195d63f3d54752ec3ce3506c79d378ac5cb 100644 (file)
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -47,6 +47,10 @@ Core and builtins
 - obj.__contains__() now returns True/False instead of 1/0.  SF patch
   820195.
 
+- Python no longer tries to be smart about recursive comparisons.
+  When comparing containers with cyclic references to themselves it
+  will now just hit the recursion limit.  See SF patch 825639.
+
 Extension modules
 -----------------
 
index b0e19347d44ca700dba7df60e773588e9a2642c5..84b297cad95c2ab0bc27117a34bc7252610da6c3 100644 (file)
@@ -1970,7 +1970,6 @@ instance_iternext(PyInstanceObject *self)
 static PyObject *
 instance_call(PyObject *func, PyObject *arg, PyObject *kw)
 {
-       PyThreadState *tstate = PyThreadState_GET();
        PyObject *res, *call = PyObject_GetAttrString(func, "__call__");
        if (call == NULL) {
                PyInstanceObject *inst = (PyInstanceObject*) func;
@@ -1990,14 +1989,13 @@ instance_call(PyObject *func, PyObject *arg, PyObject *kw)
               a() # infinite recursion
           This bounces between instance_call() and PyObject_Call() without
           ever hitting eval_frame() (which has the main recursion check). */
-       if (tstate->recursion_depth++ > Py_GetRecursionLimit()) {
-               PyErr_SetString(PyExc_RuntimeError,
-                               "maximum __call__ recursion depth exceeded");
+       if (Py_EnterRecursiveCall(" in __call__")) {
                res = NULL;
        }
-       else
+       else {
                res = PyObject_Call(call, arg, kw);
-       tstate->recursion_depth--;
+               Py_LeaveRecursiveCall();
+       }
        Py_DECREF(call);
        return res;
 }
index 8c4bd0ef6742b44341b35c4a27c6955418c2f585..d85f697fa4e4c3ccead499d84e775db6c39f0a14 100644 (file)
@@ -740,120 +740,6 @@ do_cmp(PyObject *v, PyObject *w)
        return default_3way_compare(v, w);
 }
 
-/* compare_nesting is incremented before calling compare (for
-   some types) and decremented on exit.  If the count exceeds the
-   nesting limit, enable code to detect circular data structures.
-
-   This is a tunable parameter that should only affect the performance
-   of comparisons, nothing else.  Setting it high makes comparing deeply
-   nested non-cyclical data structures faster, but makes comparing cyclical
-   data structures slower.
-*/
-#define NESTING_LIMIT 20
-
-static int compare_nesting = 0;
-
-static PyObject*
-get_inprogress_dict(void)
-{
-       static PyObject *key;
-       PyObject *tstate_dict, *inprogress;
-
-       if (key == NULL) {
-               key = PyString_InternFromString("cmp_state");
-               if (key == NULL)
-                       return NULL;
-       }
-
-       tstate_dict = PyThreadState_GetDict();
-       if (tstate_dict == NULL) {
-               PyErr_BadInternalCall();
-               return NULL;
-       }
-
-       inprogress = PyDict_GetItem(tstate_dict, key);
-       if (inprogress == NULL) {
-               inprogress = PyDict_New();
-               if (inprogress == NULL)
-                       return NULL;
-               if (PyDict_SetItem(tstate_dict, key, inprogress) == -1) {
-                   Py_DECREF(inprogress);
-                   return NULL;
-               }
-               Py_DECREF(inprogress);
-       }
-
-       return inprogress;
-}
-
-/* If the comparison "v op w" is already in progress in this thread, returns
- * a borrowed reference to Py_None (the caller must not decref).
- * If it's not already in progress, returns "a token" which must eventually
- * be passed to delete_token().  The caller must not decref this either
- * (delete_token decrefs it).  The token must not survive beyond any point
- * where v or w may die.
- * If an error occurs (out-of-memory), returns NULL.
- */
-static PyObject *
-check_recursion(PyObject *v, PyObject *w, int op)
-{
-       PyObject *inprogress;
-       PyObject *token;
-       Py_uintptr_t iv = (Py_uintptr_t)v;
-       Py_uintptr_t iw = (Py_uintptr_t)w;
-       PyObject *x, *y, *z;
-
-       inprogress = get_inprogress_dict();
-       if (inprogress == NULL)
-               return NULL;
-
-       token = PyTuple_New(3);
-       if (token == NULL)
-               return NULL;
-
-       if (iv <= iw) {
-               PyTuple_SET_ITEM(token, 0, x = PyLong_FromVoidPtr((void *)v));
-               PyTuple_SET_ITEM(token, 1, y = PyLong_FromVoidPtr((void *)w));
-               if (op >= 0)
-                       op = swapped_op[op];
-       } else {
-               PyTuple_SET_ITEM(token, 0, x = PyLong_FromVoidPtr((void *)w));
-               PyTuple_SET_ITEM(token, 1, y = PyLong_FromVoidPtr((void *)v));
-       }
-       PyTuple_SET_ITEM(token, 2, z = PyInt_FromLong((long)op));
-       if (x == NULL || y == NULL || z == NULL) {
-               Py_DECREF(token);
-               return NULL;
-       }
-
-       if (PyDict_GetItem(inprogress, token) != NULL) {
-               Py_DECREF(token);
-               return Py_None; /* Without INCREF! */
-       }
-
-       if (PyDict_SetItem(inprogress, token, token) < 0) {
-               Py_DECREF(token);
-               return NULL;
-       }
-
-       return token;
-}
-
-static void
-delete_token(PyObject *token)
-{
-       PyObject *inprogress;
-
-       if (token == NULL || token == Py_None)
-               return;
-       inprogress = get_inprogress_dict();
-       if (inprogress == NULL)
-               PyErr_Clear();
-       else
-               PyDict_DelItem(inprogress, token);
-       Py_DECREF(token);
-}
-
 /* Compare v to w.  Return
    -1 if v <  w or exception (PyErr_Occurred() true in latter case).
     0 if v == w.
@@ -867,12 +753,6 @@ PyObject_Compare(PyObject *v, PyObject *w)
        PyTypeObject *vtp;
        int result;
 
-#if defined(USE_STACKCHECK)
-       if (PyOS_CheckStack()) {
-               PyErr_SetString(PyExc_MemoryError, "Stack overflow");
-               return -1;
-       }
-#endif
        if (v == NULL || w == NULL) {
                PyErr_BadInternalCall();
                return -1;
@@ -880,31 +760,10 @@ PyObject_Compare(PyObject *v, PyObject *w)
        if (v == w)
                return 0;
        vtp = v->ob_type;
-       compare_nesting++;
-       if (compare_nesting > NESTING_LIMIT &&
-           (vtp->tp_as_mapping || vtp->tp_as_sequence) &&
-           !PyString_CheckExact(v) &&
-           !PyTuple_CheckExact(v)) {
-               /* try to detect circular data structures */
-               PyObject *token = check_recursion(v, w, -1);
-
-               if (token == NULL) {
-                       result = -1;
-               }
-               else if (token == Py_None) {
-                       /* already comparing these objects.  assume
-                          they're equal until shown otherwise */
-                        result = 0;
-               }
-               else {
-                       result = do_cmp(v, w);
-                       delete_token(token);
-               }
-       }
-       else {
-               result = do_cmp(v, w);
-       }
-       compare_nesting--;
+       if (Py_EnterRecursiveCall(" in cmp"))
+               return -1;
+       result = do_cmp(v, w);
+       Py_LeaveRecursiveCall();
        return result < 0 ? -1 : result;
 }
 
@@ -975,41 +834,10 @@ PyObject_RichCompare(PyObject *v, PyObject *w, int op)
        PyObject *res;
 
        assert(Py_LT <= op && op <= Py_GE);
+       if (Py_EnterRecursiveCall(" in cmp"))
+               return NULL;
 
-       compare_nesting++;
-       if (compare_nesting > NESTING_LIMIT &&
-           (v->ob_type->tp_as_mapping || v->ob_type->tp_as_sequence) &&
-           !PyString_CheckExact(v) &&
-           !PyTuple_CheckExact(v)) {
-               /* try to detect circular data structures */
-               PyObject *token = check_recursion(v, w, op);
-               if (token == NULL) {
-                       res = NULL;
-                       goto Done;
-               }
-               else if (token == Py_None) {
-                       /* already comparing these objects with this operator.
-                          assume they're equal until shown otherwise */
-                       if (op == Py_EQ)
-                               res = Py_True;
-                       else if (op == Py_NE)
-                               res = Py_False;
-                       else {
-                               PyErr_SetString(PyExc_ValueError,
-                                       "can't order recursive values");
-                               res = NULL;
-                       }
-                       Py_XINCREF(res);
-               }
-               else {
-                       res = do_richcmp(v, w, op);
-                       delete_token(token);
-               }
-               goto Done;
-       }
-
-       /* No nesting extremism.
-          If the types are equal, and not old-style instances, try to
+       /* If the types are equal, and not old-style instances, try to
           get out cheap (don't bother with coercions etc.). */
        if (v->ob_type == w->ob_type && !PyInstance_Check(v)) {
                cmpfunc fcmp;
@@ -1041,7 +869,7 @@ PyObject_RichCompare(PyObject *v, PyObject *w, int op)
        /* Fast path not taken, or couldn't deliver a useful result. */
        res = do_richcmp(v, w, op);
 Done:
-       compare_nesting--;
+       Py_LeaveRecursiveCall();
        return res;
 }
 
index e6b742499bad01d784933eb55134d0e29103c907..fe8aca5a1dcfa669acadaa47b5fc48e527c218a8 100644 (file)
@@ -497,6 +497,7 @@ Py_MakePendingCalls(void)
 /* The interpreter's recursion limit */
 
 static int recursion_limit = 1000;
+int _Py_CheckRecursionLimit = 1000;
 
 int
 Py_GetRecursionLimit(void)
@@ -508,8 +509,38 @@ void
 Py_SetRecursionLimit(int new_limit)
 {
        recursion_limit = new_limit;
+        _Py_CheckRecursionLimit = recursion_limit;
 }
 
+/* the macro Py_EnterRecursiveCall() only calls _Py_CheckRecursiveCall()
+   if the recursion_depth reaches _Py_CheckRecursionLimit.
+   If USE_STACKCHECK, the macro decrements _Py_CheckRecursionLimit
+   to guarantee that _Py_CheckRecursiveCall() is regularly called.
+   Without USE_STACKCHECK, there is no need for this. */
+int
+_Py_CheckRecursiveCall(char *where)
+{
+       PyThreadState *tstate = PyThreadState_GET();
+
+#ifdef USE_STACKCHECK
+       if (PyOS_CheckStack()) {
+               --tstate->recursion_depth;
+               PyErr_SetString(PyExc_MemoryError, "Stack overflow");
+               return -1;
+       }
+#endif
+       if (tstate->recursion_depth > recursion_limit) {
+               --tstate->recursion_depth;
+               PyErr_Format(PyExc_RuntimeError,
+                            "maximum recursion depth exceeded%s",
+                            where);
+               return -1;
+       }
+        _Py_CheckRecursionLimit = recursion_limit;
+       return 0;
+}
+
+
 /* Status code for main loop (reason for stack unwind) */
 
 enum why_code {
@@ -674,21 +705,9 @@ eval_frame(PyFrameObject *f)
        if (f == NULL)
                return NULL;
 
-#ifdef USE_STACKCHECK
-       if (tstate->recursion_depth%10 == 0 && PyOS_CheckStack()) {
-               PyErr_SetString(PyExc_MemoryError, "Stack overflow");
-               return NULL;
-       }
-#endif
-
        /* push frame */
-       if (++tstate->recursion_depth > recursion_limit) {
-               --tstate->recursion_depth;
-               PyErr_SetString(PyExc_RuntimeError,
-                               "maximum recursion depth exceeded");
-               tstate->frame = f->f_back;
+       if (Py_EnterRecursiveCall(""))
                return NULL;
-       }
 
        tstate->frame = f;
 
@@ -710,9 +729,7 @@ eval_frame(PyFrameObject *f)
                        if (call_trace(tstate->c_tracefunc, tstate->c_traceobj,
                                       f, PyTrace_CALL, Py_None)) {
                                /* Trace function raised an error */
-                               --tstate->recursion_depth;
-                               tstate->frame = f->f_back;
-                               return NULL;
+                               goto exit_eval_frame;
                        }
                }
                if (tstate->c_profilefunc != NULL) {
@@ -722,9 +739,7 @@ eval_frame(PyFrameObject *f)
                                       tstate->c_profileobj,
                                       f, PyTrace_CALL, Py_None)) {
                                /* Profile function raised an error */
-                               --tstate->recursion_depth;
-                               tstate->frame = f->f_back;
-                               return NULL;
+                               goto exit_eval_frame;
                        }
                }
        }
@@ -2428,7 +2443,8 @@ eval_frame(PyFrameObject *f)
        reset_exc_info(tstate);
 
        /* pop frame */
-       --tstate->recursion_depth;
+    exit_eval_frame:
+       Py_LeaveRecursiveCall();
        tstate->frame = f->f_back;
 
        return retval;