with self.assertRaises(interpreters.ChannelClosedError):
interpreters.channel_close(cid)
- def test_close_with_unused_items(self):
+ def test_close_empty(self):
+ tests = [
+ (False, False),
+ (True, False),
+ (False, True),
+ (True, True),
+ ]
+ for send, recv in tests:
+ with self.subTest((send, recv)):
+ cid = interpreters.channel_create()
+ interpreters.channel_send(cid, b'spam')
+ interpreters.channel_recv(cid)
+ interpreters.channel_close(cid, send=send, recv=recv)
+
+ with self.assertRaises(interpreters.ChannelClosedError):
+ interpreters.channel_send(cid, b'eggs')
+ with self.assertRaises(interpreters.ChannelClosedError):
+ interpreters.channel_recv(cid)
+
+ def test_close_defaults_with_unused_items(self):
+ cid = interpreters.channel_create()
+ interpreters.channel_send(cid, b'spam')
+ interpreters.channel_send(cid, b'ham')
+
+ with self.assertRaises(interpreters.ChannelNotEmptyError):
+ interpreters.channel_close(cid)
+ interpreters.channel_recv(cid)
+ interpreters.channel_send(cid, b'eggs')
+
+ def test_close_recv_with_unused_items_unforced(self):
cid = interpreters.channel_create()
interpreters.channel_send(cid, b'spam')
interpreters.channel_send(cid, b'ham')
- interpreters.channel_close(cid)
+
+ with self.assertRaises(interpreters.ChannelNotEmptyError):
+ interpreters.channel_close(cid, recv=True)
+ interpreters.channel_recv(cid)
+ interpreters.channel_send(cid, b'eggs')
+ interpreters.channel_recv(cid)
+ interpreters.channel_recv(cid)
+ interpreters.channel_close(cid, recv=True)
+
+ def test_close_send_with_unused_items_unforced(self):
+ cid = interpreters.channel_create()
+ interpreters.channel_send(cid, b'spam')
+ interpreters.channel_send(cid, b'ham')
+ interpreters.channel_close(cid, send=True)
+ with self.assertRaises(interpreters.ChannelClosedError):
+ interpreters.channel_send(cid, b'eggs')
+ interpreters.channel_recv(cid)
+ interpreters.channel_recv(cid)
+ with self.assertRaises(interpreters.ChannelClosedError):
+ interpreters.channel_recv(cid)
+
+ def test_close_both_with_unused_items_unforced(self):
+ cid = interpreters.channel_create()
+ interpreters.channel_send(cid, b'spam')
+ interpreters.channel_send(cid, b'ham')
+
+ with self.assertRaises(interpreters.ChannelNotEmptyError):
+ interpreters.channel_close(cid, recv=True, send=True)
+ interpreters.channel_recv(cid)
+ interpreters.channel_send(cid, b'eggs')
+ interpreters.channel_recv(cid)
+ interpreters.channel_recv(cid)
+ interpreters.channel_close(cid, recv=True)
+
+ def test_close_recv_with_unused_items_forced(self):
+ cid = interpreters.channel_create()
+ interpreters.channel_send(cid, b'spam')
+ interpreters.channel_send(cid, b'ham')
+ interpreters.channel_close(cid, recv=True, force=True)
+
+ with self.assertRaises(interpreters.ChannelClosedError):
+ interpreters.channel_send(cid, b'eggs')
+ with self.assertRaises(interpreters.ChannelClosedError):
+ interpreters.channel_recv(cid)
+
+ def test_close_send_with_unused_items_forced(self):
+ cid = interpreters.channel_create()
+ interpreters.channel_send(cid, b'spam')
+ interpreters.channel_send(cid, b'ham')
+ interpreters.channel_close(cid, send=True, force=True)
+
+ with self.assertRaises(interpreters.ChannelClosedError):
+ interpreters.channel_send(cid, b'eggs')
+ with self.assertRaises(interpreters.ChannelClosedError):
+ interpreters.channel_recv(cid)
+
+ def test_close_both_with_unused_items_forced(self):
+ cid = interpreters.channel_create()
+ interpreters.channel_send(cid, b'spam')
+ interpreters.channel_send(cid, b'ham')
+ interpreters.channel_close(cid, send=True, recv=True, force=True)
+
+ with self.assertRaises(interpreters.ChannelClosedError):
+ interpreters.channel_send(cid, b'eggs')
with self.assertRaises(interpreters.ChannelClosedError):
interpreters.channel_recv(cid)
interp = interpreters.create()
interpreters.run_string(interp, dedent(f"""
import _xxsubinterpreters as _interpreters
- _interpreters.channel_close({cid})
+ _interpreters.channel_close({cid}, force=True)
"""))
with self.assertRaises(interpreters.ChannelClosedError):
interpreters.channel_recv(cid)
interpreters.channel_send(cid, b'spam')
interpreters.channel_send(cid, b'spam')
interpreters.channel_recv(cid)
- interpreters.channel_close(cid)
+ interpreters.channel_close(cid, force=True)
with self.assertRaises(interpreters.ChannelClosedError):
interpreters.channel_send(cid, b'eggs')
/* channel-specific code ****************************************************/
+#define CHANNEL_SEND 1
+#define CHANNEL_BOTH 0
+#define CHANNEL_RECV -1
+
static PyObject *ChannelError;
static PyObject *ChannelNotFoundError;
static PyObject *ChannelClosedError;
static PyObject *ChannelEmptyError;
+static PyObject *ChannelNotEmptyError;
static int
channel_exceptions_init(PyObject *ns)
return -1;
}
+ // An operation tried to close a non-empty channel.
+ ChannelNotEmptyError = PyErr_NewException(
+ "_xxsubinterpreters.ChannelNotEmptyError", ChannelError, NULL);
+ if (ChannelNotEmptyError == NULL) {
+ return -1;
+ }
+ if (PyDict_SetItemString(ns, "ChannelNotEmptyError", ChannelNotEmptyError) != 0) {
+ return -1;
+ }
+
return 0;
}
}
static void
-_channelends_close_all(_channelends *ends)
+_channelends_close_all(_channelends *ends, int which, int force)
{
+ // XXX Handle the ends.
+ // XXX Handle force is True.
+
// Ensure all the "send"-associated interpreters are closed.
_channelend *end;
for (end = ends->send; end != NULL; end = end->next) {
/* channels */
struct _channel;
+struct _channel_closing;
+static void _channel_clear_closing(struct _channel *);
+static void _channel_finish_closing(struct _channel *);
typedef struct _channel {
PyThread_type_lock mutex;
_channelqueue *queue;
_channelends *ends;
int open;
+ struct _channel_closing *closing;
} _PyChannelState;
static _PyChannelState *
return NULL;
}
chan->open = 1;
+ chan->closing = NULL;
return chan;
}
static void
_channel_free(_PyChannelState *chan)
{
+ _channel_clear_closing(chan);
PyThread_acquire_lock(chan->mutex, WAIT_LOCK);
_channelqueue_free(chan->queue);
_channelends_free(chan->ends);
}
data = _channelqueue_get(chan->queue);
+ if (data == NULL && !PyErr_Occurred() && chan->closing != NULL) {
+ chan->open = 0;
+ }
+
done:
PyThread_release_lock(chan->mutex);
+ if (chan->queue->count == 0) {
+ _channel_finish_closing(chan);
+ }
return data;
}
static int
-_channel_close_interpreter(_PyChannelState *chan, int64_t interp, int which)
+_channel_close_interpreter(_PyChannelState *chan, int64_t interp, int end)
{
PyThread_acquire_lock(chan->mutex, WAIT_LOCK);
goto done;
}
- if (_channelends_close_interpreter(chan->ends, interp, which) != 0) {
+ if (_channelends_close_interpreter(chan->ends, interp, end) != 0) {
goto done;
}
chan->open = _channelends_is_open(chan->ends);
}
static int
-_channel_close_all(_PyChannelState *chan)
+_channel_close_all(_PyChannelState *chan, int end, int force)
{
int res = -1;
PyThread_acquire_lock(chan->mutex, WAIT_LOCK);
goto done;
}
+ if (!force && chan->queue->count > 0) {
+ PyErr_SetString(ChannelNotEmptyError,
+ "may not be closed if not empty (try force=True)");
+ goto done;
+ }
+
chan->open = 0;
// We *could* also just leave these in place, since we've marked
// the channel as closed already.
- _channelends_close_all(chan->ends);
+ _channelends_close_all(chan->ends, end, force);
res = 0;
done:
static void
_channelref_free(_channelref *ref)
{
+ if (ref->chan != NULL) {
+ _channel_clear_closing(ref->chan);
+ }
//_channelref_clear(ref);
PyMem_Free(ref);
}
return cid;
}
+/* forward */
+static int _channel_set_closing(struct _channelref *, PyThread_type_lock);
+
static int
-_channels_close(_channels *channels, int64_t cid, _PyChannelState **pchan)
+_channels_close(_channels *channels, int64_t cid, _PyChannelState **pchan,
+ int end, int force)
{
int res = -1;
PyThread_acquire_lock(channels->mutex, WAIT_LOCK);
PyErr_Format(ChannelClosedError, "channel %d closed", cid);
goto done;
}
+ else if (!force && end == CHANNEL_SEND && ref->chan->closing != NULL) {
+ PyErr_Format(ChannelClosedError, "channel %d closed", cid);
+ goto done;
+ }
else {
- if (_channel_close_all(ref->chan) != 0) {
+ if (_channel_close_all(ref->chan, end, force) != 0) {
+ if (end == CHANNEL_SEND &&
+ PyErr_ExceptionMatches(ChannelNotEmptyError)) {
+ if (ref->chan->closing != NULL) {
+ PyErr_Format(ChannelClosedError, "channel %d closed", cid);
+ goto done;
+ }
+ // Mark the channel as closing and return. The channel
+ // will be cleaned up in _channel_next().
+ PyErr_Clear();
+ if (_channel_set_closing(ref, channels->mutex) != 0) {
+ goto done;
+ }
+ if (pchan != NULL) {
+ *pchan = ref->chan;
+ }
+ res = 0;
+ }
goto done;
}
if (pchan != NULL) {
*pchan = ref->chan;
}
- else {
+ else {
_channel_free(ref->chan);
}
ref->chan = NULL;
return cids;
}
+/* support for closing non-empty channels */
+
+struct _channel_closing {
+ struct _channelref *ref;
+};
+
+static int
+_channel_set_closing(struct _channelref *ref, PyThread_type_lock mutex) {
+ struct _channel *chan = ref->chan;
+ if (chan == NULL) {
+ // already closed
+ return 0;
+ }
+ int res = -1;
+ PyThread_acquire_lock(chan->mutex, WAIT_LOCK);
+ if (chan->closing != NULL) {
+ PyErr_SetString(ChannelClosedError, "channel closed");
+ goto done;
+ }
+ chan->closing = PyMem_NEW(struct _channel_closing, 1);
+ if (chan->closing == NULL) {
+ goto done;
+ }
+ chan->closing->ref = ref;
+
+ res = 0;
+done:
+ PyThread_release_lock(chan->mutex);
+ return res;
+}
+
+static void
+_channel_clear_closing(struct _channel *chan) {
+ PyThread_acquire_lock(chan->mutex, WAIT_LOCK);
+ if (chan->closing != NULL) {
+ PyMem_Free(chan->closing);
+ chan->closing = NULL;
+ }
+ PyThread_release_lock(chan->mutex);
+}
+
+static void
+_channel_finish_closing(struct _channel *chan) {
+ struct _channel_closing *closing = chan->closing;
+ if (closing == NULL) {
+ return;
+ }
+ _channelref *ref = closing->ref;
+ _channel_clear_closing(chan);
+ // Do the things that would have been done in _channels_close().
+ ref->chan = NULL;
+ _channel_free(chan);
+};
+
/* "high"-level channel-related functions */
static int64_t
}
// Past this point we are responsible for releasing the mutex.
+ if (chan->closing != NULL) {
+ PyErr_Format(ChannelClosedError, "channel %d closed", id);
+ PyThread_release_lock(mutex);
+ return -1;
+ }
+
// Convert the object to cross-interpreter data.
_PyCrossInterpreterData *data = PyMem_NEW(_PyCrossInterpreterData, 1);
if (data == NULL) {
}
static int
-_channel_close(_channels *channels, int64_t id)
+_channel_close(_channels *channels, int64_t id, int end, int force)
{
- return _channels_close(channels, id, NULL);
+ return _channels_close(channels, id, NULL, end, force);
}
/* ChannelID class */
-#define CHANNEL_SEND 1
-#define CHANNEL_RECV -1
-
static PyTypeObject ChannelIDtype;
typedef struct channelid {
if (cid < 0) {
return NULL;
}
- if (send == 0 && recv == 0) {
- send = 1;
- recv = 1;
- }
-
- // XXX Handle the ends.
- // XXX Handle force is True.
- if (_channel_close(&_globals.channels, cid) != 0) {
+ if (_channel_close(&_globals.channels, cid, send-recv, force) != 0) {
return NULL;
}
Py_RETURN_NONE;