@@ -81,14 +81,18 @@ struct numpy_type_info {
81
81
struct numpy_internals {
82
82
std::unordered_map<std::type_index, numpy_type_info> registered_dtypes;
83
83
84
- template < typename T> numpy_type_info *get_type_info (bool throw_if_missing = true ) {
85
- auto it = registered_dtypes.find (std::type_index (typeid (T) ));
84
+ numpy_type_info *get_type_info (const std::type_info& tinfo, bool throw_if_missing = true ) {
85
+ auto it = registered_dtypes.find (std::type_index (tinfo ));
86
86
if (it != registered_dtypes.end ())
87
87
return &(it->second );
88
88
if (throw_if_missing)
89
- pybind11_fail (std::string (" NumPy type info missing for " ) + typeid (T) .name ());
89
+ pybind11_fail (std::string (" NumPy type info missing for " ) + tinfo .name ());
90
90
return nullptr ;
91
91
}
92
+
93
+ template <typename T> numpy_type_info *get_type_info (bool throw_if_missing = true ) {
94
+ return get_type_info (typeid (typename std::remove_cv<T>::type), throw_if_missing);
95
+ }
92
96
};
93
97
94
98
inline PYBIND11_NOINLINE void load_numpy_internals (numpy_internals* &ptr) {
@@ -686,6 +690,62 @@ struct field_descriptor {
686
690
dtype descr;
687
691
};
688
692
693
+ inline PYBIND11_NOINLINE void register_structured_dtype (
694
+ const std::initializer_list<field_descriptor>& fields,
695
+ const std::type_info& tinfo, size_t itemsize,
696
+ bool (*direct_converter)(PyObject *, void *&))
697
+ {
698
+ auto & numpy_internals = get_numpy_internals ();
699
+ if (numpy_internals.get_type_info (tinfo, false ))
700
+ pybind11_fail (" NumPy: dtype is already registered" );
701
+
702
+ list names, formats, offsets;
703
+ for (auto field : fields) {
704
+ if (!field.descr )
705
+ pybind11_fail (std::string (" NumPy: unsupported field dtype: `" ) +
706
+ field.name + " ` @ " + tinfo.name ());
707
+ names.append (PYBIND11_STR_TYPE (field.name ));
708
+ formats.append (field.descr );
709
+ offsets.append (pybind11::int_ (field.offset ));
710
+ }
711
+ auto dtype_ptr = pybind11::dtype (names, formats, offsets, itemsize).release ().ptr ();
712
+
713
+ // There is an existing bug in NumPy (as of v1.11): trailing bytes are
714
+ // not encoded explicitly into the format string. This will supposedly
715
+ // get fixed in v1.12; for further details, see these:
716
+ // - https://github.com/numpy/numpy/issues/7797
717
+ // - https://github.com/numpy/numpy/pull/7798
718
+ // Because of this, we won't use numpy's logic to generate buffer format
719
+ // strings and will just do it ourselves.
720
+ std::vector<field_descriptor> ordered_fields (fields);
721
+ std::sort (ordered_fields.begin (), ordered_fields.end (),
722
+ [](const field_descriptor &a, const field_descriptor &b) { return a.offset < b.offset ; });
723
+ size_t offset = 0 ;
724
+ std::ostringstream oss;
725
+ oss << " T{" ;
726
+ for (auto & field : ordered_fields) {
727
+ if (field.offset > offset)
728
+ oss << (field.offset - offset) << ' x' ;
729
+ // note that '=' is required to cover the case of unaligned fields
730
+ oss << ' =' << field.format << ' :' << field.name << ' :' ;
731
+ offset = field.offset + field.size ;
732
+ }
733
+ if (itemsize > offset)
734
+ oss << (itemsize - offset) << ' x' ;
735
+ oss << ' }' ;
736
+ auto format_str = oss.str ();
737
+
738
+ // Sanity check: verify that NumPy properly parses our buffer format string
739
+ auto & api = npy_api::get ();
740
+ auto arr = array (buffer_info (nullptr , itemsize, format_str, 1 ));
741
+ if (!api.PyArray_EquivTypes_ (dtype_ptr, arr.dtype ().ptr ()))
742
+ pybind11_fail (" NumPy: invalid buffer descriptor!" );
743
+
744
+ auto tindex = std::type_index (tinfo);
745
+ numpy_internals.registered_dtypes [tindex] = { dtype_ptr, format_str };
746
+ get_internals ().direct_conversions [tindex].push_back (direct_converter);
747
+ }
748
+
689
749
template <typename T>
690
750
struct npy_format_descriptor <T, enable_if_t <is_pod_struct<T>::value>> {
691
751
static PYBIND11_DESCR name () { return _ (" struct" ); }
@@ -699,56 +759,9 @@ struct npy_format_descriptor<T, enable_if_t<is_pod_struct<T>::value>> {
699
759
return format_str;
700
760
}
701
761
702
- static void register_dtype (std::initializer_list<field_descriptor> fields) {
703
- auto & numpy_internals = get_numpy_internals ();
704
- if (numpy_internals.get_type_info <T>(false ))
705
- pybind11_fail (" NumPy: dtype is already registered" );
706
-
707
- list names, formats, offsets;
708
- for (auto field : fields) {
709
- if (!field.descr )
710
- pybind11_fail (std::string (" NumPy: unsupported field dtype: `" ) +
711
- field.name + " ` @ " + typeid (T).name ());
712
- names.append (PYBIND11_STR_TYPE (field.name ));
713
- formats.append (field.descr );
714
- offsets.append (pybind11::int_ (field.offset ));
715
- }
716
- auto dtype_ptr = pybind11::dtype (names, formats, offsets, sizeof (T)).release ().ptr ();
717
-
718
- // There is an existing bug in NumPy (as of v1.11): trailing bytes are
719
- // not encoded explicitly into the format string. This will supposedly
720
- // get fixed in v1.12; for further details, see these:
721
- // - https://github.com/numpy/numpy/issues/7797
722
- // - https://github.com/numpy/numpy/pull/7798
723
- // Because of this, we won't use numpy's logic to generate buffer format
724
- // strings and will just do it ourselves.
725
- std::vector<field_descriptor> ordered_fields (fields);
726
- std::sort (ordered_fields.begin (), ordered_fields.end (),
727
- [](const field_descriptor &a, const field_descriptor &b) { return a.offset < b.offset ; });
728
- size_t offset = 0 ;
729
- std::ostringstream oss;
730
- oss << " T{" ;
731
- for (auto & field : ordered_fields) {
732
- if (field.offset > offset)
733
- oss << (field.offset - offset) << ' x' ;
734
- // note that '=' is required to cover the case of unaligned fields
735
- oss << ' =' << field.format << ' :' << field.name << ' :' ;
736
- offset = field.offset + field.size ;
737
- }
738
- if (sizeof (T) > offset)
739
- oss << (sizeof (T) - offset) << ' x' ;
740
- oss << ' }' ;
741
- auto format_str = oss.str ();
742
-
743
- // Sanity check: verify that NumPy properly parses our buffer format string
744
- auto & api = npy_api::get ();
745
- auto arr = array (buffer_info (nullptr , sizeof (T), format_str, 1 ));
746
- if (!api.PyArray_EquivTypes_ (dtype_ptr, arr.dtype ().ptr ()))
747
- pybind11_fail (" NumPy: invalid buffer descriptor!" );
748
-
749
- auto tindex = std::type_index (typeid (T));
750
- numpy_internals.registered_dtypes [tindex] = { dtype_ptr, format_str };
751
- get_internals ().direct_conversions [tindex].push_back (direct_converter);
762
+ static void register_dtype (const std::initializer_list<field_descriptor>& fields) {
763
+ register_structured_dtype (fields, typeid (typename std::remove_cv<T>::type),
764
+ sizeof (T), &direct_converter);
752
765
}
753
766
754
767
private:
0 commit comments