Skip to content

Commit cc8ff16

Browse files
committed
Move register_dtype() outside of the template
(avoid code bloat if possible)
1 parent f95fda0 commit cc8ff16

File tree

1 file changed

+66
-53
lines changed

1 file changed

+66
-53
lines changed

include/pybind11/numpy.h

Lines changed: 66 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -81,14 +81,18 @@ struct numpy_type_info {
8181
struct numpy_internals {
8282
std::unordered_map<std::type_index, numpy_type_info> registered_dtypes;
8383

84-
template<typename T> numpy_type_info *get_type_info(bool throw_if_missing = true) {
85-
auto it = registered_dtypes.find(std::type_index(typeid(T)));
84+
numpy_type_info *get_type_info(const std::type_info& tinfo, bool throw_if_missing = true) {
85+
auto it = registered_dtypes.find(std::type_index(tinfo));
8686
if (it != registered_dtypes.end())
8787
return &(it->second);
8888
if (throw_if_missing)
89-
pybind11_fail(std::string("NumPy type info missing for ") + typeid(T).name());
89+
pybind11_fail(std::string("NumPy type info missing for ") + tinfo.name());
9090
return nullptr;
9191
}
92+
93+
template<typename T> numpy_type_info *get_type_info(bool throw_if_missing = true) {
94+
return get_type_info(typeid(typename std::remove_cv<T>::type), throw_if_missing);
95+
}
9296
};
9397

9498
inline PYBIND11_NOINLINE void load_numpy_internals(numpy_internals* &ptr) {
@@ -686,6 +690,62 @@ struct field_descriptor {
686690
dtype descr;
687691
};
688692

693+
inline PYBIND11_NOINLINE void register_structured_dtype(
694+
const std::initializer_list<field_descriptor>& fields,
695+
const std::type_info& tinfo, size_t itemsize,
696+
bool (*direct_converter)(PyObject *, void *&))
697+
{
698+
auto& numpy_internals = get_numpy_internals();
699+
if (numpy_internals.get_type_info(tinfo, false))
700+
pybind11_fail("NumPy: dtype is already registered");
701+
702+
list names, formats, offsets;
703+
for (auto field : fields) {
704+
if (!field.descr)
705+
pybind11_fail(std::string("NumPy: unsupported field dtype: `") +
706+
field.name + "` @ " + tinfo.name());
707+
names.append(PYBIND11_STR_TYPE(field.name));
708+
formats.append(field.descr);
709+
offsets.append(pybind11::int_(field.offset));
710+
}
711+
auto dtype_ptr = pybind11::dtype(names, formats, offsets, itemsize).release().ptr();
712+
713+
// There is an existing bug in NumPy (as of v1.11): trailing bytes are
714+
// not encoded explicitly into the format string. This will supposedly
715+
// get fixed in v1.12; for further details, see these:
716+
// - https://github.com/numpy/numpy/issues/7797
717+
// - https://github.com/numpy/numpy/pull/7798
718+
// Because of this, we won't use numpy's logic to generate buffer format
719+
// strings and will just do it ourselves.
720+
std::vector<field_descriptor> ordered_fields(fields);
721+
std::sort(ordered_fields.begin(), ordered_fields.end(),
722+
[](const field_descriptor &a, const field_descriptor &b) { return a.offset < b.offset; });
723+
size_t offset = 0;
724+
std::ostringstream oss;
725+
oss << "T{";
726+
for (auto& field : ordered_fields) {
727+
if (field.offset > offset)
728+
oss << (field.offset - offset) << 'x';
729+
// note that '=' is required to cover the case of unaligned fields
730+
oss << '=' << field.format << ':' << field.name << ':';
731+
offset = field.offset + field.size;
732+
}
733+
if (itemsize > offset)
734+
oss << (itemsize - offset) << 'x';
735+
oss << '}';
736+
auto format_str = oss.str();
737+
738+
// Sanity check: verify that NumPy properly parses our buffer format string
739+
auto& api = npy_api::get();
740+
auto arr = array(buffer_info(nullptr, itemsize, format_str, 1));
741+
if (!api.PyArray_EquivTypes_(dtype_ptr, arr.dtype().ptr()))
742+
pybind11_fail("NumPy: invalid buffer descriptor!");
743+
744+
auto tindex = std::type_index(tinfo);
745+
numpy_internals.registered_dtypes[tindex] = { dtype_ptr, format_str };
746+
get_internals().direct_conversions[tindex].push_back(direct_converter);
747+
}
748+
689749
template <typename T>
690750
struct npy_format_descriptor<T, enable_if_t<is_pod_struct<T>::value>> {
691751
static PYBIND11_DESCR name() { return _("struct"); }
@@ -699,56 +759,9 @@ struct npy_format_descriptor<T, enable_if_t<is_pod_struct<T>::value>> {
699759
return format_str;
700760
}
701761

702-
static void register_dtype(std::initializer_list<field_descriptor> fields) {
703-
auto& numpy_internals = get_numpy_internals();
704-
if (numpy_internals.get_type_info<T>(false))
705-
pybind11_fail("NumPy: dtype is already registered");
706-
707-
list names, formats, offsets;
708-
for (auto field : fields) {
709-
if (!field.descr)
710-
pybind11_fail(std::string("NumPy: unsupported field dtype: `") +
711-
field.name + "` @ " + typeid(T).name());
712-
names.append(PYBIND11_STR_TYPE(field.name));
713-
formats.append(field.descr);
714-
offsets.append(pybind11::int_(field.offset));
715-
}
716-
auto dtype_ptr = pybind11::dtype(names, formats, offsets, sizeof(T)).release().ptr();
717-
718-
// There is an existing bug in NumPy (as of v1.11): trailing bytes are
719-
// not encoded explicitly into the format string. This will supposedly
720-
// get fixed in v1.12; for further details, see these:
721-
// - https://github.com/numpy/numpy/issues/7797
722-
// - https://github.com/numpy/numpy/pull/7798
723-
// Because of this, we won't use numpy's logic to generate buffer format
724-
// strings and will just do it ourselves.
725-
std::vector<field_descriptor> ordered_fields(fields);
726-
std::sort(ordered_fields.begin(), ordered_fields.end(),
727-
[](const field_descriptor &a, const field_descriptor &b) { return a.offset < b.offset; });
728-
size_t offset = 0;
729-
std::ostringstream oss;
730-
oss << "T{";
731-
for (auto& field : ordered_fields) {
732-
if (field.offset > offset)
733-
oss << (field.offset - offset) << 'x';
734-
// note that '=' is required to cover the case of unaligned fields
735-
oss << '=' << field.format << ':' << field.name << ':';
736-
offset = field.offset + field.size;
737-
}
738-
if (sizeof(T) > offset)
739-
oss << (sizeof(T) - offset) << 'x';
740-
oss << '}';
741-
auto format_str = oss.str();
742-
743-
// Sanity check: verify that NumPy properly parses our buffer format string
744-
auto& api = npy_api::get();
745-
auto arr = array(buffer_info(nullptr, sizeof(T), format_str, 1));
746-
if (!api.PyArray_EquivTypes_(dtype_ptr, arr.dtype().ptr()))
747-
pybind11_fail("NumPy: invalid buffer descriptor!");
748-
749-
auto tindex = std::type_index(typeid(T));
750-
numpy_internals.registered_dtypes[tindex] = { dtype_ptr, format_str };
751-
get_internals().direct_conversions[tindex].push_back(direct_converter);
762+
static void register_dtype(const std::initializer_list<field_descriptor>& fields) {
763+
register_structured_dtype(fields, typeid(typename std::remove_cv<T>::type),
764+
sizeof(T), &direct_converter);
752765
}
753766

754767
private:

0 commit comments

Comments
 (0)