]> granicus.if.org Git - python/commitdiff
SF patch #1390657:
authorArmin Rigo <arigo@tunes.org>
Thu, 29 Dec 2005 15:59:19 +0000 (15:59 +0000)
committerArmin Rigo <arigo@tunes.org>
Thu, 29 Dec 2005 15:59:19 +0000 (15:59 +0000)
* set sq_repeat and sq_concat to NULL for user-defined new-style
  classes, as a way to fix a number of related problems.  See
  test_descr.notimplemented()).  One of these problems was fixed
  in r25556 and r25557 but many more existed; this is a general
  fix and thus reverts r25556-r25557.

* to avoid having PySequence_Repeat()/PySequence_Concat() failing
  on user-defined classes, they now fall back to nb_add/nb_mul if
  sq_concat/sq_repeat are not defined and the arguments appear to
  be sequences.

* added tests.

Backport candidate.

Lib/test/test_descr.py
Lib/test/test_operator.py
Objects/abstract.c
Objects/typeobject.c

index f594ca88750dc4bec937604bbb1610a272c40535..2ea8186846e8470b411a152ac2cedc0dcb56b523 100644 (file)
@@ -3990,6 +3990,77 @@ def methodwrapper():
     verify(l.__add__.__objclass__ is list)
     vereq(l.__add__.__doc__, list.__add__.__doc__)
 
+def notimplemented():
+    # all binary methods should be able to return a NotImplemented
+    if verbose:
+        print "Testing NotImplemented..."
+
+    import sys
+    import types
+    import operator
+
+    def specialmethod(self, other):
+        return NotImplemented
+
+    def check(expr, x, y):
+        try:
+            exec expr in {'x': x, 'y': y, 'operator': operator}
+        except TypeError:
+            pass
+        else:
+            raise TestFailed("no TypeError from %r" % (expr,))
+
+    N1 = sys.maxint + 1L    # might trigger OverflowErrors instead of TypeErrors
+    N2 = sys.maxint         # if sizeof(int) < sizeof(long), might trigger
+                            #   ValueErrors instead of TypeErrors
+    for metaclass in [type, types.ClassType]:
+        for name, expr, iexpr in [
+                ('__add__',      'x + y',                   'x += y'),
+                ('__sub__',      'x - y',                   'x -= y'),
+                ('__mul__',      'x * y',                   'x *= y'),
+                ('__truediv__',  'operator.truediv(x, y)',  None),
+                ('__floordiv__', 'operator.floordiv(x, y)', None),
+                ('__div__',      'x / y',                   'x /= y'),
+                ('__mod__',      'x % y',                   'x %= y'),
+                ('__divmod__',   'divmod(x, y)',            None),
+                ('__pow__',      'x ** y',                  'x **= y'),
+                ('__lshift__',   'x << y',                  'x <<= y'),
+                ('__rshift__',   'x >> y',                  'x >>= y'),
+                ('__and__',      'x & y',                   'x &= y'),
+                ('__or__',       'x | y',                   'x |= y'),
+                ('__xor__',      'x ^ y',                   'x ^= y'),
+                ('__coerce__',   'coerce(x, y)',            None)]:
+            if name == '__coerce__':
+                rname = name
+            else:
+                rname = '__r' + name[2:]
+            A = metaclass('A', (), {name: specialmethod})
+            B = metaclass('B', (), {rname: specialmethod})
+            a = A()
+            b = B()
+            check(expr, a, a)
+            check(expr, a, b)
+            check(expr, b, a)
+            check(expr, b, b)
+            check(expr, a, N1)
+            check(expr, a, N2)
+            check(expr, N1, b)
+            check(expr, N2, b)
+            if iexpr:
+                check(iexpr, a, a)
+                check(iexpr, a, b)
+                check(iexpr, b, a)
+                check(iexpr, b, b)
+                check(iexpr, a, N1)
+                check(iexpr, a, N2)
+                iname = '__i' + name[2:]
+                C = metaclass('C', (), {iname: specialmethod})
+                c = C()
+                check(iexpr, c, a)
+                check(iexpr, c, b)
+                check(iexpr, c, N1)
+                check(iexpr, c, N2)
+
 def test_main():
     weakref_segfault() # Must be first, somehow
     do_this_first()
@@ -4084,6 +4155,7 @@ def test_main():
     vicious_descriptor_nonsense()
     test_init()
     methodwrapper()
+    notimplemented()
 
     if verbose: print "All OK"
 
index 725b2d947750f6f979a2dc82974a0d68b1a80e83..6cc7945f3d7d2f1027261251536ab3fa1730ec95 100644 (file)
@@ -3,6 +3,34 @@ import unittest
 
 from test import test_support
 
+class Seq1:
+    def __init__(self, lst):
+        self.lst = lst
+    def __len__(self):
+        return len(self.lst)
+    def __getitem__(self, i):
+        return self.lst[i]
+    def __add__(self, other):
+        return self.lst + other.lst
+    def __mul__(self, other):
+        return self.lst * other
+    def __rmul__(self, other):
+        return other * self.lst
+
+class Seq2(object):
+    def __init__(self, lst):
+        self.lst = lst
+    def __len__(self):
+        return len(self.lst)
+    def __getitem__(self, i):
+        return self.lst[i]
+    def __add__(self, other):
+        return self.lst + other.lst
+    def __mul__(self, other):
+        return self.lst * other
+    def __rmul__(self, other):
+        return other * self.lst
+
 
 class OperatorTestCase(unittest.TestCase):
     def test_lt(self):
@@ -92,6 +120,9 @@ class OperatorTestCase(unittest.TestCase):
         self.failUnlessRaises(TypeError, operator.concat, None, None)
         self.failUnless(operator.concat('py', 'thon') == 'python')
         self.failUnless(operator.concat([1, 2], [3, 4]) == [1, 2, 3, 4])
+        self.failUnless(operator.concat(Seq1([5, 6]), Seq1([7])) == [5, 6, 7])
+        self.failUnless(operator.concat(Seq2([5, 6]), Seq2([7])) == [5, 6, 7])
+        self.failUnlessRaises(TypeError, operator.concat, 13, 29)
 
     def test_countOf(self):
         self.failUnlessRaises(TypeError, operator.countOf)
@@ -246,6 +277,15 @@ class OperatorTestCase(unittest.TestCase):
         self.failUnless(operator.repeat(a, 2) == a+a)
         self.failUnless(operator.repeat(a, 1) == a)
         self.failUnless(operator.repeat(a, 0) == '')
+        a = Seq1([4, 5, 6])
+        self.failUnless(operator.repeat(a, 2) == [4, 5, 6, 4, 5, 6])
+        self.failUnless(operator.repeat(a, 1) == [4, 5, 6])
+        self.failUnless(operator.repeat(a, 0) == [])
+        a = Seq2([4, 5, 6])
+        self.failUnless(operator.repeat(a, 2) == [4, 5, 6, 4, 5, 6])
+        self.failUnless(operator.repeat(a, 1) == [4, 5, 6])
+        self.failUnless(operator.repeat(a, 0) == [])
+        self.failUnlessRaises(TypeError, operator.repeat, 6, 7)
 
     def test_rshift(self):
         self.failUnlessRaises(TypeError, operator.rshift)
index 1f8feb52e721d17214c4c4836c6bbd7a03a94938..6e070a9099b86e75a28b3b7f53fa533b3e4c3fb8 100644 (file)
@@ -635,14 +635,11 @@ PyNumber_Add(PyObject *v, PyObject *w)
        PyObject *result = binary_op1(v, w, NB_SLOT(nb_add));
        if (result == Py_NotImplemented) {
                PySequenceMethods *m = v->ob_type->tp_as_sequence;
+               Py_DECREF(result);
                if (m && m->sq_concat) {
-                       Py_DECREF(result);
-                       result = (*m->sq_concat)(v, w);
+                       return (*m->sq_concat)(v, w);
                }
-               if (result == Py_NotImplemented) {
-                       Py_DECREF(result);
-                       return binop_type_error(v, w, "+");
-                }
+               result = binop_type_error(v, w, "+");
        }
        return result;
 }
@@ -1144,6 +1141,15 @@ PySequence_Concat(PyObject *s, PyObject *o)
        if (m && m->sq_concat)
                return m->sq_concat(s, o);
 
+       /* Instances of user classes defining an __add__() method only
+          have an nb_add slot, not an sq_concat slot.  So we fall back
+          to nb_add if both arguments appear to be sequences. */
+       if (PySequence_Check(s) && PySequence_Check(o)) {
+               PyObject *result = binary_op1(s, o, NB_SLOT(nb_add));
+               if (result != Py_NotImplemented)
+                       return result;
+               Py_DECREF(result);
+       }
        return type_error("object can't be concatenated");
 }
 
@@ -1159,6 +1165,20 @@ PySequence_Repeat(PyObject *o, int count)
        if (m && m->sq_repeat)
                return m->sq_repeat(o, count);
 
+       /* Instances of user classes defining a __mul__() method only
+          have an nb_multiply slot, not an sq_repeat slot. so we fall back
+          to nb_multiply if o appears to be a sequence. */
+       if (PySequence_Check(o)) {
+               PyObject *n, *result;
+               n = PyInt_FromLong(count);
+               if (n == NULL)
+                       return NULL;
+               result = binary_op1(o, n, NB_SLOT(nb_multiply));
+               Py_DECREF(n);
+               if (result != Py_NotImplemented)
+                       return result;
+               Py_DECREF(result);
+       }
        return type_error("object can't be repeated");
 }
 
@@ -1176,6 +1196,13 @@ PySequence_InPlaceConcat(PyObject *s, PyObject *o)
        if (m && m->sq_concat)
                return m->sq_concat(s, o);
 
+       if (PySequence_Check(s) && PySequence_Check(o)) {
+               PyObject *result = binary_iop1(s, o, NB_SLOT(nb_inplace_add),
+                                              NB_SLOT(nb_add));
+               if (result != Py_NotImplemented)
+                       return result;
+               Py_DECREF(result);
+       }
        return type_error("object can't be concatenated");
 }
 
@@ -1193,6 +1220,18 @@ PySequence_InPlaceRepeat(PyObject *o, int count)
        if (m && m->sq_repeat)
                return m->sq_repeat(o, count);
 
+       if (PySequence_Check(o)) {
+               PyObject *n, *result;
+               n = PyInt_FromLong(count);
+               if (n == NULL)
+                       return NULL;
+               result = binary_iop1(o, n, NB_SLOT(nb_inplace_multiply),
+                                    NB_SLOT(nb_multiply));
+               Py_DECREF(n);
+               if (result != Py_NotImplemented)
+                       return result;
+               Py_DECREF(result);
+       }
        return type_error("object can't be repeated");
 }
 
index 7c36ba4f402a72f666d48b348dac5d9b0b870b55..b74fa1ad9f9450e8eb2b314fe8b38f86ef3ffb1d 100644 (file)
@@ -4095,9 +4095,6 @@ slot_sq_length(PyObject *self)
        return len;
 }
 
-SLOT1(slot_sq_concat, "__add__", PyObject *, "O")
-SLOT1(slot_sq_repeat, "__mul__", int, "i")
-
 /* Super-optimized version of slot_sq_item.
    Other slots could do the same... */
 static PyObject *
@@ -4211,9 +4208,6 @@ slot_sq_contains(PyObject *self, PyObject *value)
        return result;
 }
 
-SLOT1(slot_sq_inplace_concat, "__iadd__", PyObject *, "O")
-SLOT1(slot_sq_inplace_repeat, "__imul__", int, "i")
-
 #define slot_mp_length slot_sq_length
 
 SLOT1(slot_mp_subscript, "__getitem__", PyObject *, "O")
@@ -4926,12 +4920,17 @@ typedef struct wrapperbase slotdef;
 static slotdef slotdefs[] = {
        SQSLOT("__len__", sq_length, slot_sq_length, wrap_inquiry,
               "x.__len__() <==> len(x)"),
-       SQSLOT("__add__", sq_concat, slot_sq_concat, wrap_binaryfunc,
-              "x.__add__(y) <==> x+y"),
-       SQSLOT("__mul__", sq_repeat, slot_sq_repeat, wrap_intargfunc,
-              "x.__mul__(n) <==> x*n"),
-       SQSLOT("__rmul__", sq_repeat, slot_sq_repeat, wrap_intargfunc,
-              "x.__rmul__(n) <==> n*x"),
+       /* Heap types defining __add__/__mul__ have sq_concat/sq_repeat == NULL.
+          The logic in abstract.c always falls back to nb_add/nb_multiply in
+          this case.  Defining both the nb_* and the sq_* slots to call the
+          user-defined methods has unexpected side-effects, as shown by
+          test_descr.notimplemented() */
+       SQSLOT("__add__", sq_concat, NULL, wrap_binaryfunc,
+          "x.__add__(y) <==> x+y"),
+       SQSLOT("__mul__", sq_repeat, NULL, wrap_intargfunc,
+          "x.__mul__(n) <==> x*n"),
+       SQSLOT("__rmul__", sq_repeat, NULL, wrap_intargfunc,
+          "x.__rmul__(n) <==> n*x"),
        SQSLOT("__getitem__", sq_item, slot_sq_item, wrap_sq_item,
               "x.__getitem__(y) <==> x[y]"),
        SQSLOT("__getslice__", sq_slice, slot_sq_slice, wrap_intintargfunc,
@@ -4953,10 +4952,10 @@ static slotdef slotdefs[] = {
                Use of negative indices is not supported."),
        SQSLOT("__contains__", sq_contains, slot_sq_contains, wrap_objobjproc,
               "x.__contains__(y) <==> y in x"),
-       SQSLOT("__iadd__", sq_inplace_concat, slot_sq_inplace_concat,
-              wrap_binaryfunc, "x.__iadd__(y) <==> x+=y"),
-       SQSLOT("__imul__", sq_inplace_repeat, slot_sq_inplace_repeat,
-              wrap_intargfunc, "x.__imul__(y) <==> x*=y"),
+       SQSLOT("__iadd__", sq_inplace_concat, NULL,
+          wrap_binaryfunc, "x.__iadd__(y) <==> x+=y"),
+       SQSLOT("__imul__", sq_inplace_repeat, NULL,
+          wrap_intargfunc, "x.__imul__(y) <==> x*=y"),
 
        MPSLOT("__len__", mp_length, slot_mp_length, wrap_inquiry,
               "x.__len__() <==> len(x)"),