]> granicus.if.org Git - python/commitdiff
bpo-32604: Fix memory leaks in the new _xxsubinterpreters module. (#5507)
authorEric Snow <ericsnowcurrently@gmail.com>
Sat, 3 Feb 2018 04:49:49 +0000 (21:49 -0700)
committerGitHub <noreply@github.com>
Sat, 3 Feb 2018 04:49:49 +0000 (21:49 -0700)
Lib/test/test__xxsubinterpreters.py
Modules/_xxsubinterpretersmodule.c
Python/pystate.c

index 2b170443a3b63895d0bdebc153b84088b837f45e..8d72ca20021486071f8b645b07295e613d514b0d 100644 (file)
@@ -362,13 +362,15 @@ class DestroyTests(TestBase):
     def test_from_current(self):
         main, = interpreters.list_all()
         id = interpreters.create()
-        script = dedent("""
+        script = dedent(f"""
             import _xxsubinterpreters as _interpreters
-            _interpreters.destroy({})
-            """).format(id)
+            try:
+                _interpreters.destroy({id})
+            except RuntimeError:
+                pass
+            """)
 
-        with self.assertRaises(RuntimeError):
-            interpreters.run_string(id, script)
+        interpreters.run_string(id, script)
         self.assertEqual(set(interpreters.list_all()), {main, id})
 
     def test_from_sibling(self):
@@ -761,12 +763,12 @@ class ChannelIDTests(TestBase):
         self.assertEqual(int(cid), 10)
 
     def test_bad_id(self):
-        ids = [-1, 2**64, "spam"]
-        for cid in ids:
+        for cid in [-1, 'spam']:
             with self.subTest(cid):
                 with self.assertRaises(ValueError):
                     interpreters._channel_id(cid)
-
+        with self.assertRaises(OverflowError):
+            interpreters._channel_id(2**64)
         with self.assertRaises(TypeError):
             interpreters._channel_id(object())
 
index d2b5f26fae1d0907d01e8e608d3fde48bbdf87d0..7829b4cd951156648f1cb16f1e43a3ee288b9270 100644 (file)
@@ -7,6 +7,22 @@
 #include "internal/pystate.h"
 
 
+static char *
+_copy_raw_string(PyObject *strobj)
+{
+    const char *str = PyUnicode_AsUTF8(strobj);
+    if (str == NULL) {
+        return NULL;
+    }
+    char *copied = PyMem_Malloc(strlen(str)+1);
+    if (str == NULL) {
+        PyErr_NoMemory();
+        return NULL;
+    }
+    strcpy(copied, str);
+    return copied;
+}
+
 static PyInterpreterState *
 _get_current(void)
 {
@@ -31,10 +47,13 @@ _coerce_id(PyObject *id)
         }
         return -1;
     }
-    long long cid = PyLong_AsLongLong(id);
+    int64_t cid = PyLong_AsLongLong(id);
+    Py_DECREF(id);
     if (cid == -1 && PyErr_Occurred() != NULL) {
-        PyErr_SetString(PyExc_ValueError,
-                        "'id' must be a non-negative int");
+        if (!PyErr_ExceptionMatches(PyExc_OverflowError)) {
+            PyErr_SetString(PyExc_ValueError,
+                            "'id' must be a non-negative int");
+        }
         return -1;
     }
     if (cid < 0) {
@@ -42,92 +61,131 @@ _coerce_id(PyObject *id)
                         "'id' must be a non-negative int");
         return -1;
     }
-    if (cid > INT64_MAX) {
-        PyErr_SetString(PyExc_ValueError,
-                        "'id' too large (must be 64-bit int)");
-        return -1;
-    }
     return cid;
 }
 
+
 /* data-sharing-specific code ***********************************************/
 
-typedef struct _shareditem {
-    Py_UNICODE *name;
-    Py_ssize_t namelen;
+struct _sharednsitem {
+    char *name;
     _PyCrossInterpreterData data;
-} _shareditem;
+};
+
+static int
+_sharednsitem_init(struct _sharednsitem *item, PyObject *key, PyObject *value)
+{
+    item->name = _copy_raw_string(key);
+    if (item->name == NULL) {
+        return -1;
+    }
+    if (_PyObject_GetCrossInterpreterData(value, &item->data) != 0) {
+        return -1;
+    }
+    return 0;
+}
+
+static void
+_sharednsitem_clear(struct _sharednsitem *item)
+{
+    if (item->name != NULL) {
+        PyMem_Free(item->name);
+    }
+    _PyCrossInterpreterData_Release(&item->data);
+}
+
+static int
+_sharednsitem_apply(struct _sharednsitem *item, PyObject *ns)
+{
+    PyObject *name = PyUnicode_FromString(item->name);
+    if (name == NULL) {
+        return -1;
+    }
+    PyObject *value = _PyCrossInterpreterData_NewObject(&item->data);
+    if (value == NULL) {
+        Py_DECREF(name);
+        return -1;
+    }
+    int res = PyDict_SetItem(ns, name, value);
+    Py_DECREF(name);
+    Py_DECREF(value);
+    return res;
+}
+
+typedef struct _sharedns {
+    Py_ssize_t len;
+    struct _sharednsitem* items;
+} _sharedns;
+
+static _sharedns *
+_sharedns_new(Py_ssize_t len)
+{
+    _sharedns *shared = PyMem_NEW(_sharedns, 1);
+    if (shared == NULL) {
+        PyErr_NoMemory();
+        return NULL;
+    }
+    shared->len = len;
+    shared->items = PyMem_NEW(struct _sharednsitem, len);
+    if (shared->items == NULL) {
+        PyErr_NoMemory();
+        PyMem_Free(shared);
+        return NULL;
+    }
+    return shared;
+}
 
-void
-_sharedns_clear(_shareditem *shared)
+static void
+_sharedns_free(_sharedns *shared)
 {
-    for (_shareditem *item=shared; item->name != NULL; item += 1) {
-        _PyCrossInterpreterData_Release(&item->data);
+    for (Py_ssize_t i=0; i < shared->len; i++) {
+        _sharednsitem_clear(&shared->items[i]);
     }
+    PyMem_Free(shared->items);
+    PyMem_Free(shared);
 }
 
-static _shareditem *
-_get_shared_ns(PyObject *shareable, Py_ssize_t *lenp)
+static _sharedns *
+_get_shared_ns(PyObject *shareable)
 {
     if (shareable == NULL || shareable == Py_None) {
-        *lenp = 0;
         return NULL;
     }
     Py_ssize_t len = PyDict_Size(shareable);
-    *lenp = len;
     if (len == 0) {
         return NULL;
     }
 
-    _shareditem *shared = PyMem_NEW(_shareditem, len+1);
+    _sharedns *shared = _sharedns_new(len);
     if (shared == NULL) {
         return NULL;
     }
-    for (Py_ssize_t i=0; i < len; i++) {
-        *(shared + i) = (_shareditem){0};
-    }
     Py_ssize_t pos = 0;
     for (Py_ssize_t i=0; i < len; i++) {
         PyObject *key, *value;
         if (PyDict_Next(shareable, &pos, &key, &value) == 0) {
             break;
         }
-        _shareditem *item = shared + i;
-
-        if (_PyObject_GetCrossInterpreterData(value, &item->data) != 0) {
+        if (_sharednsitem_init(&shared->items[i], key, value) != 0) {
             break;
         }
-        item->name = PyUnicode_AsUnicodeAndSize(key, &item->namelen);
-        if (item->name == NULL) {
-            _PyCrossInterpreterData_Release(&item->data);
-            break;
-        }
-        (item + 1)->name = NULL;  // Mark the next one as the last.
     }
     if (PyErr_Occurred()) {
-        _sharedns_clear(shared);
-        PyMem_Free(shared);
+        _sharedns_free(shared);
         return NULL;
     }
     return shared;
 }
 
 static int
-_shareditem_apply(_shareditem *item, PyObject *ns)
+_sharedns_apply(_sharedns *shared, PyObject *ns)
 {
-    PyObject *name = PyUnicode_FromUnicode(item->name, item->namelen);
-    if (name == NULL) {
-        return 1;
-    }
-    PyObject *value = _PyCrossInterpreterData_NewObject(&item->data);
-    if (value == NULL) {
-        Py_DECREF(name);
-        return 1;
+    for (Py_ssize_t i=0; i < shared->len; i++) {
+        if (_sharednsitem_apply(&shared->items[i], ns) != 0) {
+            return -1;
+        }
     }
-    int res = PyDict_SetItem(ns, name, value);
-    Py_DECREF(name);
-    Py_DECREF(value);
-    return res;
+    return 0;
 }
 
 // Ultimately we'd like to preserve enough information about the
@@ -136,65 +194,117 @@ _shareditem_apply(_shareditem *item, PyObject *ns)
 // of the exception in the calling interpreter.
 
 typedef struct _sharedexception {
+    char *name;
     char *msg;
 } _sharedexception;
 
 static _sharedexception *
-_get_shared_exception(void)
+_sharedexception_new(void)
 {
     _sharedexception *err = PyMem_NEW(_sharedexception, 1);
     if (err == NULL) {
+        PyErr_NoMemory();
         return NULL;
     }
-    PyObject *exc;
-    PyObject *value;
-    PyObject *tb;
-    PyErr_Fetch(&exc, &value, &tb);
-    PyObject *msg;
-    if (value == NULL) {
-        msg = PyUnicode_FromFormat("%S", exc);
-    }
-    else {
-        msg = PyUnicode_FromFormat("%S: %S", exc, value);
-    }
-    if (msg == NULL) {
-        err->msg = "unable to format exception";
-        return err;
+    err->name = NULL;
+    err->msg = NULL;
+    return err;
+}
+
+static void
+_sharedexception_clear(_sharedexception *exc)
+{
+    if (exc->name != NULL) {
+        PyMem_Free(exc->name);
     }
-    err->msg = (char *)PyUnicode_AsUTF8(msg);
-    if (err->msg == NULL) {
-        err->msg = "unable to encode exception";
+    if (exc->msg != NULL) {
+        PyMem_Free(exc->msg);
     }
-    return err;
 }
 
-static PyObject * RunFailedError;
+static void
+_sharedexception_free(_sharedexception *exc)
+{
+    _sharedexception_clear(exc);
+    PyMem_Free(exc);
+}
 
-static int
-interp_exceptions_init(PyObject *ns)
+static _sharedexception *
+_sharedexception_bind(PyObject *exctype, PyObject *exc, PyObject *tb)
 {
-    // XXX Move the exceptions into per-module memory?
+    assert(exctype != NULL);
+    char *failure = NULL;
 
-    // An uncaught exception came out of interp_run_string().
-    RunFailedError = PyErr_NewException("_xxsubinterpreters.RunFailedError",
-                                        PyExc_RuntimeError, NULL);
-    if (RunFailedError == NULL) {
-        return -1;
+    _sharedexception *err = _sharedexception_new();
+    if (err == NULL) {
+        goto finally;
     }
-    if (PyDict_SetItemString(ns, "RunFailedError", RunFailedError) != 0) {
-        return -1;
+
+    PyObject *name = PyUnicode_FromFormat("%S", exctype);
+    if (name == NULL) {
+        failure = "unable to format exception type name";
+        goto finally;
+    }
+    err->name = _copy_raw_string(name);
+    Py_DECREF(name);
+    if (err->name == NULL) {
+        if (PyErr_ExceptionMatches(PyExc_MemoryError)) {
+            failure = "out of memory copying exception type name";
+        }
+        failure = "unable to encode and copy exception type name";
+        goto finally;
     }
 
-    return 0;
+    if (exc != NULL) {
+        PyObject *msg = PyUnicode_FromFormat("%S", exc);
+        if (msg == NULL) {
+            failure = "unable to format exception message";
+            goto finally;
+        }
+        err->msg = _copy_raw_string(msg);
+        Py_DECREF(msg);
+        if (err->msg == NULL) {
+            if (PyErr_ExceptionMatches(PyExc_MemoryError)) {
+                failure = "out of memory copying exception message";
+            }
+            failure = "unable to encode and copy exception message";
+            goto finally;
+        }
+    }
+
+finally:
+    if (failure != NULL) {
+        PyErr_Clear();
+        if (err->name != NULL) {
+            PyMem_Free(err->name);
+            err->name = NULL;
+        }
+        err->msg = failure;
+    }
+    return err;
 }
 
 static void
-_apply_shared_exception(_sharedexception *exc)
+_sharedexception_apply(_sharedexception *exc, PyObject *wrapperclass)
 {
-    PyErr_SetString(RunFailedError, exc->msg);
+    if (exc->name != NULL) {
+        if (exc->msg != NULL) {
+            PyErr_Format(wrapperclass, "%s: %s",  exc->name, exc->msg);
+        }
+        else {
+            PyErr_SetString(wrapperclass, exc->name);
+        }
+    }
+    else if (exc->msg != NULL) {
+        PyErr_SetString(wrapperclass, exc->msg);
+    }
+    else {
+        PyErr_SetNone(wrapperclass);
+    }
 }
 
-/* channel-specific code */
+
+/* channel-specific code ****************************************************/
 
 static PyObject *ChannelError;
 static PyObject *ChannelNotFoundError;
@@ -249,6 +359,139 @@ channel_exceptions_init(PyObject *ns)
     return 0;
 }
 
+/* the channel queue */
+
+struct _channelitem;
+
+typedef struct _channelitem {
+    _PyCrossInterpreterData *data;
+    struct _channelitem *next;
+} _channelitem;
+
+static _channelitem *
+_channelitem_new(void)
+{
+    _channelitem *item = PyMem_NEW(_channelitem, 1);
+    if (item == NULL) {
+        PyErr_NoMemory();
+        return NULL;
+    }
+    item->data = NULL;
+    item->next = NULL;
+    return item;
+}
+
+static void
+_channelitem_clear(_channelitem *item)
+{
+    if (item->data != NULL) {
+        _PyCrossInterpreterData_Release(item->data);
+        PyMem_Free(item->data);
+        item->data = NULL;
+    }
+    item->next = NULL;
+}
+
+static void
+_channelitem_free(_channelitem *item)
+{
+    _channelitem_clear(item);
+    PyMem_Free(item);
+}
+
+static void
+_channelitem_free_all(_channelitem *item)
+{
+    while (item != NULL) {
+        _channelitem *last = item;
+        item = item->next;
+        _channelitem_free(last);
+    }
+}
+
+static _PyCrossInterpreterData *
+_channelitem_popped(_channelitem *item)
+{
+    _PyCrossInterpreterData *data = item->data;
+    item->data = NULL;
+    _channelitem_free(item);
+    return data;
+}
+
+typedef struct _channelqueue {
+    int64_t count;
+    _channelitem *first;
+    _channelitem *last;
+} _channelqueue;
+
+static _channelqueue *
+_channelqueue_new(void)
+{
+    _channelqueue *queue = PyMem_NEW(_channelqueue, 1);
+    if (queue == NULL) {
+        PyErr_NoMemory();
+        return NULL;
+    }
+    queue->count = 0;
+    queue->first = NULL;
+    queue->last = NULL;
+    return queue;
+}
+
+static void
+_channelqueue_clear(_channelqueue *queue)
+{
+    _channelitem_free_all(queue->first);
+    queue->count = 0;
+    queue->first = NULL;
+    queue->last = NULL;
+}
+
+static void
+_channelqueue_free(_channelqueue *queue)
+{
+    _channelqueue_clear(queue);
+    PyMem_Free(queue);
+}
+
+static int
+_channelqueue_put(_channelqueue *queue, _PyCrossInterpreterData *data)
+{
+    _channelitem *item = _channelitem_new();
+    if (item == NULL) {
+        return -1;
+    }
+    item->data = data;
+
+    queue->count += 1;
+    if (queue->first == NULL) {
+        queue->first = item;
+    }
+    else {
+        queue->last->next = item;
+    }
+    queue->last = item;
+    return 0;
+}
+
+static _PyCrossInterpreterData *
+_channelqueue_get(_channelqueue *queue)
+{
+    _channelitem *item = queue->first;
+    if (item == NULL) {
+        return NULL;
+    }
+    queue->first = item->next;
+    if (queue->last == item) {
+        queue->last = NULL;
+    }
+    queue->count -= 1;
+
+    return _channelitem_popped(item);
+}
+
+/* channel-interpreter associations */
+
 struct _channelend;
 
 typedef struct _channelend {
@@ -262,23 +505,28 @@ _channelend_new(int64_t interp)
 {
     _channelend *end = PyMem_NEW(_channelend, 1);
     if (end == NULL) {
+        PyErr_NoMemory();
         return NULL;
     }
-
     end->next = NULL;
     end->interp = interp;
-
     end->open = 1;
-
     return end;
 }
 
 static void
-_channelend_free_all(_channelend *end) {
+_channelend_free(_channelend *end)
+{
+    PyMem_Free(end);
+}
+
+static void
+_channelend_free_all(_channelend *end)
+{
     while (end != NULL) {
         _channelend *last = end;
         end = end->next;
-        PyMem_Free(last);
+        _channelend_free(last);
     }
 }
 
@@ -300,24 +548,7 @@ _channelend_find(_channelend *first, int64_t interp, _channelend **pprev)
     return end;
 }
 
-struct _channelitem;
-
-typedef struct _channelitem {
-    _PyCrossInterpreterData *data;
-    struct _channelitem *next;
-} _channelitem;
-
-struct _channel;
-
-typedef struct _channel {
-    PyThread_type_lock mutex;
-
-    int open;
-
-    int64_t count;
-    _channelitem *first;
-    _channelitem *last;
-
+typedef struct _channelassociations {
     // Note that the list entries are never removed for interpreter
     // for which the channel is closed.  This should be a problem in
     // practice.  Also, a channel isn't automatically closed when an
@@ -326,39 +557,43 @@ typedef struct _channel {
     int64_t numrecvopen;
     _channelend *send;
     _channelend *recv;
-} _PyChannelState;
+} _channelends;
 
-static _PyChannelState *
-_channel_new(void)
+static _channelends *
+_channelends_new(void)
 {
-    _PyChannelState *chan = PyMem_NEW(_PyChannelState, 1);
-    if (chan == NULL) {
+    _channelends *ends = PyMem_NEW(_channelends, 1);
+    if (ends== NULL) {
         return NULL;
     }
-    chan->mutex = PyThread_allocate_lock();
-    if (chan->mutex == NULL) {
-        PyMem_Free(chan);
-        PyErr_SetString(ChannelError,
-                        "can't initialize mutex for new channel");
-        return NULL;
-    }
-
-    chan->open = 1;
+    ends->numsendopen = 0;
+    ends->numrecvopen = 0;
+    ends->send = NULL;
+    ends->recv = NULL;
+    return ends;
+}
 
-    chan->count = 0;
-    chan->first = NULL;
-    chan->last = NULL;
+static void
+_channelends_clear(_channelends *ends)
+{
+    _channelend_free_all(ends->send);
+    ends->send = NULL;
+    ends->numsendopen = 0;
 
-    chan->numsendopen = 0;
-    chan->numrecvopen = 0;
-    chan->send = NULL;
-    chan->recv = NULL;
+    _channelend_free_all(ends->recv);
+    ends->recv = NULL;
+    ends->numrecvopen = 0;
+}
 
-    return chan;
+static void
+_channelends_free(_channelends *ends)
+{
+    _channelends_clear(ends);
+    PyMem_Free(ends);
 }
 
 static _channelend *
-_channel_add_end(_PyChannelState *chan, _channelend *prev, int64_t interp,
+_channelends_add(_channelends *ends, _channelend *prev, int64_t interp,
                  int send)
 {
     _channelend *end = _channelend_new(interp);
@@ -368,137 +603,163 @@ _channel_add_end(_PyChannelState *chan, _channelend *prev, int64_t interp,
 
     if (prev == NULL) {
         if (send) {
-            chan->send = end;
+            ends->send = end;
         }
         else {
-            chan->recv = end;
+            ends->recv = end;
         }
     }
     else {
         prev->next = end;
     }
     if (send) {
-        chan->numsendopen += 1;
+        ends->numsendopen += 1;
     }
     else {
-        chan->numrecvopen += 1;
+        ends->numrecvopen += 1;
     }
     return end;
 }
 
-static _channelend *
-_channel_associate_end(_PyChannelState *chan, int64_t interp, int send)
+static int
+_channelends_associate(_channelends *ends, int64_t interp, int send)
 {
-    if (!chan->open) {
-        PyErr_SetString(ChannelClosedError, "channel closed");
-        return NULL;
-    }
-
     _channelend *prev;
-    _channelend *end = _channelend_find(send ? chan->send : chan->recv,
+    _channelend *end = _channelend_find(send ? ends->send : ends->recv,
                                         interp, &prev);
     if (end != NULL) {
         if (!end->open) {
             PyErr_SetString(ChannelClosedError, "channel already closed");
-            return NULL;
+            return -1;
         }
         // already associated
-        return end;
+        return 0;
+    }
+    if (_channelends_add(ends, prev, interp, send) == NULL) {
+        return -1;
     }
-    return _channel_add_end(chan, prev, interp, send);
+    return 0;
+}
+
+static int
+_channelends_is_open(_channelends *ends)
+{
+    if (ends->numsendopen != 0 || ends->numrecvopen != 0) {
+        return 1;
+    }
+    if (ends->send == NULL && ends->recv == NULL) {
+        return 1;
+    }
+    return 0;
 }
 
 static void
-_channel_close_channelend(_PyChannelState *chan, _channelend *end, int send)
+_channelends_close_end(_channelends *ends, _channelend *end, int send)
 {
     end->open = 0;
     if (send) {
-        chan->numsendopen -= 1;
+        ends->numsendopen -= 1;
     }
     else {
-        chan->numrecvopen -= 1;
+        ends->numrecvopen -= 1;
     }
 }
 
 static int
-_channel_close_interpreter(_PyChannelState *chan, int64_t interp, int which)
+_channelends_close_interpreter(_channelends *ends, int64_t interp, int which)
 {
-    PyThread_acquire_lock(chan->mutex, WAIT_LOCK);
-
-    int res = -1;
-    if (!chan->open) {
-        PyErr_SetString(ChannelClosedError, "channel already closed");
-        goto done;
-    }
-
     _channelend *prev;
     _channelend *end;
     if (which >= 0) {  // send/both
-        end = _channelend_find(chan->send, interp, &prev);
+        end = _channelend_find(ends->send, interp, &prev);
         if (end == NULL) {
             // never associated so add it
-            end = _channel_add_end(chan, prev, interp, 1);
+            end = _channelends_add(ends, prev, interp, 1);
             if (end == NULL) {
-                goto done;
+                return -1;
             }
         }
-        _channel_close_channelend(chan, end, 1);
+        _channelends_close_end(ends, end, 1);
     }
     if (which <= 0) {  // recv/both
-        end = _channelend_find(chan->recv, interp, &prev);
+        end = _channelend_find(ends->recv, interp, &prev);
         if (end == NULL) {
             // never associated so add it
-            end = _channel_add_end(chan, prev, interp, 0);
+            end = _channelends_add(ends, prev, interp, 0);
             if (end == NULL) {
-                goto done;
+                return -1;
             }
         }
-        _channel_close_channelend(chan, end, 0);
-    }
-
-    if (chan->numsendopen == 0 && chan->numrecvopen == 0) {
-        if (chan->send != NULL || chan->recv != NULL) {
-            chan->open = 0;
-        }
+        _channelends_close_end(ends, end, 0);
     }
-
-    res = 0;
-done:
-    PyThread_release_lock(chan->mutex);
-    return res;
+    return 0;
 }
 
-static int
-_channel_close_all(_PyChannelState *chan)
+static void
+_channelends_close_all(_channelends *ends)
 {
-    int res = -1;
-    PyThread_acquire_lock(chan->mutex, WAIT_LOCK);
+    // Ensure all the "send"-associated interpreters are closed.
+    _channelend *end;
+    for (end = ends->send; end != NULL; end = end->next) {
+        _channelends_close_end(ends, end, 1);
+    }
 
-    if (!chan->open) {
-        PyErr_SetString(ChannelClosedError, "channel already closed");
-        goto done;
+    // Ensure all the "recv"-associated interpreters are closed.
+    for (end = ends->recv; end != NULL; end = end->next) {
+        _channelends_close_end(ends, end, 0);
     }
+}
 
-    chan->open = 0;
+/* channels */
 
-    // We *could* also just leave these in place, since we've marked
-    // the channel as closed already.
+struct _channel;
 
-    // Ensure all the "send"-associated interpreters are closed.
-    _channelend *end;
-    for (end = chan->send; end != NULL; end = end->next) {
-        _channel_close_channelend(chan, end, 1);
-    }
+typedef struct _channel {
+    PyThread_type_lock mutex;
+    _channelqueue *queue;
+    _channelends *ends;
+    int open;
+} _PyChannelState;
 
-    // Ensure all the "recv"-associated interpreters are closed.
-    for (end = chan->recv; end != NULL; end = end->next) {
-        _channel_close_channelend(chan, end, 0);
+static _PyChannelState *
+_channel_new(void)
+{
+    _PyChannelState *chan = PyMem_NEW(_PyChannelState, 1);
+    if (chan == NULL) {
+        return NULL;
+    }
+    chan->mutex = PyThread_allocate_lock();
+    if (chan->mutex == NULL) {
+        PyMem_Free(chan);
+        PyErr_SetString(ChannelError,
+                        "can't initialize mutex for new channel");
+        return NULL;
+    }
+    chan->queue = _channelqueue_new();
+    if (chan->queue == NULL) {
+        PyMem_Free(chan);
+        return NULL;
+    }
+    chan->ends = _channelends_new();
+    if (chan->ends == NULL) {
+        _channelqueue_free(chan->queue);
+        PyMem_Free(chan);
+        return NULL;
     }
+    chan->open = 1;
+    return chan;
+}
 
-    res = 0;
-done:
+static void
+_channel_free(_PyChannelState *chan)
+{
+    PyThread_acquire_lock(chan->mutex, WAIT_LOCK);
+    _channelqueue_free(chan->queue);
+    _channelends_free(chan->ends);
     PyThread_release_lock(chan->mutex);
-    return res;
+
+    PyThread_free_lock(chan->mutex);
+    PyMem_Free(chan);
 }
 
 static int
@@ -506,24 +767,19 @@ _channel_add(_PyChannelState *chan, int64_t interp,
              _PyCrossInterpreterData *data)
 {
     int res = -1;
-
     PyThread_acquire_lock(chan->mutex, WAIT_LOCK);
-    if (_channel_associate_end(chan, interp, 1) == NULL) {
+
+    if (!chan->open) {
+        PyErr_SetString(ChannelClosedError, "channel closed");
         goto done;
     }
-
-    _channelitem *item = PyMem_NEW(_channelitem, 1);
-    if (item == NULL) {
+    if (_channelends_associate(chan->ends, interp, 1) != 0) {
         goto done;
     }
-    item->data = data;
-    item->next = NULL;
 
-    chan->count += 1;
-    if (chan->first == NULL) {
-        chan->first = item;
+    if (_channelqueue_put(chan->queue, data) != 0) {
+        goto done;
     }
-    chan->last = item;
 
     res = 0;
 done:
@@ -536,56 +792,68 @@ _channel_next(_PyChannelState *chan, int64_t interp)
 {
     _PyCrossInterpreterData *data = NULL;
     PyThread_acquire_lock(chan->mutex, WAIT_LOCK);
-    if (_channel_associate_end(chan, interp, 0) == NULL) {
-        goto done;
-    }
 
-    _channelitem *item = chan->first;
-    if (item == NULL) {
+    if (!chan->open) {
+        PyErr_SetString(ChannelClosedError, "channel closed");
         goto done;
     }
-    chan->first = item->next;
-    if (chan->last == item) {
-        chan->last = NULL;
+    if (_channelends_associate(chan->ends, interp, 0) != 0) {
+        goto done;
     }
-    chan->count -= 1;
-
-    data = item->data;
-    PyMem_Free(item);
 
+    data = _channelqueue_get(chan->queue);
 done:
     PyThread_release_lock(chan->mutex);
     return data;
 }
 
-static void
-_channel_clear(_PyChannelState *chan)
+static int
+_channel_close_interpreter(_PyChannelState *chan, int64_t interp, int which)
 {
-    _channelitem *item = chan->first;
-    while (item != NULL) {
-        _PyCrossInterpreterData_Release(item->data);
-        PyMem_Free(item->data);
-        _channelitem *last = item;
-        item = item->next;
-        PyMem_Free(last);
+    PyThread_acquire_lock(chan->mutex, WAIT_LOCK);
+
+    int res = -1;
+    if (!chan->open) {
+        PyErr_SetString(ChannelClosedError, "channel already closed");
+        goto done;
     }
-    chan->first = NULL;
-    chan->last = NULL;
+
+    if (_channelends_close_interpreter(chan->ends, interp, which) != 0) {
+        goto done;
+    }
+    chan->open = _channelends_is_open(chan->ends);
+
+    res = 0;
+done:
+    PyThread_release_lock(chan->mutex);
+    return res;
 }
 
-static void
-_channel_free(_PyChannelState *chan)
+static int
+_channel_close_all(_PyChannelState *chan)
 {
+    int res = -1;
     PyThread_acquire_lock(chan->mutex, WAIT_LOCK);
-    _channel_clear(chan);
-    _channelend_free_all(chan->send);
-    _channelend_free_all(chan->recv);
-    PyThread_release_lock(chan->mutex);
 
-    PyThread_free_lock(chan->mutex);
-    PyMem_Free(chan);
+    if (!chan->open) {
+        PyErr_SetString(ChannelClosedError, "channel already closed");
+        goto done;
+    }
+
+    chan->open = 0;
+
+    // We *could* also just leave these in place, since we've marked
+    // the channel as closed already.
+    _channelends_close_all(chan->ends);
+
+    res = 0;
+done:
+    PyThread_release_lock(chan->mutex);
+    return res;
 }
 
+/* the set of channels */
+
 struct _channelref;
 
 typedef struct _channelref {
@@ -609,6 +877,22 @@ _channelref_new(int64_t id, _PyChannelState *chan)
     return ref;
 }
 
+//static void
+//_channelref_clear(_channelref *ref)
+//{
+//    ref->id = -1;
+//    ref->chan = NULL;
+//    ref->next = NULL;
+//    ref->objcount = 0;
+//}
+
+static void
+_channelref_free(_channelref *ref)
+{
+    //_channelref_clear(ref);
+    PyMem_Free(ref);
+}
+
 static _channelref *
 _channelref_find(_channelref *first, int64_t id, _channelref **pprev)
 {
@@ -640,7 +924,6 @@ _channels_init(_channels *channels)
     if (channels->mutex == NULL) {
         channels->mutex = PyThread_allocate_lock();
         if (channels->mutex == NULL) {
-            PyMem_Free(channels);
             PyErr_SetString(ChannelError,
                             "can't initialize mutex for channel management");
             return -1;
@@ -752,6 +1035,9 @@ _channels_close(_channels *channels, int64_t cid, _PyChannelState **pchan)
         if (pchan != NULL) {
             *pchan = ref->chan;
         }
+        else {
+            _channel_free(ref->chan);
+        }
         ref->chan = NULL;
     }
 
@@ -776,7 +1062,7 @@ _channels_remove_ref(_channels *channels, _channelref *ref, _channelref *prev,
     if (pchan != NULL) {
         *pchan = ref->chan;
     }
-    PyMem_Free(ref);
+    _channelref_free(ref);
 }
 
 static int
@@ -974,6 +1260,7 @@ _channel_recv(_channels *channels, int64_t id)
         return NULL;
     }
     _PyCrossInterpreterData_Release(data);
+    PyMem_Free(data);
 
     return obj;
 }
@@ -995,7 +1282,7 @@ _channel_drop(_channels *channels, int64_t id, int send, int recv)
     // Past this point we are responsible for releasing the mutex.
 
     // Close one or both of the two ends.
-    int res =_channel_close_interpreter(chan, interp->id, send-recv);
+    int res = _channel_close_interpreter(chan, interp->id, send-recv);
     PyThread_release_lock(mutex);
     return res;
 }
@@ -1078,6 +1365,7 @@ channelid_new(PyTypeObject *cls, PyObject *args, PyObject *kwds)
                         "'send' and 'recv' cannot both be False");
         return NULL;
     }
+
     int end = 0;
     if (send == 1) {
         if (recv == 0 || recv == -1) {
@@ -1176,7 +1464,9 @@ channelid_hash(PyObject *self)
     if (id == NULL) {
         return -1;
     }
-    return PyObject_Hash(id);
+    Py_hash_t hash = PyObject_Hash(id);
+    Py_DECREF(id);
+    return hash;
 }
 
 static PyObject *
@@ -1208,11 +1498,11 @@ channelid_richcompare(PyObject *self, PyObject *other, int op)
             Py_RETURN_NOTIMPLEMENTED;
         }
         int64_t othercid = PyLong_AsLongLong(other);
-        // XXX decref other here?
+        Py_DECREF(other);
         if (othercid == -1 && PyErr_Occurred() != NULL) {
             return NULL;
         }
-        if (othercid < 0 || othercid > INT64_MAX) {
+        if (othercid < 0) {
             equal = 0;
         }
         else {
@@ -1338,7 +1628,30 @@ static PyTypeObject ChannelIDtype = {
     NULL,                           /* tp_new */
 };
 
-/* interpreter-specific functions *******************************************/
+
+/* interpreter-specific code ************************************************/
+
+static PyObject * RunFailedError = NULL;
+
+static int
+interp_exceptions_init(PyObject *ns)
+{
+    // XXX Move the exceptions into per-module memory?
+
+    if (RunFailedError == NULL) {
+        // An uncaught exception came out of interp_run_string().
+        RunFailedError = PyErr_NewException("_xxsubinterpreters.RunFailedError",
+                                            PyExc_RuntimeError, NULL);
+        if (RunFailedError == NULL) {
+            return -1;
+        }
+        if (PyDict_SetItemString(ns, "RunFailedError", RunFailedError) != 0) {
+            return -1;
+        }
+    }
+
+    return 0;
+}
 
 static PyInterpreterState *
 _look_up(PyObject *requested_id)
@@ -1396,10 +1709,12 @@ _ensure_not_running(PyInterpreterState *interp)
 
 static int
 _run_script(PyInterpreterState *interp, const char *codestr,
-            _shareditem *shared, Py_ssize_t num_shared,
-            _sharedexception **exc)
+            _sharedns *shared, _sharedexception **exc)
 {
-    assert(num_shared >= 0);
+    PyObject *exctype = NULL;
+    PyObject *excval = NULL;
+    PyObject *tb = NULL;
+
     PyObject *main_mod = PyMapping_GetItemString(interp->modules, "__main__");
     if (main_mod == NULL) {
         goto error;
@@ -1413,12 +1728,9 @@ _run_script(PyInterpreterState *interp, const char *codestr,
 
     // Apply the cross-interpreter data.
     if (shared != NULL) {
-        for (Py_ssize_t i=0; i < num_shared; i++) {
-            _shareditem *item = &shared[i];
-            if (_shareditem_apply(item, ns) != 0) {
-                Py_DECREF(ns);
-                goto error;
-            }
+        if (_sharedns_apply(shared, ns) != 0) {
+            Py_DECREF(ns);
+            goto error;
         }
     }
 
@@ -1432,11 +1744,25 @@ _run_script(PyInterpreterState *interp, const char *codestr,
         Py_DECREF(result);  // We throw away the result.
     }
 
+    *exc = NULL;
     return 0;
 
 error:
-    *exc = _get_shared_exception();
-    PyErr_Clear();
+    PyErr_Fetch(&exctype, &excval, &tb);
+
+    _sharedexception *sharedexc = _sharedexception_bind(exctype, excval, tb);
+    Py_XDECREF(exctype);
+    Py_XDECREF(excval);
+    Py_XDECREF(tb);
+    if (sharedexc == NULL) {
+        fprintf(stderr, "RunFailedError: script raised an uncaught exception");
+        PyErr_Clear();
+        sharedexc = NULL;
+    }
+    else {
+        assert(!PyErr_Occurred());
+    }
+    *exc = sharedexc;
     return -1;
 }
 
@@ -1448,8 +1774,7 @@ _run_script_in_interpreter(PyInterpreterState *interp, const char *codestr,
         return -1;
     }
 
-    Py_ssize_t num_shared = -1;
-    _shareditem *shared = _get_shared_ns(shareables, &num_shared);
+    _sharedns *shared = _get_shared_ns(shareables);
     if (shared == NULL && PyErr_Occurred()) {
         return -1;
     }
@@ -1460,7 +1785,7 @@ _run_script_in_interpreter(PyInterpreterState *interp, const char *codestr,
 
     // Run the script.
     _sharedexception *exc = NULL;
-    int result = _run_script(interp, codestr, shared, num_shared, &exc);
+    int result = _run_script(interp, codestr, shared, &exc);
 
     // Switch back.
     if (save_tstate != NULL) {
@@ -1469,8 +1794,8 @@ _run_script_in_interpreter(PyInterpreterState *interp, const char *codestr,
 
     // Propagate any exception out to the caller.
     if (exc != NULL) {
-        _apply_shared_exception(exc);
-        PyMem_Free(exc);
+        _sharedexception_apply(exc, RunFailedError);
+        _sharedexception_free(exc);
     }
     else if (result != 0) {
         // We were unable to allocate a shared exception.
@@ -1478,8 +1803,7 @@ _run_script_in_interpreter(PyInterpreterState *interp, const char *codestr,
     }
 
     if (shared != NULL) {
-        _sharedns_clear(shared);
-        PyMem_Free(shared);
+        _sharedns_free(shared);
     }
 
     return result;
@@ -1612,7 +1936,9 @@ interp_list_all(PyObject *self)
             return NULL;
         }
         // insert at front of list
-        if (PyList_Insert(ids, 0, id) < 0) {
+        int res = PyList_Insert(ids, 0, id);
+        Py_DECREF(id);
+        if (res < 0) {
             Py_DECREF(ids);
             return NULL;
         }
@@ -1822,11 +2148,11 @@ channel_list_all(PyObject *self)
     }
     PyObject *ids = PyList_New((Py_ssize_t)count);
     if (ids == NULL) {
-        // XXX free cids
-        return NULL;
+        goto finally;
     }
-    for (int64_t i=0; i < count; cids++, i++) {
-        PyObject *id = (PyObject *)newchannelid(&ChannelIDtype, *cids, 0,
+    int64_t *cur = cids;
+    for (int64_t i=0; i < count; cur++, i++) {
+        PyObject *id = (PyObject *)newchannelid(&ChannelIDtype, *cur, 0,
                                                 &_globals.channels, 0);
         if (id == NULL) {
             Py_DECREF(ids);
@@ -1835,7 +2161,9 @@ channel_list_all(PyObject *self)
         }
         PyList_SET_ITEM(ids, i, id);
     }
-    // XXX free cids
+
+finally:
+    PyMem_Free(cids);
     return ids;
 }
 
index a474549a8c730dc7807b6dd8723e856ca9e87756..8dbda73de7015df743ef3830a01ea4f5ef554fcb 100644 (file)
@@ -1242,6 +1242,7 @@ _PyCrossInterpreterData_Lookup(PyObject *obj)
             break;
         }
     }
+    Py_DECREF(cls);
     PyThread_release_lock(_PyRuntime.xidregistry.mutex);
     return getdata;
 }