Skip to content

Commit b4edf64

Browse files
Multi-phase init
1 parent 13d14da commit b4edf64

File tree

1 file changed

+99
-50
lines changed

1 file changed

+99
-50
lines changed

Modules/socketmodule.c

+99-50
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ Local naming conventions:
108108
#define PY_SSIZE_T_CLEAN
109109
#include "Python.h"
110110
#include "pycore_fileutils.h" // _Py_set_inheritable()
111+
#include "pycore_moduleobject.h" // _PyModule_GetState
111112
#include "structmember.h" // PyMemberDef
112113

113114
#ifdef _Py_MEMORY_SANITIZER
@@ -569,11 +570,25 @@ typedef struct _socket_state {
569570
#endif
570571
} socket_state;
571572

572-
static socket_state global_state;
573+
static inline socket_state *
574+
get_module_state(PyObject *mod)
575+
{
576+
void *state = _PyModule_GetState(mod);
577+
assert(state != NULL);
578+
return (socket_state *)state;
579+
}
580+
581+
static struct PyModuleDef socketmodule;
573582

574-
#define GLOBAL_STATE() (&global_state)
583+
static inline socket_state *
584+
find_module_state_by_def(PyTypeObject *type)
585+
{
586+
PyObject *mod = PyType_GetModuleByDef(type, &socketmodule);
587+
assert(mod != NULL);
588+
return get_module_state(mod);
589+
}
575590

576-
#define clinic_state() GLOBAL_STATE()
591+
#define clinic_state() (find_module_state_by_def(type))
577592
#include "clinic/socketmodule.c.h"
578593
#undef clinic_state
579594

@@ -5334,7 +5349,7 @@ sock_initobj_impl(PySocketSockObject *self, int family, int type, int proto,
53345349
{
53355350

53365351
SOCKET_T fd = INVALID_SOCKET;
5337-
socket_state *state = GLOBAL_STATE();
5352+
socket_state *state = find_module_state_by_def(Py_TYPE(self));
53385353

53395354
#ifndef MS_WINDOWS
53405355
#ifdef SOCK_CLOEXEC
@@ -5683,7 +5698,7 @@ socket_gethostbyname(PyObject *self, PyObject *args)
56835698
if (PySys_Audit("socket.gethostbyname", "O", args) < 0) {
56845699
goto finally;
56855700
}
5686-
socket_state *state = GLOBAL_STATE();
5701+
socket_state *state = get_module_state(self);
56875702
int rc = setipaddr(state, name, (struct sockaddr *)&addrbuf,
56885703
sizeof(addrbuf), AF_INET);
56895704
if (rc < 0) {
@@ -5878,7 +5893,7 @@ socket_gethostbyname_ex(PyObject *self, PyObject *args)
58785893
if (PySys_Audit("socket.gethostbyname", "O", args) < 0) {
58795894
goto finally;
58805895
}
5881-
socket_state *state = GLOBAL_STATE();
5896+
socket_state *state = get_module_state(self);
58825897
if (setipaddr(state, name, SAS2SA(&addr), sizeof(addr), AF_INET) < 0) {
58835898
goto finally;
58845899
}
@@ -5963,7 +5978,7 @@ socket_gethostbyaddr(PyObject *self, PyObject *args)
59635978
goto finally;
59645979
}
59655980
af = AF_UNSPEC;
5966-
socket_state *state = GLOBAL_STATE();
5981+
socket_state *state = get_module_state(self);
59675982
if (setipaddr(state, ip_num, sa, sizeof(addr), af) < 0) {
59685983
goto finally;
59695984
}
@@ -6226,7 +6241,7 @@ socket_socketpair(PyObject *self, PyObject *args)
62266241
SOCKET_T sv[2];
62276242
int family, type = SOCK_STREAM, proto = 0;
62286243
PyObject *res = NULL;
6229-
socket_state *state = GLOBAL_STATE();
6244+
socket_state *state = get_module_state(self);
62306245
#ifdef SOCK_CLOEXEC
62316246
int *atomic_flag_works = &(state->sock_cloexec_works);
62326247
#else
@@ -6732,7 +6747,7 @@ socket_getaddrinfo(PyObject *self, PyObject *args, PyObject* kwargs)
67326747
Py_END_ALLOW_THREADS
67336748
if (error) {
67346749
res0 = NULL; // gh-100795
6735-
socket_state *state = GLOBAL_STATE();
6750+
socket_state *state = get_module_state(self);
67366751
set_gaierror(state, error);
67376752
goto err;
67386753
}
@@ -6832,7 +6847,7 @@ socket_getnameinfo(PyObject *self, PyObject *args)
68326847
Py_END_ALLOW_THREADS
68336848
if (error) {
68346849
res = NULL; // gh-100795
6835-
socket_state *state = GLOBAL_STATE();
6850+
socket_state *state = get_module_state(self);
68366851
set_gaierror(state, error);
68376852
goto fail;
68386853
}
@@ -6865,7 +6880,7 @@ socket_getnameinfo(PyObject *self, PyObject *args)
68656880
error = getnameinfo(res->ai_addr, (socklen_t) res->ai_addrlen,
68666881
hbuf, sizeof(hbuf), pbuf, sizeof(pbuf), flags);
68676882
if (error) {
6868-
socket_state *state = GLOBAL_STATE();
6883+
socket_state *state = get_module_state(self);
68696884
set_gaierror(state, error);
68706885
goto fail;
68716886
}
@@ -6892,7 +6907,7 @@ Get host and port for a sockaddr.");
68926907
static PyObject *
68936908
socket_getdefaulttimeout(PyObject *self, PyObject *Py_UNUSED(ignored))
68946909
{
6895-
socket_state *state = GLOBAL_STATE();
6910+
socket_state *state = get_module_state(self);
68966911
if (state->defaulttimeout < 0) {
68976912
Py_RETURN_NONE;
68986913
}
@@ -6917,7 +6932,7 @@ socket_setdefaulttimeout(PyObject *self, PyObject *arg)
69176932
if (socket_parse_timeout(&timeout, arg) < 0)
69186933
return NULL;
69196934

6920-
socket_state *state = GLOBAL_STATE();
6935+
socket_state *state = get_module_state(self);
69216936
state->defaulttimeout = timeout;
69226937

69236938
Py_RETURN_NONE;
@@ -7334,27 +7349,16 @@ PyDoc_STRVAR(socket_doc,
73347349
\n\
73357350
See the socket module for documentation.");
73367351

7337-
static struct PyModuleDef socketmodule = {
7338-
.m_base = PyModuleDef_HEAD_INIT,
7339-
.m_name = PySocket_MODULE_NAME,
7340-
.m_doc = socket_doc,
7341-
.m_size = sizeof(socket_state),
7342-
.m_methods = socket_methods,
7343-
};
7344-
7345-
PyMODINIT_FUNC
7346-
PyInit__socket(void)
7352+
static int
7353+
socket_exec(PyObject *m)
73477354
{
7348-
PyObject *m, *has_ipv6;
7349-
7350-
if (!os_init())
7351-
return NULL;
7355+
PyObject *has_ipv6;
73527356

7353-
m = PyModule_Create(&socketmodule);
7354-
if (m == NULL)
7355-
return NULL;
7357+
if (!os_init()) {
7358+
return -1;
7359+
}
73567360

7357-
socket_state *state = GLOBAL_STATE();
7361+
socket_state *state = get_module_state(m);
73587362
state->defaulttimeout = _PYTIME_FROMSECONDS(-1);
73597363

73607364
#if defined(HAVE_ACCEPT) || defined(HAVE_ACCEPT4)
@@ -7371,36 +7375,36 @@ PyInit__socket(void)
73717375
state->socket_herror = PyErr_NewException("socket.herror",
73727376
PyExc_OSError, NULL);
73737377
if (state->socket_herror == NULL) {
7374-
return NULL;
7378+
return -1;
73757379
}
73767380
if (PyModule_AddObjectRef(m, "error", PyExc_OSError) < 0) {
7377-
return NULL;
7381+
return -1;
73787382
}
73797383
if (PyModule_AddObjectRef(m, "herror", state->socket_herror) < 0) {
7380-
return NULL;
7384+
return -1;
73817385
}
73827386
state->socket_gaierror = PyErr_NewException("socket.gaierror",
73837387
PyExc_OSError, NULL);
73847388
if (state->socket_gaierror == NULL) {
7385-
return NULL;
7389+
return -1;
73867390
}
73877391
if (PyModule_AddObjectRef(m, "gaierror", state->socket_gaierror) < 0) {
7388-
return NULL;
7392+
return -1;
73897393
}
73907394
if (PyModule_AddObjectRef(m, "timeout", PyExc_TimeoutError) < 0) {
7391-
return NULL;
7395+
return -1;
73927396
}
73937397

73947398
PyObject *sock_type = PyType_FromMetaclass(NULL, m, &sock_spec, NULL);
73957399
if (sock_type == NULL) {
7396-
return NULL;
7400+
return -1;
73977401
}
73987402
state->sock_type = (PyTypeObject *)sock_type;
73997403
if (PyModule_AddObjectRef(m, "SocketType", sock_type) < 0) {
7400-
return NULL;
7404+
return -1;
74017405
}
74027406
if (PyModule_AddObjectRef(m, "socket", sock_type) < 0) {
7403-
return NULL;
7407+
return -1;
74047408
}
74057409

74067410
#ifdef ENABLE_IPV6
@@ -7413,21 +7417,18 @@ PyInit__socket(void)
74137417
/* Export C API */
74147418
PySocketModule_APIObject *capi = sock_get_api(state);
74157419
if (capi == NULL) {
7416-
Py_DECREF(m);
7417-
return NULL;
7420+
return -1;
74187421
}
74197422
PyObject *capsule = PyCapsule_New(capi,
74207423
PySocket_CAPSULE_NAME,
74217424
sock_destroy_api);
74227425
if (capsule == NULL) {
74237426
sock_free_api(capi);
7424-
Py_DECREF(m);
7425-
return NULL;
7427+
return -1;
74267428
}
74277429
if (PyModule_AddObject(m, PySocket_CAPI_NAME, capsule) < 0) {
74287430
Py_DECREF(capsule);
7429-
Py_DECREF(m);
7430-
return NULL;
7431+
return -1;
74317432
}
74327433

74337434
/* Address families (we only support AF_INET and AF_UNIX) */
@@ -8782,7 +8783,7 @@ PyInit__socket(void)
87828783
PyObject *tmp;
87838784
tmp = PyLong_FromUnsignedLong(codes[i]);
87848785
if (tmp == NULL)
8785-
return NULL;
8786+
return -1;
87868787
PyModule_AddObject(m, names[i], tmp);
87878788
}
87888789
}
@@ -8805,10 +8806,58 @@ PyInit__socket(void)
88058806
#ifdef MS_WINDOWS
88068807
/* remove some flags on older version Windows during run-time */
88078808
if (remove_unusable_flags(m) < 0) {
8808-
Py_DECREF(m);
8809-
return NULL;
8809+
return -1;
88108810
}
88118811
#endif
88128812

8813-
return m;
8813+
return 0;
8814+
}
8815+
8816+
static struct PyModuleDef_Slot socket_slots[] = {
8817+
{Py_mod_exec, socket_exec},
8818+
{0, NULL},
8819+
};
8820+
8821+
static int
8822+
socket_traverse(PyObject *mod, visitproc visit, void *arg)
8823+
{
8824+
socket_state *state = get_module_state(mod);
8825+
Py_VISIT(state->sock_type);
8826+
Py_VISIT(state->socket_herror);
8827+
Py_VISIT(state->socket_gaierror);
8828+
return 0;
8829+
}
8830+
8831+
static int
8832+
socket_clear(PyObject *mod)
8833+
{
8834+
socket_state *state = get_module_state(mod);
8835+
Py_CLEAR(state->sock_type);
8836+
Py_CLEAR(state->socket_herror);
8837+
Py_CLEAR(state->socket_gaierror);
8838+
return 0;
8839+
}
8840+
8841+
static void
8842+
socket_free(void *mod)
8843+
{
8844+
(void)socket_clear((PyObject *)mod);
8845+
}
8846+
8847+
static struct PyModuleDef socketmodule = {
8848+
.m_base = PyModuleDef_HEAD_INIT,
8849+
.m_name = PySocket_MODULE_NAME,
8850+
.m_doc = socket_doc,
8851+
.m_size = sizeof(socket_state),
8852+
.m_methods = socket_methods,
8853+
.m_slots = socket_slots,
8854+
.m_traverse = socket_traverse,
8855+
.m_clear = socket_clear,
8856+
.m_free = socket_free,
8857+
};
8858+
8859+
PyMODINIT_FUNC
8860+
PyInit__socket(void)
8861+
{
8862+
return PyModuleDef_Init(&socketmodule);
88148863
}

0 commit comments

Comments
 (0)