@@ -80,14 +80,18 @@ struct numpy_type_info {
80
80
struct numpy_internals {
81
81
std::unordered_map<std::type_index, numpy_type_info> registered_dtypes;
82
82
83
- template < typename T> numpy_type_info *get_type_info (bool throw_if_missing = true ) {
84
- auto it = registered_dtypes.find (std::type_index (typeid (T) ));
83
+ numpy_type_info *get_type_info (const std::type_info& tinfo, bool throw_if_missing = true ) {
84
+ auto it = registered_dtypes.find (std::type_index (tinfo ));
85
85
if (it != registered_dtypes.end ())
86
86
return &(it->second );
87
87
if (throw_if_missing)
88
- pybind11_fail (std::string (" NumPy type info missing for " ) + typeid (T) .name ());
88
+ pybind11_fail (std::string (" NumPy type info missing for " ) + tinfo .name ());
89
89
return nullptr ;
90
90
}
91
+
92
+ template <typename T> numpy_type_info *get_type_info (bool throw_if_missing = true ) {
93
+ return get_type_info (typeid (T), throw_if_missing);
94
+ }
91
95
};
92
96
93
97
inline PYBIND11_NOINLINE numpy_internals& load_numpy_internals () {
@@ -683,6 +687,62 @@ struct field_descriptor {
683
687
dtype descr;
684
688
};
685
689
690
+ template <typename F>
691
+ static PYBIND11_NOINLINE void register_structured_dtype (
692
+ const F& fields, const std::type_info& tinfo, size_t itemsize,
693
+ bool (*direct_converter)(PyObject *, void *&))
694
+ {
695
+ auto & numpy_internals = get_numpy_internals ();
696
+ if (numpy_internals.get_type_info (tinfo, false ))
697
+ pybind11_fail (" NumPy: dtype is already registered" );
698
+
699
+ list names, formats, offsets;
700
+ for (auto field : fields) {
701
+ if (!field.descr )
702
+ pybind11_fail (std::string (" NumPy: unsupported field dtype: `" ) +
703
+ field.name + " ` @ " + tinfo.name ());
704
+ names.append (PYBIND11_STR_TYPE (field.name ));
705
+ formats.append (field.descr );
706
+ offsets.append (pybind11::int_ (field.offset ));
707
+ }
708
+ auto dtype_ptr = pybind11::dtype (names, formats, offsets, itemsize).release ().ptr ();
709
+
710
+ // There is an existing bug in NumPy (as of v1.11): trailing bytes are
711
+ // not encoded explicitly into the format string. This will supposedly
712
+ // get fixed in v1.12; for further details, see these:
713
+ // - https://github.com/numpy/numpy/issues/7797
714
+ // - https://github.com/numpy/numpy/pull/7798
715
+ // Because of this, we won't use numpy's logic to generate buffer format
716
+ // strings and will just do it ourselves.
717
+ std::vector<field_descriptor> ordered_fields (fields);
718
+ std::sort (ordered_fields.begin (), ordered_fields.end (),
719
+ [](const field_descriptor &a, const field_descriptor &b) { return a.offset < b.offset ; });
720
+ size_t offset = 0 ;
721
+ std::ostringstream oss;
722
+ oss << " T{" ;
723
+ for (auto & field : ordered_fields) {
724
+ if (field.offset > offset)
725
+ oss << (field.offset - offset) << ' x' ;
726
+ // note that '=' is required to cover the case of unaligned fields
727
+ oss << ' =' << field.format << ' :' << field.name << ' :' ;
728
+ offset = field.offset + field.size ;
729
+ }
730
+ if (itemsize > offset)
731
+ oss << (itemsize - offset) << ' x' ;
732
+ oss << ' }' ;
733
+ auto format_str = oss.str ();
734
+
735
+ // Sanity check: verify that NumPy properly parses our buffer format string
736
+ auto & api = npy_api::get ();
737
+ auto arr = array (buffer_info (nullptr , itemsize, format_str, 1 ));
738
+ if (!api.PyArray_EquivTypes_ (dtype_ptr, arr.dtype ().ptr ()))
739
+ pybind11_fail (" NumPy: invalid buffer descriptor!" );
740
+
741
+ auto tindex = std::type_index (tinfo);
742
+ numpy_internals.registered_dtypes [tindex] = { dtype_ptr, format_str };
743
+ get_internals ().direct_conversions [tindex].push_back (direct_converter);
744
+ }
745
+
686
746
template <typename T>
687
747
struct npy_format_descriptor <T, enable_if_t <is_pod_struct<T>::value>> {
688
748
static PYBIND11_DESCR name () { return _ (" struct" ); }
@@ -696,56 +756,8 @@ struct npy_format_descriptor<T, enable_if_t<is_pod_struct<T>::value>> {
696
756
return format_str;
697
757
}
698
758
699
- static void register_dtype (std::initializer_list<field_descriptor> fields) {
700
- auto & numpy_internals = get_numpy_internals ();
701
- if (numpy_internals.get_type_info <T>(false ))
702
- pybind11_fail (" NumPy: dtype is already registered" );
703
-
704
- list names, formats, offsets;
705
- for (auto field : fields) {
706
- if (!field.descr )
707
- pybind11_fail (std::string (" NumPy: unsupported field dtype: `" ) +
708
- field.name + " ` @ " + typeid (T).name ());
709
- names.append (PYBIND11_STR_TYPE (field.name ));
710
- formats.append (field.descr );
711
- offsets.append (pybind11::int_ (field.offset ));
712
- }
713
- auto dtype_ptr = pybind11::dtype (names, formats, offsets, sizeof (T)).release ().ptr ();
714
-
715
- // There is an existing bug in NumPy (as of v1.11): trailing bytes are
716
- // not encoded explicitly into the format string. This will supposedly
717
- // get fixed in v1.12; for further details, see these:
718
- // - https://github.com/numpy/numpy/issues/7797
719
- // - https://github.com/numpy/numpy/pull/7798
720
- // Because of this, we won't use numpy's logic to generate buffer format
721
- // strings and will just do it ourselves.
722
- std::vector<field_descriptor> ordered_fields (fields);
723
- std::sort (ordered_fields.begin (), ordered_fields.end (),
724
- [](const field_descriptor &a, const field_descriptor &b) { return a.offset < b.offset ; });
725
- size_t offset = 0 ;
726
- std::ostringstream oss;
727
- oss << " T{" ;
728
- for (auto & field : ordered_fields) {
729
- if (field.offset > offset)
730
- oss << (field.offset - offset) << ' x' ;
731
- // note that '=' is required to cover the case of unaligned fields
732
- oss << ' =' << field.format << ' :' << field.name << ' :' ;
733
- offset = field.offset + field.size ;
734
- }
735
- if (sizeof (T) > offset)
736
- oss << (sizeof (T) - offset) << ' x' ;
737
- oss << ' }' ;
738
- auto format_str = oss.str ();
739
-
740
- // Sanity check: verify that NumPy properly parses our buffer format string
741
- auto & api = npy_api::get ();
742
- auto arr = array (buffer_info (nullptr , sizeof (T), format_str, 1 ));
743
- if (!api.PyArray_EquivTypes_ (dtype_ptr, arr.dtype ().ptr ()))
744
- pybind11_fail (" NumPy: invalid buffer descriptor!" );
745
-
746
- auto tindex = std::type_index (typeid (T));
747
- numpy_internals.registered_dtypes [tindex] = { dtype_ptr, format_str };
748
- get_internals ().direct_conversions [tindex].push_back (direct_converter);
759
+ static void register_dtype (const std::initializer_list<field_descriptor>& fields) {
760
+ register_structured_dtype (fields, typeid (T), sizeof (T), &direct_converter);
749
761
}
750
762
751
763
private:
0 commit comments