Skip to content

Commit 0a53644

Browse files
committed
Move register_dtype() outside of the template
(avoid code bloat if possible)
1 parent 2231760 commit 0a53644

File tree

1 file changed

+65
-53
lines changed

1 file changed

+65
-53
lines changed

include/pybind11/numpy.h

Lines changed: 65 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -80,14 +80,18 @@ struct numpy_type_info {
8080
struct numpy_internals {
8181
std::unordered_map<std::type_index, numpy_type_info> registered_dtypes;
8282

83-
template<typename T> numpy_type_info *get_type_info(bool throw_if_missing = true) {
84-
auto it = registered_dtypes.find(std::type_index(typeid(T)));
83+
numpy_type_info *get_type_info(const std::type_info& tinfo, bool throw_if_missing = true) {
84+
auto it = registered_dtypes.find(std::type_index(tinfo));
8585
if (it != registered_dtypes.end())
8686
return &(it->second);
8787
if (throw_if_missing)
88-
pybind11_fail(std::string("NumPy type info missing for ") + typeid(T).name());
88+
pybind11_fail(std::string("NumPy type info missing for ") + tinfo.name());
8989
return nullptr;
9090
}
91+
92+
template<typename T> numpy_type_info *get_type_info(bool throw_if_missing = true) {
93+
return get_type_info(typeid(T), throw_if_missing);
94+
}
9195
};
9296

9397
inline PYBIND11_NOINLINE void load_numpy_internals(numpy_internals* &ptr) {
@@ -685,6 +689,62 @@ struct field_descriptor {
685689
dtype descr;
686690
};
687691

692+
template<typename F>
693+
static PYBIND11_NOINLINE void register_structured_dtype(
694+
const F& fields, const std::type_info& tinfo, size_t itemsize,
695+
bool (*direct_converter)(PyObject *, void *&))
696+
{
697+
auto& numpy_internals = get_numpy_internals();
698+
if (numpy_internals.get_type_info(tinfo, false))
699+
pybind11_fail("NumPy: dtype is already registered");
700+
701+
list names, formats, offsets;
702+
for (auto field : fields) {
703+
if (!field.descr)
704+
pybind11_fail(std::string("NumPy: unsupported field dtype: `") +
705+
field.name + "` @ " + tinfo.name());
706+
names.append(PYBIND11_STR_TYPE(field.name));
707+
formats.append(field.descr);
708+
offsets.append(pybind11::int_(field.offset));
709+
}
710+
auto dtype_ptr = pybind11::dtype(names, formats, offsets, itemsize).release().ptr();
711+
712+
// There is an existing bug in NumPy (as of v1.11): trailing bytes are
713+
// not encoded explicitly into the format string. This will supposedly
714+
// get fixed in v1.12; for further details, see these:
715+
// - https://github.com/numpy/numpy/issues/7797
716+
// - https://github.com/numpy/numpy/pull/7798
717+
// Because of this, we won't use numpy's logic to generate buffer format
718+
// strings and will just do it ourselves.
719+
std::vector<field_descriptor> ordered_fields(fields);
720+
std::sort(ordered_fields.begin(), ordered_fields.end(),
721+
[](const field_descriptor &a, const field_descriptor &b) { return a.offset < b.offset; });
722+
size_t offset = 0;
723+
std::ostringstream oss;
724+
oss << "T{";
725+
for (auto& field : ordered_fields) {
726+
if (field.offset > offset)
727+
oss << (field.offset - offset) << 'x';
728+
// note that '=' is required to cover the case of unaligned fields
729+
oss << '=' << field.format << ':' << field.name << ':';
730+
offset = field.offset + field.size;
731+
}
732+
if (itemsize > offset)
733+
oss << (itemsize - offset) << 'x';
734+
oss << '}';
735+
auto format_str = oss.str();
736+
737+
// Sanity check: verify that NumPy properly parses our buffer format string
738+
auto& api = npy_api::get();
739+
auto arr = array(buffer_info(nullptr, itemsize, format_str, 1));
740+
if (!api.PyArray_EquivTypes_(dtype_ptr, arr.dtype().ptr()))
741+
pybind11_fail("NumPy: invalid buffer descriptor!");
742+
743+
auto tindex = std::type_index(tinfo);
744+
numpy_internals.registered_dtypes[tindex] = { dtype_ptr, format_str };
745+
get_internals().direct_conversions[tindex].push_back(direct_converter);
746+
}
747+
688748
template <typename T>
689749
struct npy_format_descriptor<T, enable_if_t<is_pod_struct<T>::value>> {
690750
static PYBIND11_DESCR name() { return _("struct"); }
@@ -698,56 +758,8 @@ struct npy_format_descriptor<T, enable_if_t<is_pod_struct<T>::value>> {
698758
return format_str;
699759
}
700760

701-
static void register_dtype(std::initializer_list<field_descriptor> fields) {
702-
auto& numpy_internals = get_numpy_internals();
703-
if (numpy_internals.get_type_info<T>(false))
704-
pybind11_fail("NumPy: dtype is already registered");
705-
706-
list names, formats, offsets;
707-
for (auto field : fields) {
708-
if (!field.descr)
709-
pybind11_fail(std::string("NumPy: unsupported field dtype: `") +
710-
field.name + "` @ " + typeid(T).name());
711-
names.append(PYBIND11_STR_TYPE(field.name));
712-
formats.append(field.descr);
713-
offsets.append(pybind11::int_(field.offset));
714-
}
715-
auto dtype_ptr = pybind11::dtype(names, formats, offsets, sizeof(T)).release().ptr();
716-
717-
// There is an existing bug in NumPy (as of v1.11): trailing bytes are
718-
// not encoded explicitly into the format string. This will supposedly
719-
// get fixed in v1.12; for further details, see these:
720-
// - https://github.com/numpy/numpy/issues/7797
721-
// - https://github.com/numpy/numpy/pull/7798
722-
// Because of this, we won't use numpy's logic to generate buffer format
723-
// strings and will just do it ourselves.
724-
std::vector<field_descriptor> ordered_fields(fields);
725-
std::sort(ordered_fields.begin(), ordered_fields.end(),
726-
[](const field_descriptor &a, const field_descriptor &b) { return a.offset < b.offset; });
727-
size_t offset = 0;
728-
std::ostringstream oss;
729-
oss << "T{";
730-
for (auto& field : ordered_fields) {
731-
if (field.offset > offset)
732-
oss << (field.offset - offset) << 'x';
733-
// note that '=' is required to cover the case of unaligned fields
734-
oss << '=' << field.format << ':' << field.name << ':';
735-
offset = field.offset + field.size;
736-
}
737-
if (sizeof(T) > offset)
738-
oss << (sizeof(T) - offset) << 'x';
739-
oss << '}';
740-
auto format_str = oss.str();
741-
742-
// Sanity check: verify that NumPy properly parses our buffer format string
743-
auto& api = npy_api::get();
744-
auto arr = array(buffer_info(nullptr, sizeof(T), format_str, 1));
745-
if (!api.PyArray_EquivTypes_(dtype_ptr, arr.dtype().ptr()))
746-
pybind11_fail("NumPy: invalid buffer descriptor!");
747-
748-
auto tindex = std::type_index(typeid(T));
749-
numpy_internals.registered_dtypes[tindex] = { dtype_ptr, format_str };
750-
get_internals().direct_conversions[tindex].push_back(direct_converter);
761+
static void register_dtype(const std::initializer_list<field_descriptor>& fields) {
762+
register_structured_dtype(fields, typeid(T), sizeof(T), &direct_converter);
751763
}
752764

753765
private:

0 commit comments

Comments
 (0)