@@ -198,6 +198,9 @@ _release_xid_data(_PyCrossInterpreterData *data, int flags)
198
198
/* module state *************************************************************/
199
199
200
200
typedef struct {
201
+ PyTypeObject * send_channel_type ;
202
+ PyTypeObject * recv_channel_type ;
203
+
201
204
/* heap types */
202
205
PyTypeObject * ChannelIDType ;
203
206
@@ -218,6 +221,21 @@ get_module_state(PyObject *mod)
218
221
return state ;
219
222
}
220
223
224
+ static module_state *
225
+ _get_current_module_state (void )
226
+ {
227
+ PyObject * mod = _get_current_module ();
228
+ if (mod == NULL ) {
229
+ // XXX import it?
230
+ PyErr_SetString (PyExc_RuntimeError ,
231
+ MODULE_NAME " module not imported yet" );
232
+ return NULL ;
233
+ }
234
+ module_state * state = get_module_state (mod );
235
+ Py_DECREF (mod );
236
+ return state ;
237
+ }
238
+
221
239
static int
222
240
traverse_module_state (module_state * state , visitproc visit , void * arg )
223
241
{
@@ -237,6 +255,9 @@ traverse_module_state(module_state *state, visitproc visit, void *arg)
237
255
static int
238
256
clear_module_state (module_state * state )
239
257
{
258
+ Py_CLEAR (state -> send_channel_type );
259
+ Py_CLEAR (state -> recv_channel_type );
260
+
240
261
/* heap types */
241
262
if (state -> ChannelIDType != NULL ) {
242
263
(void )_PyCrossInterpreterData_UnregisterClass (state -> ChannelIDType );
@@ -1529,17 +1550,20 @@ typedef struct channelid {
1529
1550
struct channel_id_converter_data {
1530
1551
PyObject * module ;
1531
1552
int64_t cid ;
1553
+ int end ;
1532
1554
};
1533
1555
1534
1556
static int
1535
1557
channel_id_converter (PyObject * arg , void * ptr )
1536
1558
{
1537
1559
int64_t cid ;
1560
+ int end = 0 ;
1538
1561
struct channel_id_converter_data * data = ptr ;
1539
1562
module_state * state = get_module_state (data -> module );
1540
1563
assert (state != NULL );
1541
1564
if (PyObject_TypeCheck (arg , state -> ChannelIDType )) {
1542
1565
cid = ((channelid * )arg )-> id ;
1566
+ end = ((channelid * )arg )-> end ;
1543
1567
}
1544
1568
else if (PyIndex_Check (arg )) {
1545
1569
cid = PyLong_AsLongLong (arg );
@@ -1559,6 +1583,7 @@ channel_id_converter(PyObject *arg, void *ptr)
1559
1583
return 0 ;
1560
1584
}
1561
1585
data -> cid = cid ;
1586
+ data -> end = end ;
1562
1587
return 1 ;
1563
1588
}
1564
1589
@@ -1600,6 +1625,7 @@ _channelid_new(PyObject *mod, PyTypeObject *cls,
1600
1625
{
1601
1626
static char * kwlist [] = {"id" , "send" , "recv" , "force" , "_resolve" , NULL };
1602
1627
int64_t cid ;
1628
+ int end ;
1603
1629
struct channel_id_converter_data cid_data = {
1604
1630
.module = mod ,
1605
1631
};
@@ -1614,21 +1640,25 @@ _channelid_new(PyObject *mod, PyTypeObject *cls,
1614
1640
return NULL ;
1615
1641
}
1616
1642
cid = cid_data .cid ;
1643
+ end = cid_data .end ;
1617
1644
1618
1645
// Handle "send" and "recv".
1619
1646
if (send == 0 && recv == 0 ) {
1620
1647
PyErr_SetString (PyExc_ValueError ,
1621
1648
"'send' and 'recv' cannot both be False" );
1622
1649
return NULL ;
1623
1650
}
1624
-
1625
- int end = 0 ;
1626
- if (send == 1 ) {
1651
+ else if (send == 1 ) {
1627
1652
if (recv == 0 || recv == -1 ) {
1628
1653
end = CHANNEL_SEND ;
1629
1654
}
1655
+ else {
1656
+ assert (recv == 1 );
1657
+ end = 0 ;
1658
+ }
1630
1659
}
1631
1660
else if (recv == 1 ) {
1661
+ assert (send == 0 || send == -1 );
1632
1662
end = CHANNEL_RECV ;
1633
1663
}
1634
1664
@@ -1773,21 +1803,12 @@ channelid_richcompare(PyObject *self, PyObject *other, int op)
1773
1803
return res ;
1774
1804
}
1775
1805
1806
+ static PyTypeObject * _get_current_channel_end_type (int end );
1807
+
1776
1808
static PyObject *
1777
1809
_channel_from_cid (PyObject * cid , int end )
1778
1810
{
1779
- PyObject * highlevel = PyImport_ImportModule ("interpreters" );
1780
- if (highlevel == NULL ) {
1781
- PyErr_Clear ();
1782
- highlevel = PyImport_ImportModule ("test.support.interpreters" );
1783
- if (highlevel == NULL ) {
1784
- return NULL ;
1785
- }
1786
- }
1787
- const char * clsname = (end == CHANNEL_RECV ) ? "RecvChannel" :
1788
- "SendChannel" ;
1789
- PyObject * cls = PyObject_GetAttrString (highlevel , clsname );
1790
- Py_DECREF (highlevel );
1811
+ PyObject * cls = (PyObject * )_get_current_channel_end_type (end );
1791
1812
if (cls == NULL ) {
1792
1813
return NULL ;
1793
1814
}
@@ -1943,6 +1964,103 @@ static PyType_Spec ChannelIDType_spec = {
1943
1964
};
1944
1965
1945
1966
1967
+ /* SendChannel and RecvChannel classes */
1968
+
1969
+ // XXX Use a new __xid__ protocol instead?
1970
+
1971
+ static PyTypeObject *
1972
+ _get_current_channel_end_type (int end )
1973
+ {
1974
+ module_state * state = _get_current_module_state ();
1975
+ if (state == NULL ) {
1976
+ return NULL ;
1977
+ }
1978
+ PyTypeObject * cls ;
1979
+ if (end == CHANNEL_SEND ) {
1980
+ cls = state -> send_channel_type ;
1981
+ }
1982
+ else {
1983
+ assert (end == CHANNEL_RECV );
1984
+ cls = state -> recv_channel_type ;
1985
+ }
1986
+ if (cls == NULL ) {
1987
+ PyObject * highlevel = PyImport_ImportModule ("interpreters" );
1988
+ if (highlevel == NULL ) {
1989
+ PyErr_Clear ();
1990
+ highlevel = PyImport_ImportModule ("test.support.interpreters" );
1991
+ if (highlevel == NULL ) {
1992
+ return NULL ;
1993
+ }
1994
+ }
1995
+ if (end == CHANNEL_SEND ) {
1996
+ cls = state -> send_channel_type ;
1997
+ }
1998
+ else {
1999
+ cls = state -> recv_channel_type ;
2000
+ }
2001
+ assert (cls != NULL );
2002
+ }
2003
+ return cls ;
2004
+ }
2005
+
2006
+ static PyObject *
2007
+ _channel_end_from_xid (_PyCrossInterpreterData * data )
2008
+ {
2009
+ channelid * cid = (channelid * )_channelid_from_xid (data );
2010
+ if (cid == NULL ) {
2011
+ return NULL ;
2012
+ }
2013
+ PyTypeObject * cls = _get_current_channel_end_type (cid -> end );
2014
+ if (cls == NULL ) {
2015
+ return NULL ;
2016
+ }
2017
+ PyObject * obj = PyObject_CallOneArg ((PyObject * )cls , (PyObject * )cid );
2018
+ Py_DECREF (cid );
2019
+ return obj ;
2020
+ }
2021
+
2022
+ static int
2023
+ _channel_end_shared (PyThreadState * tstate , PyObject * obj ,
2024
+ _PyCrossInterpreterData * data )
2025
+ {
2026
+ PyObject * cidobj = PyObject_GetAttrString (obj , "_id" );
2027
+ if (cidobj == NULL ) {
2028
+ return -1 ;
2029
+ }
2030
+ if (_channelid_shared (tstate , cidobj , data ) < 0 ) {
2031
+ return -1 ;
2032
+ }
2033
+ data -> new_object = _channel_end_from_xid ;
2034
+ return 0 ;
2035
+ }
2036
+
2037
+ static int
2038
+ set_channel_end_types (PyObject * mod , PyTypeObject * send , PyTypeObject * recv )
2039
+ {
2040
+ module_state * state = get_module_state (mod );
2041
+ if (state == NULL ) {
2042
+ return -1 ;
2043
+ }
2044
+
2045
+ if (state -> send_channel_type != NULL
2046
+ || state -> recv_channel_type != NULL )
2047
+ {
2048
+ PyErr_SetString (PyExc_TypeError , "already registered" );
2049
+ return -1 ;
2050
+ }
2051
+ state -> send_channel_type = (PyTypeObject * )Py_NewRef (send );
2052
+ state -> recv_channel_type = (PyTypeObject * )Py_NewRef (recv );
2053
+
2054
+ if (_PyCrossInterpreterData_RegisterClass (send , _channel_end_shared )) {
2055
+ return -1 ;
2056
+ }
2057
+ if (_PyCrossInterpreterData_RegisterClass (recv , _channel_end_shared )) {
2058
+ return -1 ;
2059
+ }
2060
+
2061
+ return 0 ;
2062
+ }
2063
+
1946
2064
/* module level code ********************************************************/
1947
2065
1948
2066
/* globals is the process-global state for the module. It holds all
@@ -2346,13 +2464,38 @@ channel__channel_id(PyObject *self, PyObject *args, PyObject *kwds)
2346
2464
return NULL ;
2347
2465
}
2348
2466
PyTypeObject * cls = state -> ChannelIDType ;
2349
- PyObject * mod = get_module_from_owned_type (cls );
2350
- if (mod == NULL ) {
2467
+ assert (get_module_from_owned_type (cls ) == self );
2468
+
2469
+ return _channelid_new (self , cls , args , kwds );
2470
+ }
2471
+
2472
+ static PyObject *
2473
+ channel__register_end_types (PyObject * self , PyObject * args , PyObject * kwds )
2474
+ {
2475
+ static char * kwlist [] = {"send" , "recv" , NULL };
2476
+ PyObject * send ;
2477
+ PyObject * recv ;
2478
+ if (!PyArg_ParseTupleAndKeywords (args , kwds ,
2479
+ "OO:_register_end_types" , kwlist ,
2480
+ & send , & recv )) {
2351
2481
return NULL ;
2352
2482
}
2353
- PyObject * cid = _channelid_new (mod , cls , args , kwds );
2354
- Py_DECREF (mod );
2355
- return cid ;
2483
+ if (!PyType_Check (send )) {
2484
+ PyErr_SetString (PyExc_TypeError , "expected a type for 'send'" );
2485
+ return NULL ;
2486
+ }
2487
+ if (!PyType_Check (recv )) {
2488
+ PyErr_SetString (PyExc_TypeError , "expected a type for 'recv'" );
2489
+ return NULL ;
2490
+ }
2491
+ PyTypeObject * cls_send = (PyTypeObject * )send ;
2492
+ PyTypeObject * cls_recv = (PyTypeObject * )recv ;
2493
+
2494
+ if (set_channel_end_types (self , cls_send , cls_recv ) < 0 ) {
2495
+ return NULL ;
2496
+ }
2497
+
2498
+ Py_RETURN_NONE ;
2356
2499
}
2357
2500
2358
2501
static PyMethodDef module_functions [] = {
@@ -2374,6 +2517,8 @@ static PyMethodDef module_functions[] = {
2374
2517
METH_VARARGS | METH_KEYWORDS , channel_release_doc },
2375
2518
{"_channel_id" , _PyCFunction_CAST (channel__channel_id ),
2376
2519
METH_VARARGS | METH_KEYWORDS , NULL },
2520
+ {"_register_end_types" , _PyCFunction_CAST (channel__register_end_types ),
2521
+ METH_VARARGS | METH_KEYWORDS , NULL },
2377
2522
2378
2523
{NULL , NULL } /* sentinel */
2379
2524
};
0 commit comments