|
21 | 21 | #include <initializer_list>
|
22 | 22 | #include <functional>
|
23 | 23 | #include <utility>
|
| 24 | +#include <typeindex> |
24 | 25 |
|
25 | 26 | #if defined(_MSC_VER)
|
26 | 27 | # pragma warning(push)
|
@@ -72,6 +73,39 @@ struct PyVoidScalarObject_Proxy {
|
72 | 73 | PyObject *base;
|
73 | 74 | };
|
74 | 75 |
|
| 76 | +struct numpy_type_info { |
| 77 | + PyObject* dtype_ptr; |
| 78 | + std::string format_str; |
| 79 | +}; |
| 80 | + |
| 81 | +struct numpy_internals { |
| 82 | + std::unordered_map<std::type_index, numpy_type_info> registered_dtypes; |
| 83 | + |
| 84 | + numpy_type_info *get_type_info(const std::type_info& tinfo, bool throw_if_missing = true) { |
| 85 | + auto it = registered_dtypes.find(std::type_index(tinfo)); |
| 86 | + if (it != registered_dtypes.end()) |
| 87 | + return &(it->second); |
| 88 | + if (throw_if_missing) |
| 89 | + pybind11_fail(std::string("NumPy type info missing for ") + tinfo.name()); |
| 90 | + return nullptr; |
| 91 | + } |
| 92 | + |
| 93 | + template<typename T> numpy_type_info *get_type_info(bool throw_if_missing = true) { |
| 94 | + return get_type_info(typeid(typename std::remove_cv<T>::type), throw_if_missing); |
| 95 | + } |
| 96 | +}; |
| 97 | + |
| 98 | +inline PYBIND11_NOINLINE void load_numpy_internals(numpy_internals* &ptr) { |
| 99 | + ptr = &get_or_create_shared_data<numpy_internals>("_numpy_internals"); |
| 100 | +} |
| 101 | + |
| 102 | +inline numpy_internals& get_numpy_internals() { |
| 103 | + static numpy_internals* ptr = nullptr; |
| 104 | + if (!ptr) |
| 105 | + load_numpy_internals(ptr); |
| 106 | + return *ptr; |
| 107 | +} |
| 108 | + |
75 | 109 | struct npy_api {
|
76 | 110 | enum constants {
|
77 | 111 | NPY_C_CONTIGUOUS_ = 0x0001,
|
@@ -656,99 +690,100 @@ struct field_descriptor {
|
656 | 690 | dtype descr;
|
657 | 691 | };
|
658 | 692 |
|
| 693 | +inline PYBIND11_NOINLINE void register_structured_dtype( |
| 694 | + const std::initializer_list<field_descriptor>& fields, |
| 695 | + const std::type_info& tinfo, size_t itemsize, |
| 696 | + bool (*direct_converter)(PyObject *, void *&)) |
| 697 | +{ |
| 698 | + auto& numpy_internals = get_numpy_internals(); |
| 699 | + if (numpy_internals.get_type_info(tinfo, false)) |
| 700 | + pybind11_fail("NumPy: dtype is already registered"); |
| 701 | + |
| 702 | + list names, formats, offsets; |
| 703 | + for (auto field : fields) { |
| 704 | + if (!field.descr) |
| 705 | + pybind11_fail(std::string("NumPy: unsupported field dtype: `") + |
| 706 | + field.name + "` @ " + tinfo.name()); |
| 707 | + names.append(PYBIND11_STR_TYPE(field.name)); |
| 708 | + formats.append(field.descr); |
| 709 | + offsets.append(pybind11::int_(field.offset)); |
| 710 | + } |
| 711 | + auto dtype_ptr = pybind11::dtype(names, formats, offsets, itemsize).release().ptr(); |
| 712 | + |
| 713 | + // There is an existing bug in NumPy (as of v1.11): trailing bytes are |
| 714 | + // not encoded explicitly into the format string. This will supposedly |
| 715 | + // get fixed in v1.12; for further details, see these: |
| 716 | + // - https://github.com/numpy/numpy/issues/7797 |
| 717 | + // - https://github.com/numpy/numpy/pull/7798 |
| 718 | + // Because of this, we won't use numpy's logic to generate buffer format |
| 719 | + // strings and will just do it ourselves. |
| 720 | + std::vector<field_descriptor> ordered_fields(fields); |
| 721 | + std::sort(ordered_fields.begin(), ordered_fields.end(), |
| 722 | + [](const field_descriptor &a, const field_descriptor &b) { return a.offset < b.offset; }); |
| 723 | + size_t offset = 0; |
| 724 | + std::ostringstream oss; |
| 725 | + oss << "T{"; |
| 726 | + for (auto& field : ordered_fields) { |
| 727 | + if (field.offset > offset) |
| 728 | + oss << (field.offset - offset) << 'x'; |
| 729 | + // note that '=' is required to cover the case of unaligned fields |
| 730 | + oss << '=' << field.format << ':' << field.name << ':'; |
| 731 | + offset = field.offset + field.size; |
| 732 | + } |
| 733 | + if (itemsize > offset) |
| 734 | + oss << (itemsize - offset) << 'x'; |
| 735 | + oss << '}'; |
| 736 | + auto format_str = oss.str(); |
| 737 | + |
| 738 | + // Sanity check: verify that NumPy properly parses our buffer format string |
| 739 | + auto& api = npy_api::get(); |
| 740 | + auto arr = array(buffer_info(nullptr, itemsize, format_str, 1)); |
| 741 | + if (!api.PyArray_EquivTypes_(dtype_ptr, arr.dtype().ptr())) |
| 742 | + pybind11_fail("NumPy: invalid buffer descriptor!"); |
| 743 | + |
| 744 | + auto tindex = std::type_index(tinfo); |
| 745 | + numpy_internals.registered_dtypes[tindex] = { dtype_ptr, format_str }; |
| 746 | + get_internals().direct_conversions[tindex].push_back(direct_converter); |
| 747 | +} |
| 748 | + |
659 | 749 | template <typename T>
|
660 | 750 | struct npy_format_descriptor<T, enable_if_t<is_pod_struct<T>::value>> {
|
661 | 751 | static PYBIND11_DESCR name() { return _("struct"); }
|
662 | 752 |
|
663 | 753 | static pybind11::dtype dtype() {
|
664 |
| - if (!dtype_ptr) |
665 |
| - pybind11_fail("NumPy: unsupported buffer format!"); |
666 |
| - return object(dtype_ptr, true); |
| 754 | + return object(dtype_ptr(), true); |
667 | 755 | }
|
668 | 756 |
|
669 | 757 | static std::string format() {
|
670 |
| - if (!dtype_ptr) |
671 |
| - pybind11_fail("NumPy: unsupported buffer format!"); |
| 758 | + static auto format_str = get_numpy_internals().get_type_info<T>(true)->format_str; |
672 | 759 | return format_str;
|
673 | 760 | }
|
674 | 761 |
|
675 |
| - static void register_dtype(std::initializer_list<field_descriptor> fields) { |
676 |
| - if (dtype_ptr) |
677 |
| - pybind11_fail("NumPy: dtype is already registered"); |
678 |
| - |
679 |
| - list names, formats, offsets; |
680 |
| - for (auto field : fields) { |
681 |
| - if (!field.descr) |
682 |
| - pybind11_fail("NumPy: unsupported field dtype"); |
683 |
| - names.append(PYBIND11_STR_TYPE(field.name)); |
684 |
| - formats.append(field.descr); |
685 |
| - offsets.append(pybind11::int_(field.offset)); |
686 |
| - } |
687 |
| - dtype_ptr = pybind11::dtype(names, formats, offsets, sizeof(T)).release().ptr(); |
688 |
| - |
689 |
| - // There is an existing bug in NumPy (as of v1.11): trailing bytes are |
690 |
| - // not encoded explicitly into the format string. This will supposedly |
691 |
| - // get fixed in v1.12; for further details, see these: |
692 |
| - // - https://github.com/numpy/numpy/issues/7797 |
693 |
| - // - https://github.com/numpy/numpy/pull/7798 |
694 |
| - // Because of this, we won't use numpy's logic to generate buffer format |
695 |
| - // strings and will just do it ourselves. |
696 |
| - std::vector<field_descriptor> ordered_fields(fields); |
697 |
| - std::sort(ordered_fields.begin(), ordered_fields.end(), |
698 |
| - [](const field_descriptor &a, const field_descriptor &b) { |
699 |
| - return a.offset < b.offset; |
700 |
| - }); |
701 |
| - size_t offset = 0; |
702 |
| - std::ostringstream oss; |
703 |
| - oss << "T{"; |
704 |
| - for (auto& field : ordered_fields) { |
705 |
| - if (field.offset > offset) |
706 |
| - oss << (field.offset - offset) << 'x'; |
707 |
| - // note that '=' is required to cover the case of unaligned fields |
708 |
| - oss << '=' << field.format << ':' << field.name << ':'; |
709 |
| - offset = field.offset + field.size; |
710 |
| - } |
711 |
| - if (sizeof(T) > offset) |
712 |
| - oss << (sizeof(T) - offset) << 'x'; |
713 |
| - oss << '}'; |
714 |
| - format_str = oss.str(); |
715 |
| - |
716 |
| - // Sanity check: verify that NumPy properly parses our buffer format string |
717 |
| - auto& api = npy_api::get(); |
718 |
| - auto arr = array(buffer_info(nullptr, sizeof(T), format(), 1)); |
719 |
| - if (!api.PyArray_EquivTypes_(dtype_ptr, arr.dtype().ptr())) |
720 |
| - pybind11_fail("NumPy: invalid buffer descriptor!"); |
721 |
| - |
722 |
| - register_direct_converter(); |
| 762 | + static void register_dtype(const std::initializer_list<field_descriptor>& fields) { |
| 763 | + register_structured_dtype(fields, typeid(typename std::remove_cv<T>::type), |
| 764 | + sizeof(T), &direct_converter); |
723 | 765 | }
|
724 | 766 |
|
725 | 767 | private:
|
726 |
| - static std::string format_str; |
727 |
| - static PyObject* dtype_ptr; |
| 768 | + static PyObject* dtype_ptr() { |
| 769 | + static PyObject* ptr = get_numpy_internals().get_type_info<T>(true)->dtype_ptr; |
| 770 | + return ptr; |
| 771 | + } |
728 | 772 |
|
729 | 773 | static bool direct_converter(PyObject *obj, void*& value) {
|
730 | 774 | auto& api = npy_api::get();
|
731 | 775 | if (!PyObject_TypeCheck(obj, api.PyVoidArrType_Type_))
|
732 | 776 | return false;
|
733 | 777 | if (auto descr = object(api.PyArray_DescrFromScalar_(obj), false)) {
|
734 |
| - if (api.PyArray_EquivTypes_(dtype_ptr, descr.ptr())) { |
| 778 | + if (api.PyArray_EquivTypes_(dtype_ptr(), descr.ptr())) { |
735 | 779 | value = ((PyVoidScalarObject_Proxy *) obj)->obval;
|
736 | 780 | return true;
|
737 | 781 | }
|
738 | 782 | }
|
739 | 783 | return false;
|
740 | 784 | }
|
741 |
| - |
742 |
| - static void register_direct_converter() { |
743 |
| - get_internals().direct_conversions[std::type_index(typeid(T))].push_back(direct_converter); |
744 |
| - } |
745 | 785 | };
|
746 | 786 |
|
747 |
| -template <typename T> |
748 |
| -std::string npy_format_descriptor<T, enable_if_t<is_pod_struct<T>::value>>::format_str; |
749 |
| -template <typename T> |
750 |
| -PyObject* npy_format_descriptor<T, enable_if_t<is_pod_struct<T>::value>>::dtype_ptr = nullptr; |
751 |
| - |
752 | 787 | #define PYBIND11_FIELD_DESCRIPTOR_EX(T, Field, Name) \
|
753 | 788 | ::pybind11::detail::field_descriptor { \
|
754 | 789 | Name, offsetof(T, Field), sizeof(decltype(std::declval<T>().Field)), \
|
|
0 commit comments