Skip to content

Support for sharing dtypes across extensions + public shared data API #472

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Nov 3, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions docs/advanced/misc.rst
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,25 @@ accessed by multiple extension modules:
...
};

Note also that it is possible (although would rarely be required) to share arbitrary
C++ objects between extension modules at runtime. Internal library data is shared
between modules using capsule machinery [#f6]_ which can be also utilized for
storing, modifying and accessing user-defined data. Note that an extension module
will "see" other extensions' data if and only if they were built with the same
pybind11 version. Consider the following example:

.. code-block:: cpp

auto data = (MyData *) py::get_shared_data("mydata");
if (!data)
data = (MyData *) py::set_shared_data("mydata", new MyData(42));

If the above snippet was used in several separately compiled extension modules,
the first one to be imported would create a ``MyData`` instance and associate
a ``"mydata"`` key with a pointer to it. Extensions that are imported later
would be then able to access the data behind the same pointer.

.. [#f6] https://docs.python.org/3/extending/extending.html#using-capsules


Generating documentation using Sphinx
Expand Down
30 changes: 30 additions & 0 deletions include/pybind11/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,7 @@ struct internals {
std::unordered_set<std::pair<const PyObject *, const char *>, overload_hash> inactive_overload_cache;
std::unordered_map<std::type_index, std::vector<bool (*)(PyObject *, void *&)>> direct_conversions;
std::forward_list<void (*) (std::exception_ptr)> registered_exception_translators;
std::unordered_map<std::string, void *> shared_data; // Custom data to be shared across extensions
#if defined(WITH_THREAD)
decltype(PyThread_create_key()) tstate = 0; // Usually an int but a long on Cygwin64 with Python 3.x
PyInterpreterState *istate = nullptr;
Expand Down Expand Up @@ -427,6 +428,35 @@ inline void ignore_unused(const int *) { }

NAMESPACE_END(detail)

/// Returns a named pointer that is shared among all extension modules (using the same
/// pybind11 version) running in the current interpreter. Names starting with underscores
/// are reserved for internal usage. Returns `nullptr` if no matching entry was found.
inline PYBIND11_NOINLINE void* get_shared_data(const std::string& name) {
auto& internals = detail::get_internals();
auto it = internals.shared_data.find(name);
return it != internals.shared_data.end() ? it->second : nullptr;
}

/// Set the shared data that can be later recovered by `get_shared_data()`.
inline PYBIND11_NOINLINE void *set_shared_data(const std::string& name, void *data) {
detail::get_internals().shared_data[name] = data;
return data;
}

/// Returns a typed reference to a shared data entry (by using `get_shared_data()`) if
/// such entry exists. Otherwise, a new object of default-constructible type `T` is
/// added to the shared data under the given name and a reference to it is returned.
template<typename T> T& get_or_create_shared_data(const std::string& name) {
auto& internals = detail::get_internals();
auto it = internals.shared_data.find(name);
T* ptr = (T*) (it != internals.shared_data.end() ? it->second : nullptr);
if (!ptr) {
ptr = new T();
internals.shared_data[name] = ptr;
}
return *ptr;
}

/// Fetch and hold an error which was already set in Python
class error_already_set : public std::runtime_error {
public:
Expand Down
165 changes: 100 additions & 65 deletions include/pybind11/numpy.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include <initializer_list>
#include <functional>
#include <utility>
#include <typeindex>

#if defined(_MSC_VER)
# pragma warning(push)
Expand Down Expand Up @@ -72,6 +73,39 @@ struct PyVoidScalarObject_Proxy {
PyObject *base;
};

struct numpy_type_info {
PyObject* dtype_ptr;
std::string format_str;
};

struct numpy_internals {
std::unordered_map<std::type_index, numpy_type_info> registered_dtypes;

numpy_type_info *get_type_info(const std::type_info& tinfo, bool throw_if_missing = true) {
auto it = registered_dtypes.find(std::type_index(tinfo));
if (it != registered_dtypes.end())
return &(it->second);
if (throw_if_missing)
pybind11_fail(std::string("NumPy type info missing for ") + tinfo.name());
return nullptr;
}

template<typename T> numpy_type_info *get_type_info(bool throw_if_missing = true) {
return get_type_info(typeid(typename std::remove_cv<T>::type), throw_if_missing);
}
};

inline PYBIND11_NOINLINE void load_numpy_internals(numpy_internals* &ptr) {
ptr = &get_or_create_shared_data<numpy_internals>("_numpy_internals");
}

inline numpy_internals& get_numpy_internals() {
static numpy_internals* ptr = nullptr;
if (!ptr)
load_numpy_internals(ptr);
return *ptr;
}

struct npy_api {
enum constants {
NPY_C_CONTIGUOUS_ = 0x0001,
Expand Down Expand Up @@ -656,99 +690,100 @@ struct field_descriptor {
dtype descr;
};

inline PYBIND11_NOINLINE void register_structured_dtype(
const std::initializer_list<field_descriptor>& fields,
const std::type_info& tinfo, size_t itemsize,
bool (*direct_converter)(PyObject *, void *&))
{
auto& numpy_internals = get_numpy_internals();
if (numpy_internals.get_type_info(tinfo, false))
pybind11_fail("NumPy: dtype is already registered");

list names, formats, offsets;
for (auto field : fields) {
if (!field.descr)
pybind11_fail(std::string("NumPy: unsupported field dtype: `") +
field.name + "` @ " + tinfo.name());
names.append(PYBIND11_STR_TYPE(field.name));
formats.append(field.descr);
offsets.append(pybind11::int_(field.offset));
}
auto dtype_ptr = pybind11::dtype(names, formats, offsets, itemsize).release().ptr();

// There is an existing bug in NumPy (as of v1.11): trailing bytes are
// not encoded explicitly into the format string. This will supposedly
// get fixed in v1.12; for further details, see these:
// - https://github.com/numpy/numpy/issues/7797
// - https://github.com/numpy/numpy/pull/7798
// Because of this, we won't use numpy's logic to generate buffer format
// strings and will just do it ourselves.
std::vector<field_descriptor> ordered_fields(fields);
std::sort(ordered_fields.begin(), ordered_fields.end(),
[](const field_descriptor &a, const field_descriptor &b) { return a.offset < b.offset; });
size_t offset = 0;
std::ostringstream oss;
oss << "T{";
for (auto& field : ordered_fields) {
if (field.offset > offset)
oss << (field.offset - offset) << 'x';
// note that '=' is required to cover the case of unaligned fields
oss << '=' << field.format << ':' << field.name << ':';
offset = field.offset + field.size;
}
if (itemsize > offset)
oss << (itemsize - offset) << 'x';
oss << '}';
auto format_str = oss.str();

// Sanity check: verify that NumPy properly parses our buffer format string
auto& api = npy_api::get();
auto arr = array(buffer_info(nullptr, itemsize, format_str, 1));
if (!api.PyArray_EquivTypes_(dtype_ptr, arr.dtype().ptr()))
pybind11_fail("NumPy: invalid buffer descriptor!");

auto tindex = std::type_index(tinfo);
numpy_internals.registered_dtypes[tindex] = { dtype_ptr, format_str };
get_internals().direct_conversions[tindex].push_back(direct_converter);
}

template <typename T>
struct npy_format_descriptor<T, enable_if_t<is_pod_struct<T>::value>> {
static PYBIND11_DESCR name() { return _("struct"); }

static pybind11::dtype dtype() {
if (!dtype_ptr)
pybind11_fail("NumPy: unsupported buffer format!");
return object(dtype_ptr, true);
return object(dtype_ptr(), true);
}

static std::string format() {
if (!dtype_ptr)
pybind11_fail("NumPy: unsupported buffer format!");
static auto format_str = get_numpy_internals().get_type_info<T>(true)->format_str;
return format_str;
}

static void register_dtype(std::initializer_list<field_descriptor> fields) {
if (dtype_ptr)
pybind11_fail("NumPy: dtype is already registered");

list names, formats, offsets;
for (auto field : fields) {
if (!field.descr)
pybind11_fail("NumPy: unsupported field dtype");
names.append(PYBIND11_STR_TYPE(field.name));
formats.append(field.descr);
offsets.append(pybind11::int_(field.offset));
}
dtype_ptr = pybind11::dtype(names, formats, offsets, sizeof(T)).release().ptr();

// There is an existing bug in NumPy (as of v1.11): trailing bytes are
// not encoded explicitly into the format string. This will supposedly
// get fixed in v1.12; for further details, see these:
// - https://github.com/numpy/numpy/issues/7797
// - https://github.com/numpy/numpy/pull/7798
// Because of this, we won't use numpy's logic to generate buffer format
// strings and will just do it ourselves.
std::vector<field_descriptor> ordered_fields(fields);
std::sort(ordered_fields.begin(), ordered_fields.end(),
[](const field_descriptor &a, const field_descriptor &b) {
return a.offset < b.offset;
});
size_t offset = 0;
std::ostringstream oss;
oss << "T{";
for (auto& field : ordered_fields) {
if (field.offset > offset)
oss << (field.offset - offset) << 'x';
// note that '=' is required to cover the case of unaligned fields
oss << '=' << field.format << ':' << field.name << ':';
offset = field.offset + field.size;
}
if (sizeof(T) > offset)
oss << (sizeof(T) - offset) << 'x';
oss << '}';
format_str = oss.str();

// Sanity check: verify that NumPy properly parses our buffer format string
auto& api = npy_api::get();
auto arr = array(buffer_info(nullptr, sizeof(T), format(), 1));
if (!api.PyArray_EquivTypes_(dtype_ptr, arr.dtype().ptr()))
pybind11_fail("NumPy: invalid buffer descriptor!");

register_direct_converter();
static void register_dtype(const std::initializer_list<field_descriptor>& fields) {
register_structured_dtype(fields, typeid(typename std::remove_cv<T>::type),
sizeof(T), &direct_converter);
}

private:
static std::string format_str;
static PyObject* dtype_ptr;
static PyObject* dtype_ptr() {
static PyObject* ptr = get_numpy_internals().get_type_info<T>(true)->dtype_ptr;
return ptr;
}

static bool direct_converter(PyObject *obj, void*& value) {
auto& api = npy_api::get();
if (!PyObject_TypeCheck(obj, api.PyVoidArrType_Type_))
return false;
if (auto descr = object(api.PyArray_DescrFromScalar_(obj), false)) {
if (api.PyArray_EquivTypes_(dtype_ptr, descr.ptr())) {
if (api.PyArray_EquivTypes_(dtype_ptr(), descr.ptr())) {
value = ((PyVoidScalarObject_Proxy *) obj)->obval;
return true;
}
}
return false;
}

static void register_direct_converter() {
get_internals().direct_conversions[std::type_index(typeid(T))].push_back(direct_converter);
}
};

template <typename T>
std::string npy_format_descriptor<T, enable_if_t<is_pod_struct<T>::value>>::format_str;
template <typename T>
PyObject* npy_format_descriptor<T, enable_if_t<is_pod_struct<T>::value>>::dtype_ptr = nullptr;

#define PYBIND11_FIELD_DESCRIPTOR_EX(T, Field, Name) \
::pybind11::detail::field_descriptor { \
Name, offsetof(T, Field), sizeof(decltype(std::declval<T>().Field)), \
Expand Down
23 changes: 16 additions & 7 deletions tests/test_numpy_dtypes.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,20 @@
import re
import pytest

with pytest.suppress(ImportError):
import numpy as np

simple_dtype = np.dtype({'names': ['x', 'y', 'z'],
'formats': ['?', 'u4', 'f4'],
'offsets': [0, 4, 8]})
packed_dtype = np.dtype([('x', '?'), ('y', 'u4'), ('z', 'f4')])

@pytest.fixture(scope='module')
def simple_dtype():
return np.dtype({'names': ['x', 'y', 'z'],
'formats': ['?', 'u4', 'f4'],
'offsets': [0, 4, 8]})


@pytest.fixture(scope='module')
def packed_dtype():
return np.dtype([('x', '?'), ('y', 'u4'), ('z', 'f4')])


def assert_equal(actual, expected_data, expected_dtype):
Expand All @@ -18,7 +27,7 @@ def test_format_descriptors():

with pytest.raises(RuntimeError) as excinfo:
get_format_unbound()
assert 'unsupported buffer format' in str(excinfo.value)
assert re.match('^NumPy type info missing for .*UnboundStruct.*$', str(excinfo.value))

assert print_format_descriptors() == [
"T{=?:x:3x=I:y:=f:z:}",
Expand All @@ -32,7 +41,7 @@ def test_format_descriptors():


@pytest.requires_numpy
def test_dtype():
def test_dtype(simple_dtype):
from pybind11_tests import print_dtypes, test_dtype_ctors, test_dtype_methods

assert print_dtypes() == [
Expand All @@ -57,7 +66,7 @@ def test_dtype():


@pytest.requires_numpy
def test_recarray():
def test_recarray(simple_dtype, packed_dtype):
from pybind11_tests import (create_rec_simple, create_rec_packed, create_rec_nested,
print_rec_simple, print_rec_packed, print_rec_nested,
create_rec_partial, create_rec_partial_nested)
Expand Down