Skip to content

Commit c1712ef

Browse files
gh-116664: Make module state Py_SETREF's in _warnings thread-safe (#116959)
Mark the swap operations as critical sections. Add an internal Py_BEGIN_CRITICAL_SECTION_MUT API that takes a PyMutex pointer instead of a PyObject pointer.
1 parent 9a388b9 commit c1712ef

File tree

3 files changed

+44
-23
lines changed

3 files changed

+44
-23
lines changed

Include/internal/pycore_critical_section.h

+6-2
Original file line numberDiff line numberDiff line change
@@ -87,10 +87,13 @@ extern "C" {
8787
#define _Py_CRITICAL_SECTION_MASK 0x3
8888

8989
#ifdef Py_GIL_DISABLED
90-
# define Py_BEGIN_CRITICAL_SECTION(op) \
90+
# define Py_BEGIN_CRITICAL_SECTION_MUT(mutex) \
9191
{ \
9292
_PyCriticalSection _cs; \
93-
_PyCriticalSection_Begin(&_cs, &_PyObject_CAST(op)->ob_mutex)
93+
_PyCriticalSection_Begin(&_cs, mutex)
94+
95+
# define Py_BEGIN_CRITICAL_SECTION(op) \
96+
Py_BEGIN_CRITICAL_SECTION_MUT(&_PyObject_CAST(op)->ob_mutex)
9497

9598
# define Py_END_CRITICAL_SECTION() \
9699
_PyCriticalSection_End(&_cs); \
@@ -138,6 +141,7 @@ extern "C" {
138141

139142
#else /* !Py_GIL_DISABLED */
140143
// The critical section APIs are no-ops with the GIL.
144+
# define Py_BEGIN_CRITICAL_SECTION_MUT(mut)
141145
# define Py_BEGIN_CRITICAL_SECTION(op)
142146
# define Py_END_CRITICAL_SECTION()
143147
# define Py_XBEGIN_CRITICAL_SECTION(op)

Include/internal/pycore_warnings.h

+1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ struct _warnings_runtime_state {
1414
PyObject *filters; /* List */
1515
PyObject *once_registry; /* Dict */
1616
PyObject *default_action; /* String */
17+
struct _PyMutex mutex;
1718
long filters_version;
1819
};
1920

Python/_warnings.c

+37-21
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include "Python.h"
2+
#include "pycore_critical_section.h" // Py_BEGIN_CRITICAL_SECTION_MUT()
23
#include "pycore_interp.h" // PyInterpreterState.warnings
34
#include "pycore_long.h" // _PyLong_GetZero()
45
#include "pycore_pyerrors.h" // _PyErr_Occurred()
@@ -235,14 +236,12 @@ get_warnings_attr(PyInterpreterState *interp, PyObject *attr, int try_import)
235236
static PyObject *
236237
get_once_registry(PyInterpreterState *interp)
237238
{
238-
PyObject *registry;
239-
240239
WarningsState *st = warnings_get_state(interp);
241-
if (st == NULL) {
242-
return NULL;
243-
}
240+
assert(st != NULL);
241+
242+
_Py_CRITICAL_SECTION_ASSERT_MUTEX_LOCKED(&st->mutex);
244243

245-
registry = GET_WARNINGS_ATTR(interp, onceregistry, 0);
244+
PyObject *registry = GET_WARNINGS_ATTR(interp, onceregistry, 0);
246245
if (registry == NULL) {
247246
if (PyErr_Occurred())
248247
return NULL;
@@ -265,14 +264,12 @@ get_once_registry(PyInterpreterState *interp)
265264
static PyObject *
266265
get_default_action(PyInterpreterState *interp)
267266
{
268-
PyObject *default_action;
269-
270267
WarningsState *st = warnings_get_state(interp);
271-
if (st == NULL) {
272-
return NULL;
273-
}
268+
assert(st != NULL);
274269

275-
default_action = GET_WARNINGS_ATTR(interp, defaultaction, 0);
270+
_Py_CRITICAL_SECTION_ASSERT_MUTEX_LOCKED(&st->mutex);
271+
272+
PyObject *default_action = GET_WARNINGS_ATTR(interp, defaultaction, 0);
276273
if (default_action == NULL) {
277274
if (PyErr_Occurred()) {
278275
return NULL;
@@ -299,15 +296,12 @@ get_filter(PyInterpreterState *interp, PyObject *category,
299296
PyObject *text, Py_ssize_t lineno,
300297
PyObject *module, PyObject **item)
301298
{
302-
PyObject *action;
303-
Py_ssize_t i;
304-
PyObject *warnings_filters;
305299
WarningsState *st = warnings_get_state(interp);
306-
if (st == NULL) {
307-
return NULL;
308-
}
300+
assert(st != NULL);
309301

310-
warnings_filters = GET_WARNINGS_ATTR(interp, filters, 0);
302+
_Py_CRITICAL_SECTION_ASSERT_MUTEX_LOCKED(&st->mutex);
303+
304+
PyObject *warnings_filters = GET_WARNINGS_ATTR(interp, filters, 0);
311305
if (warnings_filters == NULL) {
312306
if (PyErr_Occurred())
313307
return NULL;
@@ -324,7 +318,7 @@ get_filter(PyInterpreterState *interp, PyObject *category,
324318
}
325319

326320
/* WarningsState.filters could change while we are iterating over it. */
327-
for (i = 0; i < PyList_GET_SIZE(filters); i++) {
321+
for (Py_ssize_t i = 0; i < PyList_GET_SIZE(filters); i++) {
328322
PyObject *tmp_item, *action, *msg, *cat, *mod, *ln_obj;
329323
Py_ssize_t ln;
330324
int is_subclass, good_msg, good_mod;
@@ -384,7 +378,7 @@ get_filter(PyInterpreterState *interp, PyObject *category,
384378
Py_DECREF(tmp_item);
385379
}
386380

387-
action = get_default_action(interp);
381+
PyObject *action = get_default_action(interp);
388382
if (action != NULL) {
389383
*item = Py_NewRef(Py_None);
390384
return action;
@@ -1000,8 +994,13 @@ do_warn(PyObject *message, PyObject *category, Py_ssize_t stack_level,
1000994
&filename, &lineno, &module, &registry))
1001995
return NULL;
1002996

997+
WarningsState *st = warnings_get_state(tstate->interp);
998+
assert(st != NULL);
999+
1000+
Py_BEGIN_CRITICAL_SECTION_MUT(&st->mutex);
10031001
res = warn_explicit(tstate, category, message, filename, lineno, module, registry,
10041002
NULL, source);
1003+
Py_END_CRITICAL_SECTION();
10051004
Py_DECREF(filename);
10061005
Py_DECREF(registry);
10071006
Py_DECREF(module);
@@ -1149,8 +1148,14 @@ warnings_warn_explicit_impl(PyObject *module, PyObject *message,
11491148
return NULL;
11501149
}
11511150
}
1151+
1152+
WarningsState *st = warnings_get_state(tstate->interp);
1153+
assert(st != NULL);
1154+
1155+
Py_BEGIN_CRITICAL_SECTION_MUT(&st->mutex);
11521156
returned = warn_explicit(tstate, category, message, filename, lineno,
11531157
mod, registry, source_line, sourceobj);
1158+
Py_END_CRITICAL_SECTION();
11541159
Py_XDECREF(source_line);
11551160
return returned;
11561161
}
@@ -1290,8 +1295,14 @@ PyErr_WarnExplicitObject(PyObject *category, PyObject *message,
12901295
if (tstate == NULL) {
12911296
return -1;
12921297
}
1298+
1299+
WarningsState *st = warnings_get_state(tstate->interp);
1300+
assert(st != NULL);
1301+
1302+
Py_BEGIN_CRITICAL_SECTION_MUT(&st->mutex);
12931303
res = warn_explicit(tstate, category, message, filename, lineno,
12941304
module, registry, NULL, NULL);
1305+
Py_END_CRITICAL_SECTION();
12951306
if (res == NULL)
12961307
return -1;
12971308
Py_DECREF(res);
@@ -1356,8 +1367,13 @@ PyErr_WarnExplicitFormat(PyObject *category,
13561367
PyObject *res;
13571368
PyThreadState *tstate = get_current_tstate();
13581369
if (tstate != NULL) {
1370+
WarningsState *st = warnings_get_state(tstate->interp);
1371+
assert(st != NULL);
1372+
1373+
Py_BEGIN_CRITICAL_SECTION_MUT(&st->mutex);
13591374
res = warn_explicit(tstate, category, message, filename, lineno,
13601375
module, registry, NULL, NULL);
1376+
Py_END_CRITICAL_SECTION();
13611377
Py_DECREF(message);
13621378
if (res != NULL) {
13631379
Py_DECREF(res);

0 commit comments

Comments
 (0)