@@ -198,6 +198,9 @@ _release_xid_data(_PyCrossInterpreterData *data, int flags)
198198/* module state *************************************************************/
199199
200200typedef struct {
201+ PyTypeObject * send_channel_type ;
202+ PyTypeObject * recv_channel_type ;
203+
201204 /* heap types */
202205 PyTypeObject * ChannelIDType ;
203206
@@ -218,6 +221,21 @@ get_module_state(PyObject *mod)
218221 return state ;
219222}
220223
224+ static module_state *
225+ _get_current_module_state (void )
226+ {
227+ PyObject * mod = _get_current_module ();
228+ if (mod == NULL ) {
229+ // XXX import it?
230+ PyErr_SetString (PyExc_RuntimeError ,
231+ MODULE_NAME " module not imported yet" );
232+ return NULL ;
233+ }
234+ module_state * state = get_module_state (mod );
235+ Py_DECREF (mod );
236+ return state ;
237+ }
238+
221239static int
222240traverse_module_state (module_state * state , visitproc visit , void * arg )
223241{
@@ -237,6 +255,9 @@ traverse_module_state(module_state *state, visitproc visit, void *arg)
237255static int
238256clear_module_state (module_state * state )
239257{
258+ Py_CLEAR (state -> send_channel_type );
259+ Py_CLEAR (state -> recv_channel_type );
260+
240261 /* heap types */
241262 if (state -> ChannelIDType != NULL ) {
242263 (void )_PyCrossInterpreterData_UnregisterClass (state -> ChannelIDType );
@@ -1529,17 +1550,20 @@ typedef struct channelid {
15291550struct channel_id_converter_data {
15301551 PyObject * module ;
15311552 int64_t cid ;
1553+ int end ;
15321554};
15331555
15341556static int
15351557channel_id_converter (PyObject * arg , void * ptr )
15361558{
15371559 int64_t cid ;
1560+ int end = 0 ;
15381561 struct channel_id_converter_data * data = ptr ;
15391562 module_state * state = get_module_state (data -> module );
15401563 assert (state != NULL );
15411564 if (PyObject_TypeCheck (arg , state -> ChannelIDType )) {
15421565 cid = ((channelid * )arg )-> id ;
1566+ end = ((channelid * )arg )-> end ;
15431567 }
15441568 else if (PyIndex_Check (arg )) {
15451569 cid = PyLong_AsLongLong (arg );
@@ -1559,6 +1583,7 @@ channel_id_converter(PyObject *arg, void *ptr)
15591583 return 0 ;
15601584 }
15611585 data -> cid = cid ;
1586+ data -> end = end ;
15621587 return 1 ;
15631588}
15641589
@@ -1600,6 +1625,7 @@ _channelid_new(PyObject *mod, PyTypeObject *cls,
16001625{
16011626 static char * kwlist [] = {"id" , "send" , "recv" , "force" , "_resolve" , NULL };
16021627 int64_t cid ;
1628+ int end ;
16031629 struct channel_id_converter_data cid_data = {
16041630 .module = mod ,
16051631 };
@@ -1614,21 +1640,25 @@ _channelid_new(PyObject *mod, PyTypeObject *cls,
16141640 return NULL ;
16151641 }
16161642 cid = cid_data .cid ;
1643+ end = cid_data .end ;
16171644
16181645 // Handle "send" and "recv".
16191646 if (send == 0 && recv == 0 ) {
16201647 PyErr_SetString (PyExc_ValueError ,
16211648 "'send' and 'recv' cannot both be False" );
16221649 return NULL ;
16231650 }
1624-
1625- int end = 0 ;
1626- if (send == 1 ) {
1651+ else if (send == 1 ) {
16271652 if (recv == 0 || recv == -1 ) {
16281653 end = CHANNEL_SEND ;
16291654 }
1655+ else {
1656+ assert (recv == 1 );
1657+ end = 0 ;
1658+ }
16301659 }
16311660 else if (recv == 1 ) {
1661+ assert (send == 0 || send == -1 );
16321662 end = CHANNEL_RECV ;
16331663 }
16341664
@@ -1773,21 +1803,12 @@ channelid_richcompare(PyObject *self, PyObject *other, int op)
17731803 return res ;
17741804}
17751805
1806+ static PyTypeObject * _get_current_channel_end_type (int end );
1807+
17761808static PyObject *
17771809_channel_from_cid (PyObject * cid , int end )
17781810{
1779- PyObject * highlevel = PyImport_ImportModule ("interpreters" );
1780- if (highlevel == NULL ) {
1781- PyErr_Clear ();
1782- highlevel = PyImport_ImportModule ("test.support.interpreters" );
1783- if (highlevel == NULL ) {
1784- return NULL ;
1785- }
1786- }
1787- const char * clsname = (end == CHANNEL_RECV ) ? "RecvChannel" :
1788- "SendChannel" ;
1789- PyObject * cls = PyObject_GetAttrString (highlevel , clsname );
1790- Py_DECREF (highlevel );
1811+ PyObject * cls = (PyObject * )_get_current_channel_end_type (end );
17911812 if (cls == NULL ) {
17921813 return NULL ;
17931814 }
@@ -1943,6 +1964,103 @@ static PyType_Spec ChannelIDType_spec = {
19431964};
19441965
19451966
1967+ /* SendChannel and RecvChannel classes */
1968+
1969+ // XXX Use a new __xid__ protocol instead?
1970+
1971+ static PyTypeObject *
1972+ _get_current_channel_end_type (int end )
1973+ {
1974+ module_state * state = _get_current_module_state ();
1975+ if (state == NULL ) {
1976+ return NULL ;
1977+ }
1978+ PyTypeObject * cls ;
1979+ if (end == CHANNEL_SEND ) {
1980+ cls = state -> send_channel_type ;
1981+ }
1982+ else {
1983+ assert (end == CHANNEL_RECV );
1984+ cls = state -> recv_channel_type ;
1985+ }
1986+ if (cls == NULL ) {
1987+ PyObject * highlevel = PyImport_ImportModule ("interpreters" );
1988+ if (highlevel == NULL ) {
1989+ PyErr_Clear ();
1990+ highlevel = PyImport_ImportModule ("test.support.interpreters" );
1991+ if (highlevel == NULL ) {
1992+ return NULL ;
1993+ }
1994+ }
1995+ if (end == CHANNEL_SEND ) {
1996+ cls = state -> send_channel_type ;
1997+ }
1998+ else {
1999+ cls = state -> recv_channel_type ;
2000+ }
2001+ assert (cls != NULL );
2002+ }
2003+ return cls ;
2004+ }
2005+
2006+ static PyObject *
2007+ _channel_end_from_xid (_PyCrossInterpreterData * data )
2008+ {
2009+ channelid * cid = (channelid * )_channelid_from_xid (data );
2010+ if (cid == NULL ) {
2011+ return NULL ;
2012+ }
2013+ PyTypeObject * cls = _get_current_channel_end_type (cid -> end );
2014+ if (cls == NULL ) {
2015+ return NULL ;
2016+ }
2017+ PyObject * obj = PyObject_CallOneArg ((PyObject * )cls , (PyObject * )cid );
2018+ Py_DECREF (cid );
2019+ return obj ;
2020+ }
2021+
2022+ static int
2023+ _channel_end_shared (PyThreadState * tstate , PyObject * obj ,
2024+ _PyCrossInterpreterData * data )
2025+ {
2026+ PyObject * cidobj = PyObject_GetAttrString (obj , "_id" );
2027+ if (cidobj == NULL ) {
2028+ return -1 ;
2029+ }
2030+ if (_channelid_shared (tstate , cidobj , data ) < 0 ) {
2031+ return -1 ;
2032+ }
2033+ data -> new_object = _channel_end_from_xid ;
2034+ return 0 ;
2035+ }
2036+
2037+ static int
2038+ set_channel_end_types (PyObject * mod , PyTypeObject * send , PyTypeObject * recv )
2039+ {
2040+ module_state * state = get_module_state (mod );
2041+ if (state == NULL ) {
2042+ return -1 ;
2043+ }
2044+
2045+ if (state -> send_channel_type != NULL
2046+ || state -> recv_channel_type != NULL )
2047+ {
2048+ PyErr_SetString (PyExc_TypeError , "already registered" );
2049+ return -1 ;
2050+ }
2051+ state -> send_channel_type = (PyTypeObject * )Py_NewRef (send );
2052+ state -> recv_channel_type = (PyTypeObject * )Py_NewRef (recv );
2053+
2054+ if (_PyCrossInterpreterData_RegisterClass (send , _channel_end_shared )) {
2055+ return -1 ;
2056+ }
2057+ if (_PyCrossInterpreterData_RegisterClass (recv , _channel_end_shared )) {
2058+ return -1 ;
2059+ }
2060+
2061+ return 0 ;
2062+ }
2063+
19462064/* module level code ********************************************************/
19472065
19482066/* globals is the process-global state for the module. It holds all
@@ -2346,13 +2464,38 @@ channel__channel_id(PyObject *self, PyObject *args, PyObject *kwds)
23462464 return NULL ;
23472465 }
23482466 PyTypeObject * cls = state -> ChannelIDType ;
2349- PyObject * mod = get_module_from_owned_type (cls );
2350- if (mod == NULL ) {
2467+ assert (get_module_from_owned_type (cls ) == self );
2468+
2469+ return _channelid_new (self , cls , args , kwds );
2470+ }
2471+
2472+ static PyObject *
2473+ channel__register_end_types (PyObject * self , PyObject * args , PyObject * kwds )
2474+ {
2475+ static char * kwlist [] = {"send" , "recv" , NULL };
2476+ PyObject * send ;
2477+ PyObject * recv ;
2478+ if (!PyArg_ParseTupleAndKeywords (args , kwds ,
2479+ "OO:_register_end_types" , kwlist ,
2480+ & send , & recv )) {
23512481 return NULL ;
23522482 }
2353- PyObject * cid = _channelid_new (mod , cls , args , kwds );
2354- Py_DECREF (mod );
2355- return cid ;
2483+ if (!PyType_Check (send )) {
2484+ PyErr_SetString (PyExc_TypeError , "expected a type for 'send'" );
2485+ return NULL ;
2486+ }
2487+ if (!PyType_Check (recv )) {
2488+ PyErr_SetString (PyExc_TypeError , "expected a type for 'recv'" );
2489+ return NULL ;
2490+ }
2491+ PyTypeObject * cls_send = (PyTypeObject * )send ;
2492+ PyTypeObject * cls_recv = (PyTypeObject * )recv ;
2493+
2494+ if (set_channel_end_types (self , cls_send , cls_recv ) < 0 ) {
2495+ return NULL ;
2496+ }
2497+
2498+ Py_RETURN_NONE ;
23562499}
23572500
23582501static PyMethodDef module_functions [] = {
@@ -2374,6 +2517,8 @@ static PyMethodDef module_functions[] = {
23742517 METH_VARARGS | METH_KEYWORDS , channel_release_doc },
23752518 {"_channel_id" , _PyCFunction_CAST (channel__channel_id ),
23762519 METH_VARARGS | METH_KEYWORDS , NULL },
2520+ {"_register_end_types" , _PyCFunction_CAST (channel__register_end_types ),
2521+ METH_VARARGS | METH_KEYWORDS , NULL },
23772522
23782523 {NULL , NULL } /* sentinel */
23792524};
0 commit comments