Skip to content

gh-76785: Module-level Fixes for test.support.interpreters #110236

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Oct 2, 2023
30 changes: 25 additions & 5 deletions Lib/test/support/interpreters.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
# aliases:
from _xxsubinterpreters import is_shareable
from _xxinterpchannels import (
ChannelError, ChannelNotFoundError, ChannelEmptyError,
ChannelError, ChannelNotFoundError, ChannelClosedError,
ChannelEmptyError, ChannelNotEmptyError,
)


Expand Down Expand Up @@ -117,10 +118,16 @@ def list_all_channels():
class _ChannelEnd:
"""The base class for RecvChannel and SendChannel."""

def __init__(self, id):
if not isinstance(id, (int, _channels.ChannelID)):
raise TypeError(f'id must be an int, got {id!r}')
self._id = id
_end = None

def __init__(self, cid):
if self._end == 'send':
cid = _channels._channel_id(cid, send=True, force=True)
elif self._end == 'recv':
cid = _channels._channel_id(cid, recv=True, force=True)
else:
raise NotImplementedError(self._end)
self._id = cid

def __repr__(self):
return f'{type(self).__name__}(id={int(self._id)})'
Expand All @@ -147,6 +154,8 @@ def id(self):
class RecvChannel(_ChannelEnd):
"""The receiving end of a cross-interpreter channel."""

_end = 'recv'

def recv(self, *, _sentinel=object(), _delay=10 / 1000): # 10 milliseconds
"""Return the next object from the channel.

Expand All @@ -171,10 +180,15 @@ def recv_nowait(self, default=_NOT_SET):
else:
return _channels.recv(self._id, default)

def close(self):
_channels.close(self._id, recv=True)


class SendChannel(_ChannelEnd):
"""The sending end of a cross-interpreter channel."""

_end = 'send'

def send(self, obj):
"""Send the object (i.e. its data) to the channel's receiving end.

Expand All @@ -196,3 +210,9 @@ def send_nowait(self, obj):
# None. This should be fixed when channel_send_wait() is added.
# See bpo-32604 and gh-19829.
return _channels.send(self._id, obj)

def close(self):
_channels.close(self._id, send=True)


_channels._register_end_types(SendChannel, RecvChannel)
16 changes: 16 additions & 0 deletions Lib/test/test_interpreters.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,22 @@ def test_list_all(self):
after = set(interpreters.list_all_channels())
self.assertEqual(after, created)

def test_shareable(self):
rch, sch = interpreters.create_channel()

self.assertTrue(
interpreters.is_shareable(rch))
self.assertTrue(
interpreters.is_shareable(sch))

sch.send_nowait(rch)
sch.send_nowait(sch)
rch2 = rch.recv()
sch2 = rch.recv()

self.assertEqual(rch2, rch)
self.assertEqual(sch2, sch)


class TestRecvChannelAttrs(TestBase):

Expand Down
185 changes: 165 additions & 20 deletions Modules/_xxinterpchannelsmodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,9 @@ _release_xid_data(_PyCrossInterpreterData *data, int flags)
/* module state *************************************************************/

typedef struct {
PyTypeObject *send_channel_type;
PyTypeObject *recv_channel_type;

/* heap types */
PyTypeObject *ChannelIDType;

Expand All @@ -218,6 +221,21 @@ get_module_state(PyObject *mod)
return state;
}

static module_state *
_get_current_module_state(void)
{
PyObject *mod = _get_current_module();
if (mod == NULL) {
// XXX import it?
PyErr_SetString(PyExc_RuntimeError,
MODULE_NAME " module not imported yet");
return NULL;
}
module_state *state = get_module_state(mod);
Py_DECREF(mod);
return state;
}

static int
traverse_module_state(module_state *state, visitproc visit, void *arg)
{
Expand All @@ -237,6 +255,9 @@ traverse_module_state(module_state *state, visitproc visit, void *arg)
static int
clear_module_state(module_state *state)
{
Py_CLEAR(state->send_channel_type);
Py_CLEAR(state->recv_channel_type);

/* heap types */
if (state->ChannelIDType != NULL) {
(void)_PyCrossInterpreterData_UnregisterClass(state->ChannelIDType);
Expand Down Expand Up @@ -1529,17 +1550,20 @@ typedef struct channelid {
struct channel_id_converter_data {
PyObject *module;
int64_t cid;
int end;
};

static int
channel_id_converter(PyObject *arg, void *ptr)
{
int64_t cid;
int end = 0;
struct channel_id_converter_data *data = ptr;
module_state *state = get_module_state(data->module);
assert(state != NULL);
if (PyObject_TypeCheck(arg, state->ChannelIDType)) {
cid = ((channelid *)arg)->id;
end = ((channelid *)arg)->end;
}
else if (PyIndex_Check(arg)) {
cid = PyLong_AsLongLong(arg);
Expand All @@ -1559,6 +1583,7 @@ channel_id_converter(PyObject *arg, void *ptr)
return 0;
}
data->cid = cid;
data->end = end;
return 1;
}

Expand Down Expand Up @@ -1600,6 +1625,7 @@ _channelid_new(PyObject *mod, PyTypeObject *cls,
{
static char *kwlist[] = {"id", "send", "recv", "force", "_resolve", NULL};
int64_t cid;
int end;
struct channel_id_converter_data cid_data = {
.module = mod,
};
Expand All @@ -1614,21 +1640,25 @@ _channelid_new(PyObject *mod, PyTypeObject *cls,
return NULL;
}
cid = cid_data.cid;
end = cid_data.end;

// Handle "send" and "recv".
if (send == 0 && recv == 0) {
PyErr_SetString(PyExc_ValueError,
"'send' and 'recv' cannot both be False");
return NULL;
}

int end = 0;
if (send == 1) {
else if (send == 1) {
if (recv == 0 || recv == -1) {
end = CHANNEL_SEND;
}
else {
assert(recv == 1);
end = 0;
}
}
else if (recv == 1) {
assert(send == 0 || send == -1);
end = CHANNEL_RECV;
}

Expand Down Expand Up @@ -1773,21 +1803,12 @@ channelid_richcompare(PyObject *self, PyObject *other, int op)
return res;
}

static PyTypeObject * _get_current_channel_end_type(int end);

static PyObject *
_channel_from_cid(PyObject *cid, int end)
{
PyObject *highlevel = PyImport_ImportModule("interpreters");
if (highlevel == NULL) {
PyErr_Clear();
highlevel = PyImport_ImportModule("test.support.interpreters");
if (highlevel == NULL) {
return NULL;
}
}
const char *clsname = (end == CHANNEL_RECV) ? "RecvChannel" :
"SendChannel";
PyObject *cls = PyObject_GetAttrString(highlevel, clsname);
Py_DECREF(highlevel);
PyObject *cls = (PyObject *)_get_current_channel_end_type(end);
if (cls == NULL) {
return NULL;
}
Expand Down Expand Up @@ -1943,6 +1964,103 @@ static PyType_Spec ChannelIDType_spec = {
};


/* SendChannel and RecvChannel classes */

// XXX Use a new __xid__ protocol instead?

static PyTypeObject *
_get_current_channel_end_type(int end)
{
module_state *state = _get_current_module_state();
if (state == NULL) {
return NULL;
}
PyTypeObject *cls;
if (end == CHANNEL_SEND) {
cls = state->send_channel_type;
}
else {
assert(end == CHANNEL_RECV);
cls = state->recv_channel_type;
}
if (cls == NULL) {
PyObject *highlevel = PyImport_ImportModule("interpreters");
if (highlevel == NULL) {
PyErr_Clear();
highlevel = PyImport_ImportModule("test.support.interpreters");
if (highlevel == NULL) {
return NULL;
}
}
if (end == CHANNEL_SEND) {
cls = state->send_channel_type;
}
else {
cls = state->recv_channel_type;
}
assert(cls != NULL);
}
return cls;
}

static PyObject *
_channel_end_from_xid(_PyCrossInterpreterData *data)
{
channelid *cid = (channelid *)_channelid_from_xid(data);
if (cid == NULL) {
return NULL;
}
PyTypeObject *cls = _get_current_channel_end_type(cid->end);
if (cls == NULL) {
return NULL;
}
PyObject *obj = PyObject_CallOneArg((PyObject *)cls, (PyObject *)cid);
Py_DECREF(cid);
return obj;
}

static int
_channel_end_shared(PyThreadState *tstate, PyObject *obj,
_PyCrossInterpreterData *data)
{
PyObject *cidobj = PyObject_GetAttrString(obj, "_id");
if (cidobj == NULL) {
return -1;
}
if (_channelid_shared(tstate, cidobj, data) < 0) {
return -1;
}
data->new_object = _channel_end_from_xid;
return 0;
}

static int
set_channel_end_types(PyObject *mod, PyTypeObject *send, PyTypeObject *recv)
{
module_state *state = get_module_state(mod);
if (state == NULL) {
return -1;
}

if (state->send_channel_type != NULL
|| state->recv_channel_type != NULL)
{
PyErr_SetString(PyExc_TypeError, "already registered");
return -1;
}
state->send_channel_type = (PyTypeObject *)Py_NewRef(send);
state->recv_channel_type = (PyTypeObject *)Py_NewRef(recv);

if (_PyCrossInterpreterData_RegisterClass(send, _channel_end_shared)) {
return -1;
}
if (_PyCrossInterpreterData_RegisterClass(recv, _channel_end_shared)) {
return -1;
}

return 0;
}

/* module level code ********************************************************/

/* globals is the process-global state for the module. It holds all
Expand Down Expand Up @@ -2346,13 +2464,38 @@ channel__channel_id(PyObject *self, PyObject *args, PyObject *kwds)
return NULL;
}
PyTypeObject *cls = state->ChannelIDType;
PyObject *mod = get_module_from_owned_type(cls);
if (mod == NULL) {
assert(get_module_from_owned_type(cls) == self);

return _channelid_new(self, cls, args, kwds);
}

static PyObject *
channel__register_end_types(PyObject *self, PyObject *args, PyObject *kwds)
{
static char *kwlist[] = {"send", "recv", NULL};
PyObject *send;
PyObject *recv;
if (!PyArg_ParseTupleAndKeywords(args, kwds,
"OO:_register_end_types", kwlist,
&send, &recv)) {
return NULL;
}
PyObject *cid = _channelid_new(mod, cls, args, kwds);
Py_DECREF(mod);
return cid;
if (!PyType_Check(send)) {
PyErr_SetString(PyExc_TypeError, "expected a type for 'send'");
return NULL;
}
if (!PyType_Check(recv)) {
PyErr_SetString(PyExc_TypeError, "expected a type for 'recv'");
return NULL;
}
PyTypeObject *cls_send = (PyTypeObject *)send;
PyTypeObject *cls_recv = (PyTypeObject *)recv;

if (set_channel_end_types(self, cls_send, cls_recv) < 0) {
return NULL;
}

Py_RETURN_NONE;
}

static PyMethodDef module_functions[] = {
Expand All @@ -2374,6 +2517,8 @@ static PyMethodDef module_functions[] = {
METH_VARARGS | METH_KEYWORDS, channel_release_doc},
{"_channel_id", _PyCFunction_CAST(channel__channel_id),
METH_VARARGS | METH_KEYWORDS, NULL},
{"_register_end_types", _PyCFunction_CAST(channel__register_end_types),
METH_VARARGS | METH_KEYWORDS, NULL},

{NULL, NULL} /* sentinel */
};
Expand Down