@@ -80,14 +80,18 @@ struct numpy_type_info {
80
80
struct numpy_internals {
81
81
std::unordered_map<std::type_index, numpy_type_info> registered_dtypes;
82
82
83
- template < typename T> numpy_type_info *get_type_info (bool throw_if_missing = true ) {
84
- auto it = registered_dtypes.find (std::type_index (typeid (T) ));
83
+ numpy_type_info *get_type_info (const std::type_info& tinfo, bool throw_if_missing = true ) {
84
+ auto it = registered_dtypes.find (std::type_index (tinfo ));
85
85
if (it != registered_dtypes.end ())
86
86
return &(it->second );
87
87
if (throw_if_missing)
88
- pybind11_fail (std::string (" NumPy type info missing for " ) + typeid (T) .name ());
88
+ pybind11_fail (std::string (" NumPy type info missing for " ) + tinfo .name ());
89
89
return nullptr ;
90
90
}
91
+
92
+ template <typename T> numpy_type_info *get_type_info (bool throw_if_missing = true ) {
93
+ return get_type_info (typeid (T), throw_if_missing);
94
+ }
91
95
};
92
96
93
97
inline PYBIND11_NOINLINE void load_numpy_internals (numpy_internals* &ptr) {
@@ -685,6 +689,62 @@ struct field_descriptor {
685
689
dtype descr;
686
690
};
687
691
692
+ template <typename F>
693
+ static PYBIND11_NOINLINE void register_structured_dtype (
694
+ const F& fields, const std::type_info& tinfo, size_t itemsize,
695
+ bool (*direct_converter)(PyObject *, void *&))
696
+ {
697
+ auto & numpy_internals = get_numpy_internals ();
698
+ if (numpy_internals.get_type_info (tinfo, false ))
699
+ pybind11_fail (" NumPy: dtype is already registered" );
700
+
701
+ list names, formats, offsets;
702
+ for (auto field : fields) {
703
+ if (!field.descr )
704
+ pybind11_fail (std::string (" NumPy: unsupported field dtype: `" ) +
705
+ field.name + " ` @ " + tinfo.name ());
706
+ names.append (PYBIND11_STR_TYPE (field.name ));
707
+ formats.append (field.descr );
708
+ offsets.append (pybind11::int_ (field.offset ));
709
+ }
710
+ auto dtype_ptr = pybind11::dtype (names, formats, offsets, itemsize).release ().ptr ();
711
+
712
+ // There is an existing bug in NumPy (as of v1.11): trailing bytes are
713
+ // not encoded explicitly into the format string. This will supposedly
714
+ // get fixed in v1.12; for further details, see these:
715
+ // - https://github.com/numpy/numpy/issues/7797
716
+ // - https://github.com/numpy/numpy/pull/7798
717
+ // Because of this, we won't use numpy's logic to generate buffer format
718
+ // strings and will just do it ourselves.
719
+ std::vector<field_descriptor> ordered_fields (fields);
720
+ std::sort (ordered_fields.begin (), ordered_fields.end (),
721
+ [](const field_descriptor &a, const field_descriptor &b) { return a.offset < b.offset ; });
722
+ size_t offset = 0 ;
723
+ std::ostringstream oss;
724
+ oss << " T{" ;
725
+ for (auto & field : ordered_fields) {
726
+ if (field.offset > offset)
727
+ oss << (field.offset - offset) << ' x' ;
728
+ // note that '=' is required to cover the case of unaligned fields
729
+ oss << ' =' << field.format << ' :' << field.name << ' :' ;
730
+ offset = field.offset + field.size ;
731
+ }
732
+ if (itemsize > offset)
733
+ oss << (itemsize - offset) << ' x' ;
734
+ oss << ' }' ;
735
+ auto format_str = oss.str ();
736
+
737
+ // Sanity check: verify that NumPy properly parses our buffer format string
738
+ auto & api = npy_api::get ();
739
+ auto arr = array (buffer_info (nullptr , itemsize, format_str, 1 ));
740
+ if (!api.PyArray_EquivTypes_ (dtype_ptr, arr.dtype ().ptr ()))
741
+ pybind11_fail (" NumPy: invalid buffer descriptor!" );
742
+
743
+ auto tindex = std::type_index (tinfo);
744
+ numpy_internals.registered_dtypes [tindex] = { dtype_ptr, format_str };
745
+ get_internals ().direct_conversions [tindex].push_back (direct_converter);
746
+ }
747
+
688
748
template <typename T>
689
749
struct npy_format_descriptor <T, enable_if_t <is_pod_struct<T>::value>> {
690
750
static PYBIND11_DESCR name () { return _ (" struct" ); }
@@ -698,56 +758,8 @@ struct npy_format_descriptor<T, enable_if_t<is_pod_struct<T>::value>> {
698
758
return format_str;
699
759
}
700
760
701
- static void register_dtype (std::initializer_list<field_descriptor> fields) {
702
- auto & numpy_internals = get_numpy_internals ();
703
- if (numpy_internals.get_type_info <T>(false ))
704
- pybind11_fail (" NumPy: dtype is already registered" );
705
-
706
- list names, formats, offsets;
707
- for (auto field : fields) {
708
- if (!field.descr )
709
- pybind11_fail (std::string (" NumPy: unsupported field dtype: `" ) +
710
- field.name + " ` @ " + typeid (T).name ());
711
- names.append (PYBIND11_STR_TYPE (field.name ));
712
- formats.append (field.descr );
713
- offsets.append (pybind11::int_ (field.offset ));
714
- }
715
- auto dtype_ptr = pybind11::dtype (names, formats, offsets, sizeof (T)).release ().ptr ();
716
-
717
- // There is an existing bug in NumPy (as of v1.11): trailing bytes are
718
- // not encoded explicitly into the format string. This will supposedly
719
- // get fixed in v1.12; for further details, see these:
720
- // - https://github.com/numpy/numpy/issues/7797
721
- // - https://github.com/numpy/numpy/pull/7798
722
- // Because of this, we won't use numpy's logic to generate buffer format
723
- // strings and will just do it ourselves.
724
- std::vector<field_descriptor> ordered_fields (fields);
725
- std::sort (ordered_fields.begin (), ordered_fields.end (),
726
- [](const field_descriptor &a, const field_descriptor &b) { return a.offset < b.offset ; });
727
- size_t offset = 0 ;
728
- std::ostringstream oss;
729
- oss << " T{" ;
730
- for (auto & field : ordered_fields) {
731
- if (field.offset > offset)
732
- oss << (field.offset - offset) << ' x' ;
733
- // note that '=' is required to cover the case of unaligned fields
734
- oss << ' =' << field.format << ' :' << field.name << ' :' ;
735
- offset = field.offset + field.size ;
736
- }
737
- if (sizeof (T) > offset)
738
- oss << (sizeof (T) - offset) << ' x' ;
739
- oss << ' }' ;
740
- auto format_str = oss.str ();
741
-
742
- // Sanity check: verify that NumPy properly parses our buffer format string
743
- auto & api = npy_api::get ();
744
- auto arr = array (buffer_info (nullptr , sizeof (T), format_str, 1 ));
745
- if (!api.PyArray_EquivTypes_ (dtype_ptr, arr.dtype ().ptr ()))
746
- pybind11_fail (" NumPy: invalid buffer descriptor!" );
747
-
748
- auto tindex = std::type_index (typeid (T));
749
- numpy_internals.registered_dtypes [tindex] = { dtype_ptr, format_str };
750
- get_internals ().direct_conversions [tindex].push_back (direct_converter);
761
+ static void register_dtype (const std::initializer_list<field_descriptor>& fields) {
762
+ register_structured_dtype (fields, typeid (T), sizeof (T), &direct_converter);
751
763
}
752
764
753
765
private:
0 commit comments