]> granicus.if.org Git - python/commitdiff
SF Patch #1007087: Return new string for single subclass joins (Bug #1001011)
authorRaymond Hettinger <python@rcn.com>
Mon, 23 Aug 2004 23:23:54 +0000 (23:23 +0000)
committerRaymond Hettinger <python@rcn.com>
Mon, 23 Aug 2004 23:23:54 +0000 (23:23 +0000)
(Patch contributed by Nick Coghlan.)

Now joining string subtypes will always return a string.
Formerly, if there were only one item, it was returned unchanged.

Lib/test/test_string.py
Objects/stringobject.c

index 859dd4ea15801b4e7697d94c026780e50709a73d..ba9d9d397232b6114ae41cd29a811eb27bb1ba62 100644 (file)
@@ -52,6 +52,29 @@ class StringTest(
         self.checkraises(TypeError, string_tests.BadSeq1(), 'join', ' ')
         self.checkequal('a b c', string_tests.BadSeq2(), 'join', ' ')
 
+    def test_bug1001011(self):
+        # Make sure join returns a NEW object for single item sequences
+        # involving a subclass
+        # Make sure that it is of the appropriate type
+        # Check the optimisation still occurs for standard objects
+        class str_subclass(str): pass
+        s1 = str_subclass('abcd')
+        s2 = ''.join([s1])
+        self.failIf(s1 is s2)
+        self.assertEqual(type(s2), type(''))
+        s3 = 'abcd'
+        s4 = ''.join([s3])
+        self.failUnless(s3 is s4)
+        if test_support.have_unicode:
+            class unicode_subclass(unicode): pass
+            u1 = unicode_subclass(u'abcd')
+            u2 = ''.join([u1])
+            self.failIf(u1 is u2)
+            self.assertEqual(type(u2), type(u''))
+            u3 = u'abcd'
+            u4 = ''.join([u3])
+            self.failUnless(u3 is u4)
+
 class ModuleTest(unittest.TestCase):
 
     def test_attrs(self):
index b40351afcbac96bcca806e131707eff0ab412b37..c87b68800941adcdb46967a657a1204f2492d175 100644 (file)
@@ -1618,22 +1618,18 @@ string_join(PyStringObject *self, PyObject *orig)
        }
        if (seqlen == 1) {
                item = PySequence_Fast_GET_ITEM(seq, 0);
-               if (!PyString_Check(item) && !PyUnicode_Check(item)) {
-                       PyErr_Format(PyExc_TypeError,
-                                    "sequence item 0: expected string,"
-                                    " %.80s found",
-                                    item->ob_type->tp_name);
+               if (PyString_CheckExact(item) || PyUnicode_CheckExact(item)) {
+                       Py_INCREF(item);
                        Py_DECREF(seq);
-                       return NULL;
+                       return item;
                }
-               Py_INCREF(item);
-               Py_DECREF(seq);
-               return item;
        }
 
-       /* There are at least two things to join.  Do a pre-pass to figure out
-        * the total amount of space we'll need (sz), see whether any argument
-        * is absurd, and defer to the Unicode join if appropriate.
+       /* There are at least two things to join, or else we have a subclass
+        * of the builtin types in the sequence.  
+        * Do a pre-pass to figure out the total amount of space we'll
+        * need (sz), see whether any argument is absurd, and defer to
+        * the Unicode join if appropriate.
         */
        for (i = 0; i < seqlen; i++) {
                const size_t old_sz = sz;