Skip to content

Commit 7b51cc7

Browse files
committed
Add function for comparing buffer_info formats to types
Allows equivalent integral types and numpy dtypes
1 parent 29544af commit 7b51cc7

File tree

5 files changed

+54
-1
lines changed

5 files changed

+54
-1
lines changed

include/pybind11/common.h

+18
Original file line numberDiff line numberDiff line change
@@ -617,6 +617,24 @@ template <typename T> struct format_descriptor<T, detail::enable_if_t<detail::is
617617
template <typename T> constexpr const char format_descriptor<
618618
T, detail::enable_if_t<detail::is_fmt_numeric<T>::value>>::value[2];
619619

620+
NAMESPACE_BEGIN(detail)
621+
622+
template <typename T, typename SFINAE = void> struct compare_buffer_info {
623+
static bool compare(const buffer_info& b) {
624+
return b.format == format_descriptor<T>::format() && b.itemsize == sizeof(T);
625+
}
626+
};
627+
628+
template <typename T> struct compare_buffer_info<T, detail::enable_if_t<std::is_integral<T>::value>> {
629+
static bool compare(const buffer_info& b) {
630+
return b.itemsize == sizeof(T) && (b.format == format_descriptor<T>::value ||
631+
((sizeof(T) == sizeof(long)) && b.format == (std::is_unsigned<T>::value ? "L" : "l")) ||
632+
((sizeof(T) == sizeof(size_t)) && b.format == (std::is_unsigned<T>::value ? "N" : "n")));
633+
}
634+
};
635+
636+
NAMESPACE_END(detail)
637+
620638
/// RAII wrapper that temporarily clears any Python error state
621639
struct error_scope {
622640
PyObject *type, *value, *trace;

include/pybind11/numpy.h

+7
Original file line numberDiff line numberDiff line change
@@ -694,6 +694,13 @@ struct pyobject_caster<array_t<T, ExtraFlags>> {
694694
PYBIND11_TYPE_CASTER(type, handle_type_name<type>::name());
695695
};
696696

697+
template <typename T>
698+
struct compare_buffer_info<T, detail::enable_if_t<detail::is_pod_struct<T>::value>> {
699+
static bool compare(const buffer_info& b) {
700+
return npy_api::get().PyArray_EquivTypes_(dtype::of<T>().ptr(), dtype(b).ptr());
701+
}
702+
};
703+
697704
template <typename T> struct npy_format_descriptor<T, enable_if_t<satisfies_any_of<T, std::is_arithmetic, is_complex>::value>> {
698705
private:
699706
// NB: the order here must match the one in common.h

tests/test_numpy_dtypes.cpp

+18
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,22 @@ py::list test_dtype_methods() {
319319
return list;
320320
}
321321

322+
struct CompareStruct {
323+
bool x;
324+
uint32_t y;
325+
float z;
326+
};
327+
328+
py::list test_compare_buffer_info() {
329+
py::list list;
330+
list.append(py::bool_(py::detail::compare_buffer_info<float>::compare(py::buffer_info(nullptr, sizeof(float), "f", 1))));
331+
list.append(py::bool_(py::detail::compare_buffer_info<unsigned>::compare(py::buffer_info(nullptr, sizeof(int), "I", 1))));
332+
list.append(py::bool_(py::detail::compare_buffer_info<long>::compare(py::buffer_info(nullptr, sizeof(long), "l", 1))));
333+
list.append(py::bool_(py::detail::compare_buffer_info<long>::compare(py::buffer_info(nullptr, sizeof(long), sizeof(long) == sizeof(int) ? "i" : "q", 1))));
334+
list.append(py::bool_(py::detail::compare_buffer_info<CompareStruct>::compare(py::buffer_info(nullptr, sizeof(CompareStruct), "T{?:x:3xI:y:f:z:}", 1))));
335+
return list;
336+
}
337+
322338
test_initializer numpy_dtypes([](py::module &m) {
323339
try {
324340
py::module::import("numpy");
@@ -337,6 +353,7 @@ test_initializer numpy_dtypes([](py::module &m) {
337353
PYBIND11_NUMPY_DTYPE(StringStruct, a, b);
338354
PYBIND11_NUMPY_DTYPE(EnumStruct, e1, e2);
339355
PYBIND11_NUMPY_DTYPE(TrailingPaddingStruct, a, b);
356+
PYBIND11_NUMPY_DTYPE(CompareStruct, x, y, z);
340357

341358
// ... or after
342359
py::class_<PackedStruct>(m, "PackedStruct");
@@ -366,6 +383,7 @@ test_initializer numpy_dtypes([](py::module &m) {
366383
m.def("test_array_ctors", &test_array_ctors);
367384
m.def("test_dtype_ctors", &test_dtype_ctors);
368385
m.def("test_dtype_methods", &test_dtype_methods);
386+
m.def("compare_buffer_info", &test_compare_buffer_info);
369387
m.def("trailing_padding_dtype", &trailing_padding_dtype);
370388
m.def("buffer_to_dtype", &buffer_to_dtype);
371389
m.def("f_simple", [](SimpleStruct s) { return s.uint_ * 10; });

tests/test_numpy_dtypes.py

+6
Original file line numberDiff line numberDiff line change
@@ -263,3 +263,9 @@ def test_register_dtype():
263263
with pytest.raises(RuntimeError) as excinfo:
264264
register_dtype()
265265
assert 'dtype is already registered' in str(excinfo.value)
266+
267+
268+
@pytest.requires_numpy
269+
def test_compare_buffer_info():
270+
from pybind11_tests import compare_buffer_info
271+
assert all(compare_buffer_info())

tests/test_stl_binders.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def test_vector_buffer():
5555

5656
@pytest.requires_numpy
5757
def test_vector_buffer_numpy():
58-
from pybind11_tests import VectorInt, get_vectorstruct
58+
from pybind11_tests import VectorInt, VectorStruct, get_vectorstruct
5959

6060
a = np.array([1, 2, 3, 4], dtype=np.int32)
6161
with pytest.raises(TypeError):
@@ -79,6 +79,10 @@ def test_vector_buffer_numpy():
7979
m[1]['x'] = 99
8080
assert v[1].x == 99
8181

82+
v = VectorStruct(np.zeros(3, dtype=np.dtype([('w', 'bool'), ('x', 'I'),
83+
('y', 'float64'), ('z', 'bool')], align=True)))
84+
assert len(v) == 3
85+
8286

8387
def test_vector_custom():
8488
from pybind11_tests import El, VectorEl, VectorVectorEl

0 commit comments

Comments
 (0)