]> granicus.if.org Git - python/commitdiff
Make sure filter() never returns tuple, str or unicode
authorWalter Dörwald <walter@livinglogic.de>
Tue, 4 Feb 2003 20:24:45 +0000 (20:24 +0000)
committerWalter Dörwald <walter@livinglogic.de>
Tue, 4 Feb 2003 20:24:45 +0000 (20:24 +0000)
subclasses. (Discussed in SF patch #665835)

Lib/test/test_builtin.py
Python/bltinmodule.c

index 2b0c01797075edd50eae29f96b71ef5aa689b56d..92e44d5f64d14e5d0f31f148fc17fe57f6fa15ed 100644 (file)
@@ -408,6 +408,29 @@ class BuiltinTest(unittest.TestCase):
                 unicode("345")
             )
 
+    def test_filter_subclasses(self):
+        # test, that filter() never returns tuple, str or unicode subclasses
+        funcs = (None, lambda x: True)
+        class tuple2(tuple):
+            pass
+        class str2(str):
+            pass
+        inputs = {
+            tuple2: [(), (1,2,3)],
+            str2:   ["", "123"]
+        }
+        if have_unicode:
+            class unicode2(unicode):
+                pass
+            inputs[unicode2] = [unicode(), unicode("123")]
+
+        for func in funcs:
+            for (cls, inps) in inputs.iteritems():
+                for inp in inps:
+                    out = filter(func, cls(inp))
+                    self.assertEqual(inp, out)
+                    self.assert_(not isinstance(out, cls))
+
     def test_float(self):
         self.assertEqual(float(3.14), 3.14)
         self.assertEqual(float(314), 314.0)
index 2383b4fdd072ffc1296d0639822ef1e24352a28e..b74e09c26acca9f0717c33dfe252df33aeb2456b 100644 (file)
@@ -1838,7 +1838,10 @@ filtertuple(PyObject *func, PyObject *tuple)
        int len = PyTuple_Size(tuple);
 
        if (len == 0) {
-               Py_INCREF(tuple);
+               if (PyTuple_CheckExact(tuple))
+                       Py_INCREF(tuple);
+               else
+                       tuple = PyTuple_New(0);
                return tuple;
        }
 
@@ -1895,8 +1898,15 @@ filterstring(PyObject *func, PyObject *strobj)
        int outlen = len;
 
        if (func == Py_None) {
-               /* No character is ever false -- share input string */
-               Py_INCREF(strobj);
+               /* No character is ever false -- share input string
+                * (if it's not a subclass) */
+               if (PyString_CheckExact(strobj))
+                       Py_INCREF(strobj);
+               else
+                       strobj = PyString_FromStringAndSize(
+                               PyString_AS_STRING(strobj),
+                               len
+                       );
                return strobj;
        }
        if ((result = PyString_FromStringAndSize(NULL, len)) == NULL)
@@ -1980,8 +1990,15 @@ filterunicode(PyObject *func, PyObject *strobj)
        int outlen = len;
 
        if (func == Py_None) {
-               /* No character is ever false -- share input string */
-               Py_INCREF(strobj);
+               /* No character is ever false -- share input string
+                * (it if's not a subclass) */
+               if (PyUnicode_CheckExact(strobj))
+                       Py_INCREF(strobj);
+               else
+                       strobj = PyUnicode_FromUnicode(
+                               PyUnicode_AS_UNICODE(strobj),
+                               len
+                       );
                return strobj;
        }
        if ((result = PyUnicode_FromUnicode(NULL, len)) == NULL)