From 2184f6d4d64b4631c943b4c92d35bcd849da51ad Mon Sep 17 00:00:00 2001 From: Ivan Smirnov Date: Mon, 31 Oct 2016 13:52:32 +0000 Subject: [PATCH 1/5] NumPy dtypes are now shared across extensions --- include/pybind11/common.h | 1 + include/pybind11/numpy.h | 78 +++++++++++++++++++++++++------------- tests/test_numpy_dtypes.py | 2 +- 3 files changed, 54 insertions(+), 27 deletions(-) diff --git a/include/pybind11/common.h b/include/pybind11/common.h index b5434d04ab..27cd47beff 100644 --- a/include/pybind11/common.h +++ b/include/pybind11/common.h @@ -323,6 +323,7 @@ struct internals { std::unordered_set, overload_hash> inactive_overload_cache; std::unordered_map> direct_conversions; std::forward_list registered_exception_translators; + std::unordered_map shared_data; #if defined(WITH_THREAD) decltype(PyThread_create_key()) tstate = 0; // Usually an int but a long on Cygwin64 with Python 3.x PyInterpreterState *istate = nullptr; diff --git a/include/pybind11/numpy.h b/include/pybind11/numpy.h index da04c62a8b..19bff63595 100644 --- a/include/pybind11/numpy.h +++ b/include/pybind11/numpy.h @@ -21,6 +21,7 @@ #include #include #include +#include #if defined(_MSC_VER) # pragma warning(push) @@ -72,6 +73,39 @@ struct PyVoidScalarObject_Proxy { PyObject *base; }; +struct numpy_type_info { + PyObject* dtype_ptr; + std::string format_str; +}; + +struct numpy_internals { + std::unordered_map registered_dtypes; + + template numpy_type_info *get_type_info(bool throw_if_missing = true) { + auto it = registered_dtypes.find(std::type_index(typeid(T))); + if (it != registered_dtypes.end()) + return &(it->second); + if (throw_if_missing) + pybind11_fail(std::string("NumPy type info missing for ") + typeid(T).name()); + return nullptr; + } +}; + +inline PYBIND11_NOINLINE numpy_internals* load_numpy_internals() { + auto& shared_data = detail::get_internals().shared_data; + auto it = shared_data.find("numpy_internals"); + if (it != shared_data.end()) + return (numpy_internals *)it->second; + auto ptr = new numpy_internals(); + shared_data["numpy_internals"] = ptr; + return ptr; +} + +inline numpy_internals& get_numpy_internals() { + static numpy_internals* ptr = load_numpy_internals(); + return *ptr; +} + struct npy_api { enum constants { NPY_C_CONTIGUOUS_ = 0x0001, @@ -661,30 +695,29 @@ struct npy_format_descriptor::value>> { static PYBIND11_DESCR name() { return _("struct"); } static pybind11::dtype dtype() { - if (!dtype_ptr) - pybind11_fail("NumPy: unsupported buffer format!"); - return object(dtype_ptr, true); + return object(dtype_ptr(), true); } static std::string format() { - if (!dtype_ptr) - pybind11_fail("NumPy: unsupported buffer format!"); + static auto format_str = get_numpy_internals().get_type_info(true)->format_str; return format_str; } static void register_dtype(std::initializer_list fields) { - if (dtype_ptr) + auto& numpy_internals = get_numpy_internals(); + if (numpy_internals.get_type_info(false)) pybind11_fail("NumPy: dtype is already registered"); list names, formats, offsets; for (auto field : fields) { if (!field.descr) - pybind11_fail("NumPy: unsupported field dtype"); + pybind11_fail(std::string("NumPy: unsupported field dtype: `") + + field.name + "` @ " + typeid(T).name()); names.append(PYBIND11_STR_TYPE(field.name)); formats.append(field.descr); offsets.append(pybind11::int_(field.offset)); } - dtype_ptr = pybind11::dtype(names, formats, offsets, sizeof(T)).release().ptr(); + auto dtype_ptr = pybind11::dtype(names, formats, offsets, sizeof(T)).release().ptr(); // There is an existing bug in NumPy (as of v1.11): trailing bytes are // not encoded explicitly into the format string. This will supposedly @@ -695,9 +728,7 @@ struct npy_format_descriptor::value>> { // strings and will just do it ourselves. std::vector ordered_fields(fields); std::sort(ordered_fields.begin(), ordered_fields.end(), - [](const field_descriptor &a, const field_descriptor &b) { - return a.offset < b.offset; - }); + [](const field_descriptor &a, const field_descriptor &b) { return a.offset < b.offset; }); size_t offset = 0; std::ostringstream oss; oss << "T{"; @@ -711,44 +742,39 @@ struct npy_format_descriptor::value>> { if (sizeof(T) > offset) oss << (sizeof(T) - offset) << 'x'; oss << '}'; - format_str = oss.str(); + auto format_str = oss.str(); // Sanity check: verify that NumPy properly parses our buffer format string auto& api = npy_api::get(); - auto arr = array(buffer_info(nullptr, sizeof(T), format(), 1)); + auto arr = array(buffer_info(nullptr, sizeof(T), format_str, 1)); if (!api.PyArray_EquivTypes_(dtype_ptr, arr.dtype().ptr())) pybind11_fail("NumPy: invalid buffer descriptor!"); - register_direct_converter(); + auto tindex = std::type_index(typeid(T)); + numpy_internals.registered_dtypes[tindex] = { dtype_ptr, format_str }; + get_internals().direct_conversions[tindex].push_back(direct_converter); } private: - static std::string format_str; - static PyObject* dtype_ptr; + static PyObject* dtype_ptr() { + static PyObject* ptr = get_numpy_internals().get_type_info(true)->dtype_ptr; + return ptr; + } static bool direct_converter(PyObject *obj, void*& value) { auto& api = npy_api::get(); if (!PyObject_TypeCheck(obj, api.PyVoidArrType_Type_)) return false; if (auto descr = object(api.PyArray_DescrFromScalar_(obj), false)) { - if (api.PyArray_EquivTypes_(dtype_ptr, descr.ptr())) { + if (api.PyArray_EquivTypes_(dtype_ptr(), descr.ptr())) { value = ((PyVoidScalarObject_Proxy *) obj)->obval; return true; } } return false; } - - static void register_direct_converter() { - get_internals().direct_conversions[std::type_index(typeid(T))].push_back(direct_converter); - } }; -template -std::string npy_format_descriptor::value>>::format_str; -template -PyObject* npy_format_descriptor::value>>::dtype_ptr = nullptr; - #define PYBIND11_FIELD_DESCRIPTOR_EX(T, Field, Name) \ ::pybind11::detail::field_descriptor { \ Name, offsetof(T, Field), sizeof(decltype(std::declval().Field)), \ diff --git a/tests/test_numpy_dtypes.py b/tests/test_numpy_dtypes.py index b4e6d71f29..c0d6ec2920 100644 --- a/tests/test_numpy_dtypes.py +++ b/tests/test_numpy_dtypes.py @@ -18,7 +18,7 @@ def test_format_descriptors(): with pytest.raises(RuntimeError) as excinfo: get_format_unbound() - assert 'unsupported buffer format' in str(excinfo.value) + assert re.match('^NumPy type info missing for .*UnboundStruct.*$', str(excinfo.value)) assert print_format_descriptors() == [ "T{=?:x:3x=I:y:=f:z:}", From c546655dc29086895538be92377a82f64a834c1f Mon Sep 17 00:00:00 2001 From: Ivan Smirnov Date: Mon, 31 Oct 2016 13:54:43 +0000 Subject: [PATCH 2/5] Use pytest fixtures in numpy dtypes test module --- tests/test_numpy_dtypes.py | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/tests/test_numpy_dtypes.py b/tests/test_numpy_dtypes.py index c0d6ec2920..2ef6f4d0b6 100644 --- a/tests/test_numpy_dtypes.py +++ b/tests/test_numpy_dtypes.py @@ -1,11 +1,20 @@ +import re import pytest + with pytest.suppress(ImportError): import numpy as np - simple_dtype = np.dtype({'names': ['x', 'y', 'z'], - 'formats': ['?', 'u4', 'f4'], - 'offsets': [0, 4, 8]}) - packed_dtype = np.dtype([('x', '?'), ('y', 'u4'), ('z', 'f4')]) + +@pytest.fixture(scope='module') +def simple_dtype(): + return np.dtype({'names': ['x', 'y', 'z'], + 'formats': ['?', 'u4', 'f4'], + 'offsets': [0, 4, 8]}) + + +@pytest.fixture(scope='module') +def packed_dtype(): + return np.dtype([('x', '?'), ('y', 'u4'), ('z', 'f4')]) def assert_equal(actual, expected_data, expected_dtype): @@ -32,7 +41,7 @@ def test_format_descriptors(): @pytest.requires_numpy -def test_dtype(): +def test_dtype(simple_dtype): from pybind11_tests import print_dtypes, test_dtype_ctors, test_dtype_methods assert print_dtypes() == [ @@ -57,7 +66,7 @@ def test_dtype(): @pytest.requires_numpy -def test_recarray(): +def test_recarray(simple_dtype, packed_dtype): from pybind11_tests import (create_rec_simple, create_rec_packed, create_rec_nested, print_rec_simple, print_rec_packed, print_rec_nested, create_rec_partial, create_rec_partial_nested) From 2dbf0297050533e697d065f4b92283e0ecd37581 Mon Sep 17 00:00:00 2001 From: Ivan Smirnov Date: Mon, 31 Oct 2016 14:11:10 +0000 Subject: [PATCH 3/5] Add public shared_data API NumPy internals are stored under "_numpy_internals" key. --- include/pybind11/common.h | 31 ++++++++++++++++++++++++++++++- include/pybind11/numpy.h | 14 +++++--------- 2 files changed, 35 insertions(+), 10 deletions(-) diff --git a/include/pybind11/common.h b/include/pybind11/common.h index 27cd47beff..62198c3414 100644 --- a/include/pybind11/common.h +++ b/include/pybind11/common.h @@ -323,7 +323,7 @@ struct internals { std::unordered_set, overload_hash> inactive_overload_cache; std::unordered_map> direct_conversions; std::forward_list registered_exception_translators; - std::unordered_map shared_data; + std::unordered_map shared_data; // Custom data to be shared across extensions #if defined(WITH_THREAD) decltype(PyThread_create_key()) tstate = 0; // Usually an int but a long on Cygwin64 with Python 3.x PyInterpreterState *istate = nullptr; @@ -428,6 +428,35 @@ inline void ignore_unused(const int *) { } NAMESPACE_END(detail) +/// Returns a named pointer that is shared among all extension modules (using the same +/// pybind11 version) running in the current interpreter. Names starting with underscores +/// are reserved for internal usage. Returns `nullptr` if no matching entry was found. +inline PYBIND11_NOINLINE void* get_shared_data(const std::string& name) { + auto& internals = detail::get_internals(); + auto it = internals.shared_data.find(name); + return it != internals.shared_data.end() ? it->second : nullptr; +} + +/// Set the shared data that can be later recovered by `get_shared_data()`. +inline PYBIND11_NOINLINE void *set_shared_data(const std::string& name, void *data) { + detail::get_internals().shared_data[name] = data; + return data; +} + +/// Returns a typed reference to a shared data entry (by using `get_shared_data()`) if +/// such entry exists. Otherwise, a new object of default-constructible type `T` is +/// added to the shared data under the given name and a reference to it is returned. +template T& get_or_create_shared_data(const std::string& name) { + auto& internals = detail::get_internals(); + auto it = internals.shared_data.find(name); + T* ptr = (T*) (it != internals.shared_data.end() ? it->second : nullptr); + if (!ptr) { + ptr = new T(); + internals.shared_data[name] = ptr; + } + return *ptr; +} + /// Fetch and hold an error which was already set in Python class error_already_set : public std::runtime_error { public: diff --git a/include/pybind11/numpy.h b/include/pybind11/numpy.h index 19bff63595..b180cb296b 100644 --- a/include/pybind11/numpy.h +++ b/include/pybind11/numpy.h @@ -91,18 +91,14 @@ struct numpy_internals { } }; -inline PYBIND11_NOINLINE numpy_internals* load_numpy_internals() { - auto& shared_data = detail::get_internals().shared_data; - auto it = shared_data.find("numpy_internals"); - if (it != shared_data.end()) - return (numpy_internals *)it->second; - auto ptr = new numpy_internals(); - shared_data["numpy_internals"] = ptr; - return ptr; +inline PYBIND11_NOINLINE void load_numpy_internals(numpy_internals* &ptr) { + ptr = &get_or_create_shared_data("_numpy_internals"); } inline numpy_internals& get_numpy_internals() { - static numpy_internals* ptr = load_numpy_internals(); + static numpy_internals* ptr = nullptr; + if (!ptr) + load_numpy_internals(ptr); return *ptr; } From f95fda0eb20dc55b92c45249b1e004c34ab06f3f Mon Sep 17 00:00:00 2001 From: Ivan Smirnov Date: Mon, 31 Oct 2016 21:40:11 +0000 Subject: [PATCH 4/5] Add docs re: shared data API --- docs/advanced/misc.rst | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/docs/advanced/misc.rst b/docs/advanced/misc.rst index 2968f8ac12..b0719065f1 100644 --- a/docs/advanced/misc.rst +++ b/docs/advanced/misc.rst @@ -149,6 +149,25 @@ accessed by multiple extension modules: ... }; +Note also that it is possible (although would rarely be required) to share arbitrary +C++ objects between extension modules at runtime. Internal library data is shared +between modules using capsule machinery [#f6]_ which can be also utilized for +storing, modifying and accessing user-defined data. Note that an extension module +will "see" other extensions' data if and only if they were built with the same +pybind11 version. Consider the following example: + +.. code-block:: cpp + + auto data = (MyData *) py::get_shared_data("mydata"); + if (!data) + data = (MyData *) py::set_shared_data("mydata", new MyData(42)); + +If the above snippet was used in several separately compiled extension modules, +the first one to be imported would create a ``MyData`` instance and associate +a ``"mydata"`` key with a pointer to it. Extensions that are imported later +would be then able to access the data behind the same pointer. + +.. [#f6] https://docs.python.org/3/extending/extending.html#using-capsules Generating documentation using Sphinx From cc8ff16547fc841d9bac4e8a3624386d70c566e3 Mon Sep 17 00:00:00 2001 From: Ivan Smirnov Date: Mon, 31 Oct 2016 16:16:47 +0000 Subject: [PATCH 5/5] Move register_dtype() outside of the template (avoid code bloat if possible) --- include/pybind11/numpy.h | 119 ++++++++++++++++++++++----------------- 1 file changed, 66 insertions(+), 53 deletions(-) diff --git a/include/pybind11/numpy.h b/include/pybind11/numpy.h index b180cb296b..af465a17d9 100644 --- a/include/pybind11/numpy.h +++ b/include/pybind11/numpy.h @@ -81,14 +81,18 @@ struct numpy_type_info { struct numpy_internals { std::unordered_map registered_dtypes; - template numpy_type_info *get_type_info(bool throw_if_missing = true) { - auto it = registered_dtypes.find(std::type_index(typeid(T))); + numpy_type_info *get_type_info(const std::type_info& tinfo, bool throw_if_missing = true) { + auto it = registered_dtypes.find(std::type_index(tinfo)); if (it != registered_dtypes.end()) return &(it->second); if (throw_if_missing) - pybind11_fail(std::string("NumPy type info missing for ") + typeid(T).name()); + pybind11_fail(std::string("NumPy type info missing for ") + tinfo.name()); return nullptr; } + + template numpy_type_info *get_type_info(bool throw_if_missing = true) { + return get_type_info(typeid(typename std::remove_cv::type), throw_if_missing); + } }; inline PYBIND11_NOINLINE void load_numpy_internals(numpy_internals* &ptr) { @@ -686,6 +690,62 @@ struct field_descriptor { dtype descr; }; +inline PYBIND11_NOINLINE void register_structured_dtype( + const std::initializer_list& fields, + const std::type_info& tinfo, size_t itemsize, + bool (*direct_converter)(PyObject *, void *&)) +{ + auto& numpy_internals = get_numpy_internals(); + if (numpy_internals.get_type_info(tinfo, false)) + pybind11_fail("NumPy: dtype is already registered"); + + list names, formats, offsets; + for (auto field : fields) { + if (!field.descr) + pybind11_fail(std::string("NumPy: unsupported field dtype: `") + + field.name + "` @ " + tinfo.name()); + names.append(PYBIND11_STR_TYPE(field.name)); + formats.append(field.descr); + offsets.append(pybind11::int_(field.offset)); + } + auto dtype_ptr = pybind11::dtype(names, formats, offsets, itemsize).release().ptr(); + + // There is an existing bug in NumPy (as of v1.11): trailing bytes are + // not encoded explicitly into the format string. This will supposedly + // get fixed in v1.12; for further details, see these: + // - https://github.com/numpy/numpy/issues/7797 + // - https://github.com/numpy/numpy/pull/7798 + // Because of this, we won't use numpy's logic to generate buffer format + // strings and will just do it ourselves. + std::vector ordered_fields(fields); + std::sort(ordered_fields.begin(), ordered_fields.end(), + [](const field_descriptor &a, const field_descriptor &b) { return a.offset < b.offset; }); + size_t offset = 0; + std::ostringstream oss; + oss << "T{"; + for (auto& field : ordered_fields) { + if (field.offset > offset) + oss << (field.offset - offset) << 'x'; + // note that '=' is required to cover the case of unaligned fields + oss << '=' << field.format << ':' << field.name << ':'; + offset = field.offset + field.size; + } + if (itemsize > offset) + oss << (itemsize - offset) << 'x'; + oss << '}'; + auto format_str = oss.str(); + + // Sanity check: verify that NumPy properly parses our buffer format string + auto& api = npy_api::get(); + auto arr = array(buffer_info(nullptr, itemsize, format_str, 1)); + if (!api.PyArray_EquivTypes_(dtype_ptr, arr.dtype().ptr())) + pybind11_fail("NumPy: invalid buffer descriptor!"); + + auto tindex = std::type_index(tinfo); + numpy_internals.registered_dtypes[tindex] = { dtype_ptr, format_str }; + get_internals().direct_conversions[tindex].push_back(direct_converter); +} + template struct npy_format_descriptor::value>> { static PYBIND11_DESCR name() { return _("struct"); } @@ -699,56 +759,9 @@ struct npy_format_descriptor::value>> { return format_str; } - static void register_dtype(std::initializer_list fields) { - auto& numpy_internals = get_numpy_internals(); - if (numpy_internals.get_type_info(false)) - pybind11_fail("NumPy: dtype is already registered"); - - list names, formats, offsets; - for (auto field : fields) { - if (!field.descr) - pybind11_fail(std::string("NumPy: unsupported field dtype: `") + - field.name + "` @ " + typeid(T).name()); - names.append(PYBIND11_STR_TYPE(field.name)); - formats.append(field.descr); - offsets.append(pybind11::int_(field.offset)); - } - auto dtype_ptr = pybind11::dtype(names, formats, offsets, sizeof(T)).release().ptr(); - - // There is an existing bug in NumPy (as of v1.11): trailing bytes are - // not encoded explicitly into the format string. This will supposedly - // get fixed in v1.12; for further details, see these: - // - https://github.com/numpy/numpy/issues/7797 - // - https://github.com/numpy/numpy/pull/7798 - // Because of this, we won't use numpy's logic to generate buffer format - // strings and will just do it ourselves. - std::vector ordered_fields(fields); - std::sort(ordered_fields.begin(), ordered_fields.end(), - [](const field_descriptor &a, const field_descriptor &b) { return a.offset < b.offset; }); - size_t offset = 0; - std::ostringstream oss; - oss << "T{"; - for (auto& field : ordered_fields) { - if (field.offset > offset) - oss << (field.offset - offset) << 'x'; - // note that '=' is required to cover the case of unaligned fields - oss << '=' << field.format << ':' << field.name << ':'; - offset = field.offset + field.size; - } - if (sizeof(T) > offset) - oss << (sizeof(T) - offset) << 'x'; - oss << '}'; - auto format_str = oss.str(); - - // Sanity check: verify that NumPy properly parses our buffer format string - auto& api = npy_api::get(); - auto arr = array(buffer_info(nullptr, sizeof(T), format_str, 1)); - if (!api.PyArray_EquivTypes_(dtype_ptr, arr.dtype().ptr())) - pybind11_fail("NumPy: invalid buffer descriptor!"); - - auto tindex = std::type_index(typeid(T)); - numpy_internals.registered_dtypes[tindex] = { dtype_ptr, format_str }; - get_internals().direct_conversions[tindex].push_back(direct_converter); + static void register_dtype(const std::initializer_list& fields) { + register_structured_dtype(fields, typeid(typename std::remove_cv::type), + sizeof(T), &direct_converter); } private: