diff --git a/include/pybind11/common.h b/include/pybind11/common.h index ef94f3854e..44b7008e45 100644 --- a/include/pybind11/common.h +++ b/include/pybind11/common.h @@ -623,6 +623,24 @@ template struct format_descriptor constexpr const char format_descriptor< T, detail::enable_if_t::value>>::value[2]; +NAMESPACE_BEGIN(detail) + +template struct compare_buffer_info { + static bool compare(const buffer_info& b) { + return b.format == format_descriptor::format() && b.itemsize == sizeof(T); + } +}; + +template struct compare_buffer_info::value>> { + static bool compare(const buffer_info& b) { + return b.itemsize == sizeof(T) && (b.format == format_descriptor::value || + ((sizeof(T) == sizeof(long)) && b.format == (std::is_unsigned::value ? "L" : "l")) || + ((sizeof(T) == sizeof(size_t)) && b.format == (std::is_unsigned::value ? "N" : "n"))); + } +}; + +NAMESPACE_END(detail) + /// RAII wrapper that temporarily clears any Python error state struct error_scope { PyObject *type, *value, *trace; diff --git a/include/pybind11/numpy.h b/include/pybind11/numpy.h index ea9914a48a..7e2d2916e7 100644 --- a/include/pybind11/numpy.h +++ b/include/pybind11/numpy.h @@ -699,6 +699,13 @@ struct pyobject_caster> { PYBIND11_TYPE_CASTER(type, handle_type_name::name()); }; +template +struct compare_buffer_info::value>> { + static bool compare(const buffer_info& b) { + return npy_api::get().PyArray_EquivTypes_(dtype::of().ptr(), dtype(b).ptr()); + } +}; + template struct npy_format_descriptor::value>> { private: // NB: the order here must match the one in common.h diff --git a/include/pybind11/stl_bind.h b/include/pybind11/stl_bind.h index d1d45e2c0b..300e8af9aa 100644 --- a/include/pybind11/stl_bind.h +++ b/include/pybind11/stl_bind.h @@ -326,6 +326,49 @@ template auto vector_if_insertion_operator(Cl ); } +// Provide the buffer interface for vectors if we have data() and we have a format for it +// GCC seems to have "void std::vector::data()" - doing SFINAE on the existence of data() is insufficient, we need to check it returns an appropriate pointer +template +struct vector_has_data_and_format : std::false_type {}; +template +struct vector_has_data_and_format::format(), std::declval().data()), typename Vector::value_type*>::value>> : std::true_type {}; + +// Add the buffer interface to a vector +template +enable_if_t...>::value> +vector_buffer(Class_& cl) { + using T = typename Vector::value_type; + + static_assert(vector_has_data_and_format::value, "There is not an appropriate format descriptor for this vector"); + + // numpy.h declares this for arbitrary types, but it may raise an exception and crash hard at runtime if PYBIND11_NUMPY_DTYPE hasn't been called, so check here + py::format_descriptor::format(); + + cl.def_buffer([](Vector& v) -> py::buffer_info { + return py::buffer_info(v.data(), sizeof(T), py::format_descriptor::format(), 1, {v.size()}, {sizeof(T)}); + }); + + cl.def("__init__", [](Vector& vec, py::buffer buf) { + auto info = buf.request(); + if (info.ndim != 1 || info.strides[0] <= 0 || info.strides[0] % sizeof(T)) + throw pybind11::type_error("Only valid 1D buffers can be copied to a vector"); + if (!detail::compare_buffer_info::compare(info) || sizeof(T) != info.itemsize) + throw pybind11::type_error("Format mismatch (Python: " + info.format + " C++: " + py::format_descriptor::format() + ")"); + new (&vec) Vector(); + vec.reserve(info.shape[0]); + T *p = static_cast(info.ptr); + auto step = info.strides[0] / sizeof(T); + T *end = p + info.shape[0] * step; + for (; p < end; p += step) + vec.push_back(*p); + }); + + return; +} + +template +enable_if_t...>::value> vector_buffer(Class_&) {} + NAMESPACE_END(detail) // @@ -337,6 +380,9 @@ pybind11::class_ bind_vector(pybind11::module &m, std::stri Class_ cl(m, name.c_str(), std::forward(args)...); + // Declare the buffer interface if a py::buffer_protocol() is passed in + detail::vector_buffer(cl); + cl.def(pybind11::init<>()); // Register copy constructor (if possible) diff --git a/tests/test_numpy_dtypes.cpp b/tests/test_numpy_dtypes.cpp index d74ecc59ed..1f6c857043 100644 --- a/tests/test_numpy_dtypes.cpp +++ b/tests/test_numpy_dtypes.cpp @@ -319,6 +319,22 @@ py::list test_dtype_methods() { return list; } +struct CompareStruct { + bool x; + uint32_t y; + float z; +}; + +py::list test_compare_buffer_info() { + py::list list; + list.append(py::bool_(py::detail::compare_buffer_info::compare(py::buffer_info(nullptr, sizeof(float), "f", 1)))); + list.append(py::bool_(py::detail::compare_buffer_info::compare(py::buffer_info(nullptr, sizeof(int), "I", 1)))); + list.append(py::bool_(py::detail::compare_buffer_info::compare(py::buffer_info(nullptr, sizeof(long), "l", 1)))); + list.append(py::bool_(py::detail::compare_buffer_info::compare(py::buffer_info(nullptr, sizeof(long), sizeof(long) == sizeof(int) ? "i" : "q", 1)))); + list.append(py::bool_(py::detail::compare_buffer_info::compare(py::buffer_info(nullptr, sizeof(CompareStruct), "T{?:x:3xI:y:f:z:}", 1)))); + return list; +} + test_initializer numpy_dtypes([](py::module &m) { try { py::module::import("numpy"); @@ -337,6 +353,7 @@ test_initializer numpy_dtypes([](py::module &m) { PYBIND11_NUMPY_DTYPE(StringStruct, a, b); PYBIND11_NUMPY_DTYPE(EnumStruct, e1, e2); PYBIND11_NUMPY_DTYPE(TrailingPaddingStruct, a, b); + PYBIND11_NUMPY_DTYPE(CompareStruct, x, y, z); // ... or after py::class_(m, "PackedStruct"); @@ -366,6 +383,7 @@ test_initializer numpy_dtypes([](py::module &m) { m.def("test_array_ctors", &test_array_ctors); m.def("test_dtype_ctors", &test_dtype_ctors); m.def("test_dtype_methods", &test_dtype_methods); + m.def("compare_buffer_info", &test_compare_buffer_info); m.def("trailing_padding_dtype", &trailing_padding_dtype); m.def("buffer_to_dtype", &buffer_to_dtype); m.def("f_simple", [](SimpleStruct s) { return s.uint_ * 10; }); diff --git a/tests/test_numpy_dtypes.py b/tests/test_numpy_dtypes.py index 0ef4e939a2..f63814f9da 100644 --- a/tests/test_numpy_dtypes.py +++ b/tests/test_numpy_dtypes.py @@ -264,3 +264,9 @@ def test_register_dtype(): with pytest.raises(RuntimeError) as excinfo: register_dtype() assert 'dtype is already registered' in str(excinfo.value) + + +@pytest.requires_numpy +def test_compare_buffer_info(): + from pybind11_tests import compare_buffer_info + assert all(compare_buffer_info()) diff --git a/tests/test_stl_binders.cpp b/tests/test_stl_binders.cpp index ce0b33257d..f636c0b55e 100644 --- a/tests/test_stl_binders.cpp +++ b/tests/test_stl_binders.cpp @@ -10,6 +10,7 @@ #include "pybind11_tests.h" #include +#include #include #include #include @@ -58,17 +59,45 @@ template Map *times_ten(int n) { return m; } +struct VStruct { + bool w; + uint32_t x; + double y; + bool z; +}; + +struct VUndeclStruct { //dtype not declared for this version + bool w; + uint32_t x; + double y; + bool z; +}; + test_initializer stl_binder_vector([](py::module &m) { py::class_(m, "El") .def(py::init()); - py::bind_vector>(m, "VectorInt"); + py::bind_vector>(m, "VectorUChar", py::buffer_protocol()); + py::bind_vector>(m, "VectorInt", py::buffer_protocol()); py::bind_vector>(m, "VectorBool"); py::bind_vector>(m, "VectorEl"); py::bind_vector>>(m, "VectorVectorEl"); + m.def("create_undeclstruct", [m] () mutable { + py::bind_vector>(m, "VectorUndeclStruct", py::buffer_protocol()); + }); + + try { + py::module::import("numpy"); + } catch (...) { + return; + } + PYBIND11_NUMPY_DTYPE(VStruct, w, x, y, z); + py::class_(m, "VStruct").def_readwrite("x", &VStruct::x); + py::bind_vector>(m, "VectorStruct", py::buffer_protocol()); + m.def("get_vectorstruct", [] {return std::vector {{0, 5, 3.0, 1}, {1, 30, -1e4, 0}};}); }); test_initializer stl_binder_map([](py::module &m) { @@ -97,4 +126,3 @@ test_initializer stl_binder_noncopyable([](py::module &m) { py::bind_map>(m, "UmapENC"); m.def("get_umnc", ×_ten>, py::return_value_policy::reference); }); - diff --git a/tests/test_stl_binders.py b/tests/test_stl_binders.py index c9bcc79352..0edf9e26ee 100644 --- a/tests/test_stl_binders.py +++ b/tests/test_stl_binders.py @@ -1,3 +1,10 @@ +import pytest +import sys + +with pytest.suppress(ImportError): + import numpy as np + + def test_vector_int(): from pybind11_tests import VectorInt @@ -26,6 +33,57 @@ def test_vector_int(): assert v_int2 == VectorInt([0, 99, 2, 3]) +@pytest.unsupported_on_pypy +def test_vector_buffer(): + from pybind11_tests import VectorUChar, create_undeclstruct + b = bytearray([1, 2, 3, 4]) + v = VectorUChar(b) + assert v[1] == 2 + v[2] = 5 + m = memoryview(v) # We expose the buffer interface + if sys.version_info.major > 2: + assert m[2] == 5 + m[2] = 6 + else: + assert m[2] == '\x05' + m[2] = '\x06' + assert v[2] == 6 + + with pytest.raises(RuntimeError): + create_undeclstruct() # Undeclared struct contents, no buffer interface + + +@pytest.requires_numpy +def test_vector_buffer_numpy(): + from pybind11_tests import VectorInt, VectorStruct, get_vectorstruct + + a = np.array([1, 2, 3, 4], dtype=np.int32) + with pytest.raises(TypeError): + VectorInt(a) + + a = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]], dtype=np.uintc) + v = VectorInt(a[0, :]) + assert len(v) == 4 + assert v[2] == 3 + m = np.asarray(v) + m[2] = 5 + assert v[2] == 5 + + v = VectorInt(a[:, 1]) + assert len(v) == 3 + assert v[2] == 10 + + v = get_vectorstruct() + assert v[0].x == 5 + m = np.asarray(v) + m[1]['x'] = 99 + assert v[1].x == 99 + + v = VectorStruct(np.zeros(3, dtype=np.dtype([('w', 'bool'), ('x', 'I'), + ('y', 'float64'), ('z', 'bool')], align=True))) + assert len(v) == 3 + + def test_vector_custom(): from pybind11_tests import El, VectorEl, VectorVectorEl