Skip to content

Commit 762f489

Browse files
gh-116664: Ensure thread-safe dict access in _warnings (#116768)
Replace _PyDict_GetItemWithError() with PyDict_GetItemRef().
1 parent 4e45c6c commit 762f489

File tree

1 file changed

+32
-29
lines changed

1 file changed

+32
-29
lines changed

Python/_warnings.c

Lines changed: 32 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
#include "Python.h"
2-
#include "pycore_dict.h" // _PyDict_GetItemWithError()
32
#include "pycore_interp.h" // PyInterpreterState.warnings
43
#include "pycore_long.h" // _PyLong_GetZero()
54
#include "pycore_pyerrors.h" // _PyErr_Occurred()
@@ -8,6 +7,8 @@
87
#include "pycore_sysmodule.h" // _PySys_GetAttr()
98
#include "pycore_traceback.h" // _Py_DisplaySourceLine()
109

10+
#include <stdbool.h>
11+
1112
#include "clinic/_warnings.c.h"
1213

1314
#define MODULE_NAME "_warnings"
@@ -397,7 +398,7 @@ static int
397398
already_warned(PyInterpreterState *interp, PyObject *registry, PyObject *key,
398399
int should_set)
399400
{
400-
PyObject *version_obj, *already_warned;
401+
PyObject *already_warned;
401402

402403
if (key == NULL)
403404
return -1;
@@ -406,14 +407,17 @@ already_warned(PyInterpreterState *interp, PyObject *registry, PyObject *key,
406407
if (st == NULL) {
407408
return -1;
408409
}
409-
version_obj = _PyDict_GetItemWithError(registry, &_Py_ID(version));
410-
if (version_obj == NULL
410+
PyObject *version_obj;
411+
if (PyDict_GetItemRef(registry, &_Py_ID(version), &version_obj) < 0) {
412+
return -1;
413+
}
414+
bool should_update_version = (
415+
version_obj == NULL
411416
|| !PyLong_CheckExact(version_obj)
412-
|| PyLong_AsLong(version_obj) != st->filters_version)
413-
{
414-
if (PyErr_Occurred()) {
415-
return -1;
416-
}
417+
|| PyLong_AsLong(version_obj) != st->filters_version
418+
);
419+
Py_XDECREF(version_obj);
420+
if (should_update_version) {
417421
PyDict_Clear(registry);
418422
version_obj = PyLong_FromLong(st->filters_version);
419423
if (version_obj == NULL)
@@ -911,13 +915,12 @@ setup_context(Py_ssize_t stack_level,
911915
/* Setup registry. */
912916
assert(globals != NULL);
913917
assert(PyDict_Check(globals));
914-
*registry = _PyDict_GetItemWithError(globals, &_Py_ID(__warningregistry__));
918+
int rc = PyDict_GetItemRef(globals, &_Py_ID(__warningregistry__),
919+
registry);
920+
if (rc < 0) {
921+
goto handle_error;
922+
}
915923
if (*registry == NULL) {
916-
int rc;
917-
918-
if (_PyErr_Occurred(tstate)) {
919-
goto handle_error;
920-
}
921924
*registry = PyDict_New();
922925
if (*registry == NULL)
923926
goto handle_error;
@@ -926,21 +929,21 @@ setup_context(Py_ssize_t stack_level,
926929
if (rc < 0)
927930
goto handle_error;
928931
}
929-
else
930-
Py_INCREF(*registry);
931932

932933
/* Setup module. */
933-
*module = _PyDict_GetItemWithError(globals, &_Py_ID(__name__));
934-
if (*module == Py_None || (*module != NULL && PyUnicode_Check(*module))) {
935-
Py_INCREF(*module);
936-
}
937-
else if (_PyErr_Occurred(tstate)) {
934+
rc = PyDict_GetItemRef(globals, &_Py_ID(__name__), module);
935+
if (rc < 0) {
938936
goto handle_error;
939937
}
940-
else {
941-
*module = PyUnicode_FromString("<string>");
942-
if (*module == NULL)
943-
goto handle_error;
938+
if (rc > 0) {
939+
if (Py_IsNone(*module) || PyUnicode_Check(*module)) {
940+
return 1;
941+
}
942+
Py_DECREF(*module);
943+
}
944+
*module = PyUnicode_FromString("<string>");
945+
if (*module == NULL) {
946+
goto handle_error;
944947
}
945948

946949
return 1;
@@ -1063,12 +1066,12 @@ get_source_line(PyInterpreterState *interp, PyObject *module_globals, int lineno
10631066
return NULL;
10641067
}
10651068

1066-
module_name = _PyDict_GetItemWithError(module_globals, &_Py_ID(__name__));
1067-
if (!module_name) {
1069+
int rc = PyDict_GetItemRef(module_globals, &_Py_ID(__name__),
1070+
&module_name);
1071+
if (rc < 0 || rc == 0) {
10681072
Py_DECREF(loader);
10691073
return NULL;
10701074
}
1071-
Py_INCREF(module_name);
10721075

10731076
/* Make sure the loader implements the optional get_source() method. */
10741077
(void)PyObject_GetOptionalAttr(loader, &_Py_ID(get_source), &get_source);

0 commit comments

Comments
 (0)