]> granicus.if.org Git - python/commitdiff
Derivative of patch #102549, "simpler, faster(!) implementation of string.join".
authorTim Peters <tim.peters@gmail.com>
Fri, 19 Jan 2001 03:03:47 +0000 (03:03 +0000)
committerTim Peters <tim.peters@gmail.com>
Fri, 19 Jan 2001 03:03:47 +0000 (03:03 +0000)
Also fixes two long-standing bugs (present in 2.0):
1. .join() didn't check that the result size fit in an int.
2. string.join(s) when len(s)==1 returned s[0] regardless of s[0]'s
   type; e.g., "".join([3]) returned 3 (overly optimistic optimization).
I resisted a keen temptation to make .join() apply str() automagically.

Objects/stringobject.c

index eed4687d369fa141f24947ef0673942c61e2ac27..df3ab492f7b77e9e6d7c8791d87ba8d3f8bf3dd0 100644 (file)
@@ -794,46 +794,55 @@ static PyObject *
 string_join(PyStringObject *self, PyObject *args)
 {
        char *sep = PyString_AS_STRING(self);
-       int seplen = PyString_GET_SIZE(self);
+       const int seplen = PyString_GET_SIZE(self);
        PyObject *res = NULL;
-       int reslen = 0;
        char *p;
        int seqlen = 0;
-       int sz = 100;
-       int i, slen, sz_incr;
+       size_t sz = 0;
+       int i;
        PyObject *orig, *seq, *item;
 
        if (!PyArg_ParseTuple(args, "O:join", &orig))
                return NULL;
 
-       if (!(seq = PySequence_Fast(orig, ""))) {
+       seq = PySequence_Fast(orig, "");
+       if (seq == NULL) {
                if (PyErr_ExceptionMatches(PyExc_TypeError))
                        PyErr_Format(PyExc_TypeError,
                                     "sequence expected, %.80s found",
                                     orig->ob_type->tp_name);
                return NULL;
        }
-       /* From here on out, errors go through finally: for proper
-        * reference count manipulations.
-        */
+
        seqlen = PySequence_Size(seq);
+       if (seqlen == 0) {
+               Py_DECREF(seq);
+               return PyString_FromString("");
+       }
        if (seqlen == 1) {
                item = PySequence_Fast_GET_ITEM(seq, 0);
+               if (!PyString_Check(item) && !PyUnicode_Check(item)) {
+                       PyErr_Format(PyExc_TypeError,
+                                    "sequence item 0: expected string,"
+                                    " %.80s found",
+                                    item->ob_type->tp_name);
+                       Py_DECREF(seq);
+                       return NULL;
+               }
                Py_INCREF(item);
                Py_DECREF(seq);
                return item;
        }
 
-       if (!(res = PyString_FromStringAndSize((char*)NULL, sz)))
-               goto finally;
-
-       p = PyString_AS_STRING(res);
-
+       /* There are at least two things to join.  Do a pre-pass to figure out
+        * the total amount of space we'll need (sz), see whether any argument
+        * is absurd, and defer to the Unicode join if appropriate.
+        */
        for (i = 0; i < seqlen; i++) {
+               const size_t old_sz = sz;
                item = PySequence_Fast_GET_ITEM(seq, i);
                if (!PyString_Check(item)){
                        if (PyUnicode_Check(item)) {
-                               Py_DECREF(res);
                                Py_DECREF(seq);
                                return PyUnicode_Join((PyObject *)self, orig);
                        }
@@ -841,40 +850,45 @@ string_join(PyStringObject *self, PyObject *args)
                                     "sequence item %i: expected string,"
                                     " %.80s found",
                                     i, item->ob_type->tp_name);
-                       goto finally;
+                       Py_DECREF(seq);
+                       return NULL;
                }
-               slen = PyString_GET_SIZE(item);
-               while (reslen + slen + seplen >= sz) {
-                       /* at least double the size of the string */
-                       sz_incr = slen + seplen > sz ? slen + seplen : sz;
-                       if (_PyString_Resize(&res, sz + sz_incr)) {
-                               goto finally;
-                       }
-                       sz += sz_incr;
-                       p = PyString_AS_STRING(res) + reslen;
+               sz += PyString_GET_SIZE(item);
+               if (i != 0)
+                       sz += seplen;
+               if (sz < old_sz || sz > INT_MAX) {
+                       PyErr_SetString(PyExc_OverflowError,
+                               "join() is too long for a Python string");
+                       Py_DECREF(seq);
+                       return NULL;
                }
-               if (i > 0) {
+       }
+
+       /* Allocate result space. */
+       res = PyString_FromStringAndSize((char*)NULL, (int)sz);
+       if (res == NULL) {
+               Py_DECREF(seq);
+               return NULL;
+       }
+
+       /* Catenate everything. */
+       p = PyString_AS_STRING(res);
+       for (i = 0; i < seqlen; ++i) {
+               size_t n;
+               item = PySequence_Fast_GET_ITEM(seq, i);
+               n = PyString_GET_SIZE(item);
+               memcpy(p, PyString_AS_STRING(item), n);
+               p += n;
+               if (i < seqlen - 1) {
                        memcpy(p, sep, seplen);
                        p += seplen;
-                       reslen += seplen;
                }
-               memcpy(p, PyString_AS_STRING(item), slen);
-               p += slen;
-               reslen += slen;
        }
-       if (_PyString_Resize(&res, reslen))
-               goto finally;
-       Py_DECREF(seq);
-       return res;
 
-  finally:
        Py_DECREF(seq);
-       Py_XDECREF(res);
-       return NULL;
+       return res;
 }
 
-
-
 static long
 string_find_internal(PyStringObject *self, PyObject *args, int dir)
 {