-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Fix thread safety for pybind11 loader_life_support #3237
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
53b3919
caa974f
9179d60
0c2bf55
4b7dc7a
fdfff88
d91a4a7
c6720ca
2c07c0d
8a1a59f
dfc94f3
4f25e31
366f40d
b5a0538
ffc52a3
dc7df66
dd8f264
c4c6acb
6ad3de6
d7e3067
638d091
a06f851
5c58953
5f66855
1237bbe
afbc066
fe49b37
5787104
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -106,14 +106,16 @@ struct internals { | |
std::unordered_map<const PyObject *, std::vector<PyObject *>> patients; | ||
std::forward_list<ExceptionTranslator> registered_exception_translators; | ||
std::unordered_map<std::string, void *> shared_data; // Custom data to be shared across extensions | ||
std::vector<PyObject *> loader_patient_stack; // Used by `loader_life_support` | ||
std::forward_list<std::string> static_strings; // Stores the std::strings backing detail::c_str() | ||
PyTypeObject *static_property_type; | ||
PyTypeObject *default_metaclass; | ||
PyObject *instance_base; | ||
#if defined(WITH_THREAD) | ||
#if !defined(WITH_THREAD) | ||
void* loader_patient_ptr = nullptr; // Used by `loader_life_support` | ||
#else // defined(WITH_THREAD) | ||
PYBIND11_TLS_KEY_INIT(tstate); | ||
PyInterpreterState *istate = nullptr; | ||
PYBIND11_TLS_KEY_INIT(loader_patient_key); | ||
~internals() { | ||
// This destructor is called *after* Py_Finalize() in finalize_interpreter(). | ||
// That *SHOULD BE* fine. The following details what happens when PyThread_tss_free is called. | ||
|
@@ -123,6 +125,7 @@ struct internals { | |
// of those have anything to do with CPython internals. | ||
// PyMem_RawFree *requires* that the `tstate` be allocated with the CPython allocator. | ||
PYBIND11_TLS_FREE(tstate); | ||
PYBIND11_TLS_FREE(loader_patient_key); | ||
} | ||
#endif | ||
}; | ||
|
@@ -154,7 +157,7 @@ struct type_info { | |
}; | ||
|
||
/// Tracks the `internals` and `type_info` ABI version independent of the main library version | ||
#define PYBIND11_INTERNALS_VERSION 4 | ||
#define PYBIND11_INTERNALS_VERSION 5 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we are bumping the ABI, we should probably move through several of the PRs sitting round that bump the ABI and get them in too. Also, this is going to create a massive mess, as we'll have to start a conda-forge migration, and quite a few packages don't use the conda-forge's pybind11, so... It will be a mess. But I guess it's going to have to happen eventually. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There may be a non-abi bumping strategy available; if you like I can take that approach. |
||
|
||
/// On MSVC, debug and release builds are not ABI-compatible! | ||
#if defined(_MSC_VER) && defined(_DEBUG) | ||
|
@@ -298,15 +301,26 @@ PYBIND11_NOINLINE internals &get_internals() { | |
#if PY_VERSION_HEX >= 0x03070000 | ||
internals_ptr->tstate = PyThread_tss_alloc(); | ||
if (!internals_ptr->tstate || (PyThread_tss_create(internals_ptr->tstate) != 0)) | ||
pybind11_fail("get_internals: could not successfully initialize the TSS key!"); | ||
pybind11_fail("get_internals: could not successfully initialize the tstate TSS key!"); | ||
PyThread_tss_set(internals_ptr->tstate, tstate); | ||
#else | ||
internals_ptr->tstate = PyThread_create_key(); | ||
if (internals_ptr->tstate == -1) | ||
pybind11_fail("get_internals: could not successfully initialize the TLS key!"); | ||
pybind11_fail("get_internals: could not successfully initialize the tstate TLS key!"); | ||
PyThread_set_key_value(internals_ptr->tstate, tstate); | ||
#endif | ||
internals_ptr->istate = tstate->interp; | ||
|
||
#if PY_VERSION_HEX >= 0x03070000 | ||
internals_ptr->loader_patient_key = PyThread_tss_alloc(); | ||
if (!internals_ptr->loader_patient_key || (PyThread_tss_create(internals_ptr->loader_patient_key) != 0)) | ||
pybind11_fail("get_internals: could not successfully initialize the loader_patient TSS key!"); | ||
#else | ||
internals_ptr->loader_patient_key = PyThread_create_key(); | ||
if (internals_ptr->loader_patient_key == -1) | ||
pybind11_fail("get_internals: could not successfully initialize the loader_patient TLS key!"); | ||
#endif | ||
|
||
#endif | ||
builtins[id] = capsule(internals_pp); | ||
internals_ptr->registered_exception_translators.push_front(&translate_exception); | ||
|
@@ -335,6 +349,30 @@ inline local_internals &get_local_internals() { | |
return locals; | ||
} | ||
|
||
/// The patient pointer is used to store patient data for a call frame. | ||
/// See loader_life_support for use. | ||
inline void* get_loader_patient_pointer() { | ||
#if !defined(WITH_THREAD) | ||
return get_internals().loader_patient_ptr; | ||
#else | ||
auto &internals = get_internals(); | ||
return PYBIND11_TLS_GET_VALUE(internals.loader_patient_key); | ||
#endif | ||
} | ||
|
||
inline void set_loader_patient_pointer(void* ptr) { | ||
#if !defined(WITH_THREAD) | ||
get_internals().loader_patient_ptr = ptr; | ||
#else | ||
auto &internals = get_internals(); | ||
#if PY_VERSION_HEX >= 0x03070000 | ||
PyThread_tss_set(internals.loader_patient_key, ptr); | ||
#else | ||
PyThread_set_key_value(internals.loader_patient_key, ptr); | ||
#endif | ||
#endif | ||
} | ||
|
||
|
||
/// Constructs a std::string with the given arguments, stores it in `internals`, and returns its | ||
/// `c_str()`. Such strings objects have a long storage duration -- the internal strings are only | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -32,46 +32,37 @@ PYBIND11_NAMESPACE_BEGIN(detail) | |
/// Adding a patient will keep it alive up until the enclosing function returns. | ||
class loader_life_support { | ||
public: | ||
laramiel marked this conversation as resolved.
Show resolved
Hide resolved
|
||
void* parent = nullptr; | ||
std::unordered_set<PyObject *> keep_alive; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Probably better to use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note: Previously the possibility of using PySet was discussed, but that cannot be used since if the type defines a custom hash and equality function then it won't work correctly. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unfortunately py::object is not hashable, so that would require intrusive adaptors or similar. Since the control flow here is so limited, I think that the simplest answer is to just use PyObject* here and refcount it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, good point. |
||
|
||
/// A new patient frame is created when a function is entered | ||
loader_life_support() { | ||
get_internals().loader_patient_stack.push_back(nullptr); | ||
parent = get_loader_patient_pointer(); | ||
set_loader_patient_pointer(this); | ||
} | ||
|
||
/// ... and destroyed after it returns | ||
~loader_life_support() { | ||
auto &stack = get_internals().loader_patient_stack; | ||
if (stack.empty()) | ||
auto* frame = reinterpret_cast<loader_life_support*>(get_loader_patient_pointer()); | ||
if (frame != this) | ||
pybind11_fail("loader_life_support: internal error"); | ||
set_loader_patient_pointer(parent); | ||
|
||
auto ptr = stack.back(); | ||
stack.pop_back(); | ||
Py_CLEAR(ptr); | ||
|
||
// A heuristic to reduce the stack's capacity (e.g. after long recursive calls) | ||
if (stack.capacity() > 16 && !stack.empty() && stack.capacity() / stack.size() > 2) | ||
stack.shrink_to_fit(); | ||
for (auto* item : keep_alive) | ||
Py_DECREF(item); | ||
} | ||
|
||
/// This can only be used inside a pybind11-bound function, either by `argument_loader` | ||
/// at argument preparation time or by `py::cast()` at execution time. | ||
PYBIND11_NOINLINE static void add_patient(handle h) { | ||
auto &stack = get_internals().loader_patient_stack; | ||
if (stack.empty()) | ||
auto* frame = reinterpret_cast<loader_life_support*>(get_loader_patient_pointer()); | ||
if (!frame) | ||
throw cast_error("When called outside a bound function, py::cast() cannot " | ||
"do Python -> C++ conversions which require the creation " | ||
"of temporary values"); | ||
|
||
auto &list_ptr = stack.back(); | ||
if (list_ptr == nullptr) { | ||
list_ptr = PyList_New(1); | ||
if (!list_ptr) | ||
pybind11_fail("loader_life_support: error allocating list"); | ||
PyList_SET_ITEM(list_ptr, 0, h.inc_ref().ptr()); | ||
} else { | ||
auto result = PyList_Append(list_ptr, h.ptr()); | ||
if (result == -1) | ||
pybind11_fail("loader_life_support: error adding patient"); | ||
} | ||
if (frame->keep_alive.insert(h.ptr()).second) | ||
Py_INCREF(h.ptr()); | ||
} | ||
}; | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
/* | ||
tests/test_thread.cpp -- call pybind11 bound methods in threads | ||
|
||
Copyright (c) 2017 Laramie Leavitt (Google LLC) <[email protected]> | ||
|
||
All rights reserved. Use of this source code is governed by a | ||
BSD-style license that can be found in the LICENSE file. | ||
*/ | ||
|
||
#include <string_view> | ||
#include <string> | ||
|
||
#define PYBIND11_HAS_STRING_VIEW 1 | ||
|
||
#include <pybind11/cast.h> | ||
laramiel marked this conversation as resolved.
Show resolved
Hide resolved
|
||
#include <pybind11/pybind11.h> | ||
#include <pybind11/stl.h> | ||
|
||
#include "pybind11_tests.h" | ||
|
||
|
||
TEST_SUBMODULE(thread, m) { | ||
|
||
// std::string_view uses loader_life_support to ensure that the string contents | ||
// remains alive for the life of the call. These methods are invoked concurrently | ||
m.def("method", [](std::string_view str) -> std::string { | ||
return std::string(str); | ||
}); | ||
|
||
m.def("method_no_gil", [](std::string_view str) -> std::string { | ||
return std::string(str); | ||
}, | ||
py::call_guard<py::gil_scoped_release>()); | ||
|
||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
# -*- coding: utf-8 -*- | ||
import concurrent.futures | ||
|
||
import pytest | ||
laramiel marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
import env # noqa: F401 | ||
laramiel marked this conversation as resolved.
Show resolved
Hide resolved
|
||
from pybind11_tests import thread as m | ||
|
||
|
||
def method(s): | ||
return m.method(s) | ||
|
||
|
||
def method_no_gil(s): | ||
return m.method_no_gil(s) | ||
|
||
|
||
def test_message(): | ||
inputs = ["%d" % i for i in range(20, 30)] | ||
with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor: | ||
results = list(executor.map(method, inputs)) | ||
results.sort() | ||
for i in range(len(results)): | ||
assert results[i] == ("%s" % (i + 20)) | ||
|
||
|
||
def test_message_no_gil(): | ||
inputs = ["%d" % i for i in range(20, 30)] | ||
with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor: | ||
results = list(executor.map(method_no_gil, inputs)) | ||
results.sort() | ||
for i in range(len(results)): | ||
assert results[i] == ("%s" % (i + 20)) |
Uh oh!
There was an error while loading. Please reload this page.