Skip to content

Commit 0a9ef9c

Browse files
authored
Merge pull request #472 from aldanor/feature/shared-dtypes
Support for sharing dtypes across extensions + public shared data API
2 parents a743ead + cc8ff16 commit 0a9ef9c

File tree

4 files changed

+165
-72
lines changed

4 files changed

+165
-72
lines changed

docs/advanced/misc.rst

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,25 @@ accessed by multiple extension modules:
149149
...
150150
};
151151
152+
Note also that it is possible (although would rarely be required) to share arbitrary
153+
C++ objects between extension modules at runtime. Internal library data is shared
154+
between modules using capsule machinery [#f6]_ which can be also utilized for
155+
storing, modifying and accessing user-defined data. Note that an extension module
156+
will "see" other extensions' data if and only if they were built with the same
157+
pybind11 version. Consider the following example:
158+
159+
.. code-block:: cpp
160+
161+
auto data = (MyData *) py::get_shared_data("mydata");
162+
if (!data)
163+
data = (MyData *) py::set_shared_data("mydata", new MyData(42));
164+
165+
If the above snippet was used in several separately compiled extension modules,
166+
the first one to be imported would create a ``MyData`` instance and associate
167+
a ``"mydata"`` key with a pointer to it. Extensions that are imported later
168+
would be then able to access the data behind the same pointer.
169+
170+
.. [#f6] https://docs.python.org/3/extending/extending.html#using-capsules
152171
153172
154173
Generating documentation using Sphinx

include/pybind11/common.h

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,7 @@ struct internals {
323323
std::unordered_set<std::pair<const PyObject *, const char *>, overload_hash> inactive_overload_cache;
324324
std::unordered_map<std::type_index, std::vector<bool (*)(PyObject *, void *&)>> direct_conversions;
325325
std::forward_list<void (*) (std::exception_ptr)> registered_exception_translators;
326+
std::unordered_map<std::string, void *> shared_data; // Custom data to be shared across extensions
326327
#if defined(WITH_THREAD)
327328
decltype(PyThread_create_key()) tstate = 0; // Usually an int but a long on Cygwin64 with Python 3.x
328329
PyInterpreterState *istate = nullptr;
@@ -427,6 +428,35 @@ inline void ignore_unused(const int *) { }
427428

428429
NAMESPACE_END(detail)
429430

431+
/// Returns a named pointer that is shared among all extension modules (using the same
432+
/// pybind11 version) running in the current interpreter. Names starting with underscores
433+
/// are reserved for internal usage. Returns `nullptr` if no matching entry was found.
434+
inline PYBIND11_NOINLINE void* get_shared_data(const std::string& name) {
435+
auto& internals = detail::get_internals();
436+
auto it = internals.shared_data.find(name);
437+
return it != internals.shared_data.end() ? it->second : nullptr;
438+
}
439+
440+
/// Set the shared data that can be later recovered by `get_shared_data()`.
441+
inline PYBIND11_NOINLINE void *set_shared_data(const std::string& name, void *data) {
442+
detail::get_internals().shared_data[name] = data;
443+
return data;
444+
}
445+
446+
/// Returns a typed reference to a shared data entry (by using `get_shared_data()`) if
447+
/// such entry exists. Otherwise, a new object of default-constructible type `T` is
448+
/// added to the shared data under the given name and a reference to it is returned.
449+
template<typename T> T& get_or_create_shared_data(const std::string& name) {
450+
auto& internals = detail::get_internals();
451+
auto it = internals.shared_data.find(name);
452+
T* ptr = (T*) (it != internals.shared_data.end() ? it->second : nullptr);
453+
if (!ptr) {
454+
ptr = new T();
455+
internals.shared_data[name] = ptr;
456+
}
457+
return *ptr;
458+
}
459+
430460
/// Fetch and hold an error which was already set in Python
431461
class error_already_set : public std::runtime_error {
432462
public:

include/pybind11/numpy.h

Lines changed: 100 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include <initializer_list>
2222
#include <functional>
2323
#include <utility>
24+
#include <typeindex>
2425

2526
#if defined(_MSC_VER)
2627
# pragma warning(push)
@@ -72,6 +73,39 @@ struct PyVoidScalarObject_Proxy {
7273
PyObject *base;
7374
};
7475

76+
struct numpy_type_info {
77+
PyObject* dtype_ptr;
78+
std::string format_str;
79+
};
80+
81+
struct numpy_internals {
82+
std::unordered_map<std::type_index, numpy_type_info> registered_dtypes;
83+
84+
numpy_type_info *get_type_info(const std::type_info& tinfo, bool throw_if_missing = true) {
85+
auto it = registered_dtypes.find(std::type_index(tinfo));
86+
if (it != registered_dtypes.end())
87+
return &(it->second);
88+
if (throw_if_missing)
89+
pybind11_fail(std::string("NumPy type info missing for ") + tinfo.name());
90+
return nullptr;
91+
}
92+
93+
template<typename T> numpy_type_info *get_type_info(bool throw_if_missing = true) {
94+
return get_type_info(typeid(typename std::remove_cv<T>::type), throw_if_missing);
95+
}
96+
};
97+
98+
inline PYBIND11_NOINLINE void load_numpy_internals(numpy_internals* &ptr) {
99+
ptr = &get_or_create_shared_data<numpy_internals>("_numpy_internals");
100+
}
101+
102+
inline numpy_internals& get_numpy_internals() {
103+
static numpy_internals* ptr = nullptr;
104+
if (!ptr)
105+
load_numpy_internals(ptr);
106+
return *ptr;
107+
}
108+
75109
struct npy_api {
76110
enum constants {
77111
NPY_C_CONTIGUOUS_ = 0x0001,
@@ -656,99 +690,100 @@ struct field_descriptor {
656690
dtype descr;
657691
};
658692

693+
inline PYBIND11_NOINLINE void register_structured_dtype(
694+
const std::initializer_list<field_descriptor>& fields,
695+
const std::type_info& tinfo, size_t itemsize,
696+
bool (*direct_converter)(PyObject *, void *&))
697+
{
698+
auto& numpy_internals = get_numpy_internals();
699+
if (numpy_internals.get_type_info(tinfo, false))
700+
pybind11_fail("NumPy: dtype is already registered");
701+
702+
list names, formats, offsets;
703+
for (auto field : fields) {
704+
if (!field.descr)
705+
pybind11_fail(std::string("NumPy: unsupported field dtype: `") +
706+
field.name + "` @ " + tinfo.name());
707+
names.append(PYBIND11_STR_TYPE(field.name));
708+
formats.append(field.descr);
709+
offsets.append(pybind11::int_(field.offset));
710+
}
711+
auto dtype_ptr = pybind11::dtype(names, formats, offsets, itemsize).release().ptr();
712+
713+
// There is an existing bug in NumPy (as of v1.11): trailing bytes are
714+
// not encoded explicitly into the format string. This will supposedly
715+
// get fixed in v1.12; for further details, see these:
716+
// - https://github.com/numpy/numpy/issues/7797
717+
// - https://github.com/numpy/numpy/pull/7798
718+
// Because of this, we won't use numpy's logic to generate buffer format
719+
// strings and will just do it ourselves.
720+
std::vector<field_descriptor> ordered_fields(fields);
721+
std::sort(ordered_fields.begin(), ordered_fields.end(),
722+
[](const field_descriptor &a, const field_descriptor &b) { return a.offset < b.offset; });
723+
size_t offset = 0;
724+
std::ostringstream oss;
725+
oss << "T{";
726+
for (auto& field : ordered_fields) {
727+
if (field.offset > offset)
728+
oss << (field.offset - offset) << 'x';
729+
// note that '=' is required to cover the case of unaligned fields
730+
oss << '=' << field.format << ':' << field.name << ':';
731+
offset = field.offset + field.size;
732+
}
733+
if (itemsize > offset)
734+
oss << (itemsize - offset) << 'x';
735+
oss << '}';
736+
auto format_str = oss.str();
737+
738+
// Sanity check: verify that NumPy properly parses our buffer format string
739+
auto& api = npy_api::get();
740+
auto arr = array(buffer_info(nullptr, itemsize, format_str, 1));
741+
if (!api.PyArray_EquivTypes_(dtype_ptr, arr.dtype().ptr()))
742+
pybind11_fail("NumPy: invalid buffer descriptor!");
743+
744+
auto tindex = std::type_index(tinfo);
745+
numpy_internals.registered_dtypes[tindex] = { dtype_ptr, format_str };
746+
get_internals().direct_conversions[tindex].push_back(direct_converter);
747+
}
748+
659749
template <typename T>
660750
struct npy_format_descriptor<T, enable_if_t<is_pod_struct<T>::value>> {
661751
static PYBIND11_DESCR name() { return _("struct"); }
662752

663753
static pybind11::dtype dtype() {
664-
if (!dtype_ptr)
665-
pybind11_fail("NumPy: unsupported buffer format!");
666-
return object(dtype_ptr, true);
754+
return object(dtype_ptr(), true);
667755
}
668756

669757
static std::string format() {
670-
if (!dtype_ptr)
671-
pybind11_fail("NumPy: unsupported buffer format!");
758+
static auto format_str = get_numpy_internals().get_type_info<T>(true)->format_str;
672759
return format_str;
673760
}
674761

675-
static void register_dtype(std::initializer_list<field_descriptor> fields) {
676-
if (dtype_ptr)
677-
pybind11_fail("NumPy: dtype is already registered");
678-
679-
list names, formats, offsets;
680-
for (auto field : fields) {
681-
if (!field.descr)
682-
pybind11_fail("NumPy: unsupported field dtype");
683-
names.append(PYBIND11_STR_TYPE(field.name));
684-
formats.append(field.descr);
685-
offsets.append(pybind11::int_(field.offset));
686-
}
687-
dtype_ptr = pybind11::dtype(names, formats, offsets, sizeof(T)).release().ptr();
688-
689-
// There is an existing bug in NumPy (as of v1.11): trailing bytes are
690-
// not encoded explicitly into the format string. This will supposedly
691-
// get fixed in v1.12; for further details, see these:
692-
// - https://github.com/numpy/numpy/issues/7797
693-
// - https://github.com/numpy/numpy/pull/7798
694-
// Because of this, we won't use numpy's logic to generate buffer format
695-
// strings and will just do it ourselves.
696-
std::vector<field_descriptor> ordered_fields(fields);
697-
std::sort(ordered_fields.begin(), ordered_fields.end(),
698-
[](const field_descriptor &a, const field_descriptor &b) {
699-
return a.offset < b.offset;
700-
});
701-
size_t offset = 0;
702-
std::ostringstream oss;
703-
oss << "T{";
704-
for (auto& field : ordered_fields) {
705-
if (field.offset > offset)
706-
oss << (field.offset - offset) << 'x';
707-
// note that '=' is required to cover the case of unaligned fields
708-
oss << '=' << field.format << ':' << field.name << ':';
709-
offset = field.offset + field.size;
710-
}
711-
if (sizeof(T) > offset)
712-
oss << (sizeof(T) - offset) << 'x';
713-
oss << '}';
714-
format_str = oss.str();
715-
716-
// Sanity check: verify that NumPy properly parses our buffer format string
717-
auto& api = npy_api::get();
718-
auto arr = array(buffer_info(nullptr, sizeof(T), format(), 1));
719-
if (!api.PyArray_EquivTypes_(dtype_ptr, arr.dtype().ptr()))
720-
pybind11_fail("NumPy: invalid buffer descriptor!");
721-
722-
register_direct_converter();
762+
static void register_dtype(const std::initializer_list<field_descriptor>& fields) {
763+
register_structured_dtype(fields, typeid(typename std::remove_cv<T>::type),
764+
sizeof(T), &direct_converter);
723765
}
724766

725767
private:
726-
static std::string format_str;
727-
static PyObject* dtype_ptr;
768+
static PyObject* dtype_ptr() {
769+
static PyObject* ptr = get_numpy_internals().get_type_info<T>(true)->dtype_ptr;
770+
return ptr;
771+
}
728772

729773
static bool direct_converter(PyObject *obj, void*& value) {
730774
auto& api = npy_api::get();
731775
if (!PyObject_TypeCheck(obj, api.PyVoidArrType_Type_))
732776
return false;
733777
if (auto descr = object(api.PyArray_DescrFromScalar_(obj), false)) {
734-
if (api.PyArray_EquivTypes_(dtype_ptr, descr.ptr())) {
778+
if (api.PyArray_EquivTypes_(dtype_ptr(), descr.ptr())) {
735779
value = ((PyVoidScalarObject_Proxy *) obj)->obval;
736780
return true;
737781
}
738782
}
739783
return false;
740784
}
741-
742-
static void register_direct_converter() {
743-
get_internals().direct_conversions[std::type_index(typeid(T))].push_back(direct_converter);
744-
}
745785
};
746786

747-
template <typename T>
748-
std::string npy_format_descriptor<T, enable_if_t<is_pod_struct<T>::value>>::format_str;
749-
template <typename T>
750-
PyObject* npy_format_descriptor<T, enable_if_t<is_pod_struct<T>::value>>::dtype_ptr = nullptr;
751-
752787
#define PYBIND11_FIELD_DESCRIPTOR_EX(T, Field, Name) \
753788
::pybind11::detail::field_descriptor { \
754789
Name, offsetof(T, Field), sizeof(decltype(std::declval<T>().Field)), \

tests/test_numpy_dtypes.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,20 @@
1+
import re
12
import pytest
3+
24
with pytest.suppress(ImportError):
35
import numpy as np
46

5-
simple_dtype = np.dtype({'names': ['x', 'y', 'z'],
6-
'formats': ['?', 'u4', 'f4'],
7-
'offsets': [0, 4, 8]})
8-
packed_dtype = np.dtype([('x', '?'), ('y', 'u4'), ('z', 'f4')])
7+
8+
@pytest.fixture(scope='module')
9+
def simple_dtype():
10+
return np.dtype({'names': ['x', 'y', 'z'],
11+
'formats': ['?', 'u4', 'f4'],
12+
'offsets': [0, 4, 8]})
13+
14+
15+
@pytest.fixture(scope='module')
16+
def packed_dtype():
17+
return np.dtype([('x', '?'), ('y', 'u4'), ('z', 'f4')])
918

1019

1120
def assert_equal(actual, expected_data, expected_dtype):
@@ -18,7 +27,7 @@ def test_format_descriptors():
1827

1928
with pytest.raises(RuntimeError) as excinfo:
2029
get_format_unbound()
21-
assert 'unsupported buffer format' in str(excinfo.value)
30+
assert re.match('^NumPy type info missing for .*UnboundStruct.*$', str(excinfo.value))
2231

2332
assert print_format_descriptors() == [
2433
"T{=?:x:3x=I:y:=f:z:}",
@@ -32,7 +41,7 @@ def test_format_descriptors():
3241

3342

3443
@pytest.requires_numpy
35-
def test_dtype():
44+
def test_dtype(simple_dtype):
3645
from pybind11_tests import print_dtypes, test_dtype_ctors, test_dtype_methods
3746

3847
assert print_dtypes() == [
@@ -57,7 +66,7 @@ def test_dtype():
5766

5867

5968
@pytest.requires_numpy
60-
def test_recarray():
69+
def test_recarray(simple_dtype, packed_dtype):
6170
from pybind11_tests import (create_rec_simple, create_rec_packed, create_rec_nested,
6271
print_rec_simple, print_rec_packed, print_rec_nested,
6372
create_rec_partial, create_rec_partial_nested)

0 commit comments

Comments
 (0)