Skip to content

Add the buffer interface for wrapped STL vectors #488

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 14, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions include/pybind11/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -623,6 +623,24 @@ template <typename T> struct format_descriptor<T, detail::enable_if_t<detail::is
template <typename T> constexpr const char format_descriptor<
T, detail::enable_if_t<detail::is_fmt_numeric<T>::value>>::value[2];

NAMESPACE_BEGIN(detail)

template <typename T, typename SFINAE = void> struct compare_buffer_info {
static bool compare(const buffer_info& b) {
return b.format == format_descriptor<T>::format() && b.itemsize == sizeof(T);
}
};

template <typename T> struct compare_buffer_info<T, detail::enable_if_t<std::is_integral<T>::value>> {
static bool compare(const buffer_info& b) {
return b.itemsize == sizeof(T) && (b.format == format_descriptor<T>::value ||
((sizeof(T) == sizeof(long)) && b.format == (std::is_unsigned<T>::value ? "L" : "l")) ||
((sizeof(T) == sizeof(size_t)) && b.format == (std::is_unsigned<T>::value ? "N" : "n")));
}
};

NAMESPACE_END(detail)

/// RAII wrapper that temporarily clears any Python error state
struct error_scope {
PyObject *type, *value, *trace;
Expand Down
7 changes: 7 additions & 0 deletions include/pybind11/numpy.h
Original file line number Diff line number Diff line change
Expand Up @@ -699,6 +699,13 @@ struct pyobject_caster<array_t<T, ExtraFlags>> {
PYBIND11_TYPE_CASTER(type, handle_type_name<type>::name());
};

template <typename T>
struct compare_buffer_info<T, detail::enable_if_t<detail::is_pod_struct<T>::value>> {
static bool compare(const buffer_info& b) {
return npy_api::get().PyArray_EquivTypes_(dtype::of<T>().ptr(), dtype(b).ptr());
}
};

template <typename T> struct npy_format_descriptor<T, enable_if_t<satisfies_any_of<T, std::is_arithmetic, is_complex>::value>> {
private:
// NB: the order here must match the one in common.h
Expand Down
46 changes: 46 additions & 0 deletions include/pybind11/stl_bind.h
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,49 @@ template <typename Vector, typename Class_> auto vector_if_insertion_operator(Cl
);
}

// Provide the buffer interface for vectors if we have data() and we have a format for it
// GCC seems to have "void std::vector<bool>::data()" - doing SFINAE on the existence of data() is insufficient, we need to check it returns an appropriate pointer
template <typename Vector, typename = void>
struct vector_has_data_and_format : std::false_type {};
template <typename Vector>
struct vector_has_data_and_format<Vector, enable_if_t<std::is_same<decltype(py::format_descriptor<typename Vector::value_type>::format(), std::declval<Vector>().data()), typename Vector::value_type*>::value>> : std::true_type {};

// Add the buffer interface to a vector
template <typename Vector, typename Class_, typename... Args>
enable_if_t<detail::any_of<std::is_same<Args, py::buffer_protocol>...>::value>
vector_buffer(Class_& cl) {
using T = typename Vector::value_type;

static_assert(vector_has_data_and_format<Vector>::value, "There is not an appropriate format descriptor for this vector");

// numpy.h declares this for arbitrary types, but it may raise an exception and crash hard at runtime if PYBIND11_NUMPY_DTYPE hasn't been called, so check here
py::format_descriptor<T>::format();

cl.def_buffer([](Vector& v) -> py::buffer_info {
return py::buffer_info(v.data(), sizeof(T), py::format_descriptor<T>::format(), 1, {v.size()}, {sizeof(T)});
});

cl.def("__init__", [](Vector& vec, py::buffer buf) {
auto info = buf.request();
if (info.ndim != 1 || info.strides[0] <= 0 || info.strides[0] % sizeof(T))
throw pybind11::type_error("Only valid 1D buffers can be copied to a vector");
if (!detail::compare_buffer_info<T>::compare(info) || sizeof(T) != info.itemsize)
throw pybind11::type_error("Format mismatch (Python: " + info.format + " C++: " + py::format_descriptor<T>::format() + ")");
new (&vec) Vector();
vec.reserve(info.shape[0]);
T *p = static_cast<T*>(info.ptr);
auto step = info.strides[0] / sizeof(T);
T *end = p + info.shape[0] * step;
for (; p < end; p += step)
vec.push_back(*p);
});

return;
}

template <typename Vector, typename Class_, typename... Args>
enable_if_t<!detail::any_of<std::is_same<Args, py::buffer_protocol>...>::value> vector_buffer(Class_&) {}

NAMESPACE_END(detail)

//
Expand All @@ -337,6 +380,9 @@ pybind11::class_<Vector, holder_type> bind_vector(pybind11::module &m, std::stri

Class_ cl(m, name.c_str(), std::forward<Args>(args)...);

// Declare the buffer interface if a py::buffer_protocol() is passed in
detail::vector_buffer<Vector, Class_, Args...>(cl);

cl.def(pybind11::init<>());

// Register copy constructor (if possible)
Expand Down
18 changes: 18 additions & 0 deletions tests/test_numpy_dtypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,22 @@ py::list test_dtype_methods() {
return list;
}

struct CompareStruct {
bool x;
uint32_t y;
float z;
};

py::list test_compare_buffer_info() {
py::list list;
list.append(py::bool_(py::detail::compare_buffer_info<float>::compare(py::buffer_info(nullptr, sizeof(float), "f", 1))));
list.append(py::bool_(py::detail::compare_buffer_info<unsigned>::compare(py::buffer_info(nullptr, sizeof(int), "I", 1))));
list.append(py::bool_(py::detail::compare_buffer_info<long>::compare(py::buffer_info(nullptr, sizeof(long), "l", 1))));
list.append(py::bool_(py::detail::compare_buffer_info<long>::compare(py::buffer_info(nullptr, sizeof(long), sizeof(long) == sizeof(int) ? "i" : "q", 1))));
list.append(py::bool_(py::detail::compare_buffer_info<CompareStruct>::compare(py::buffer_info(nullptr, sizeof(CompareStruct), "T{?:x:3xI:y:f:z:}", 1))));
return list;
}

test_initializer numpy_dtypes([](py::module &m) {
try {
py::module::import("numpy");
Expand All @@ -337,6 +353,7 @@ test_initializer numpy_dtypes([](py::module &m) {
PYBIND11_NUMPY_DTYPE(StringStruct, a, b);
PYBIND11_NUMPY_DTYPE(EnumStruct, e1, e2);
PYBIND11_NUMPY_DTYPE(TrailingPaddingStruct, a, b);
PYBIND11_NUMPY_DTYPE(CompareStruct, x, y, z);

// ... or after
py::class_<PackedStruct>(m, "PackedStruct");
Expand Down Expand Up @@ -366,6 +383,7 @@ test_initializer numpy_dtypes([](py::module &m) {
m.def("test_array_ctors", &test_array_ctors);
m.def("test_dtype_ctors", &test_dtype_ctors);
m.def("test_dtype_methods", &test_dtype_methods);
m.def("compare_buffer_info", &test_compare_buffer_info);
m.def("trailing_padding_dtype", &trailing_padding_dtype);
m.def("buffer_to_dtype", &buffer_to_dtype);
m.def("f_simple", [](SimpleStruct s) { return s.uint_ * 10; });
Expand Down
6 changes: 6 additions & 0 deletions tests/test_numpy_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,3 +264,9 @@ def test_register_dtype():
with pytest.raises(RuntimeError) as excinfo:
register_dtype()
assert 'dtype is already registered' in str(excinfo.value)


@pytest.requires_numpy
def test_compare_buffer_info():
from pybind11_tests import compare_buffer_info
assert all(compare_buffer_info())
32 changes: 30 additions & 2 deletions tests/test_stl_binders.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "pybind11_tests.h"

#include <pybind11/stl_bind.h>
#include <pybind11/numpy.h>
#include <map>
#include <deque>
#include <unordered_map>
Expand Down Expand Up @@ -58,17 +59,45 @@ template <class Map> Map *times_ten(int n) {
return m;
}

struct VStruct {
bool w;
uint32_t x;
double y;
bool z;
};

struct VUndeclStruct { //dtype not declared for this version
bool w;
uint32_t x;
double y;
bool z;
};

test_initializer stl_binder_vector([](py::module &m) {
py::class_<El>(m, "El")
.def(py::init<int>());

py::bind_vector<std::vector<unsigned int>>(m, "VectorInt");
py::bind_vector<std::vector<unsigned char>>(m, "VectorUChar", py::buffer_protocol());
py::bind_vector<std::vector<unsigned int>>(m, "VectorInt", py::buffer_protocol());
py::bind_vector<std::vector<bool>>(m, "VectorBool");

py::bind_vector<std::vector<El>>(m, "VectorEl");

py::bind_vector<std::vector<std::vector<El>>>(m, "VectorVectorEl");

m.def("create_undeclstruct", [m] () mutable {
py::bind_vector<std::vector<VUndeclStruct>>(m, "VectorUndeclStruct", py::buffer_protocol());
});

try {
py::module::import("numpy");
} catch (...) {
return;
}
PYBIND11_NUMPY_DTYPE(VStruct, w, x, y, z);
py::class_<VStruct>(m, "VStruct").def_readwrite("x", &VStruct::x);
py::bind_vector<std::vector<VStruct>>(m, "VectorStruct", py::buffer_protocol());
m.def("get_vectorstruct", [] {return std::vector<VStruct> {{0, 5, 3.0, 1}, {1, 30, -1e4, 0}};});
});

test_initializer stl_binder_map([](py::module &m) {
Expand Down Expand Up @@ -97,4 +126,3 @@ test_initializer stl_binder_noncopyable([](py::module &m) {
py::bind_map<std::unordered_map<int, E_nc>>(m, "UmapENC");
m.def("get_umnc", &times_ten<std::unordered_map<int, E_nc>>, py::return_value_policy::reference);
});

58 changes: 58 additions & 0 deletions tests/test_stl_binders.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
import pytest
import sys

with pytest.suppress(ImportError):
import numpy as np


def test_vector_int():
from pybind11_tests import VectorInt

Expand Down Expand Up @@ -26,6 +33,57 @@ def test_vector_int():
assert v_int2 == VectorInt([0, 99, 2, 3])


@pytest.unsupported_on_pypy
def test_vector_buffer():
from pybind11_tests import VectorUChar, create_undeclstruct
b = bytearray([1, 2, 3, 4])
v = VectorUChar(b)
assert v[1] == 2
v[2] = 5
m = memoryview(v) # We expose the buffer interface
if sys.version_info.major > 2:
assert m[2] == 5
m[2] = 6
else:
assert m[2] == '\x05'
m[2] = '\x06'
assert v[2] == 6

with pytest.raises(RuntimeError):
create_undeclstruct() # Undeclared struct contents, no buffer interface


@pytest.requires_numpy
def test_vector_buffer_numpy():
from pybind11_tests import VectorInt, VectorStruct, get_vectorstruct

a = np.array([1, 2, 3, 4], dtype=np.int32)
with pytest.raises(TypeError):
VectorInt(a)

a = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]], dtype=np.uintc)
v = VectorInt(a[0, :])
assert len(v) == 4
assert v[2] == 3
m = np.asarray(v)
m[2] = 5
assert v[2] == 5

v = VectorInt(a[:, 1])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice 👍

assert len(v) == 3
assert v[2] == 10

v = get_vectorstruct()
assert v[0].x == 5
m = np.asarray(v)
m[1]['x'] = 99
assert v[1].x == 99

v = VectorStruct(np.zeros(3, dtype=np.dtype([('w', 'bool'), ('x', 'I'),
('y', 'float64'), ('z', 'bool')], align=True)))
assert len(v) == 3


def test_vector_custom():
from pybind11_tests import El, VectorEl, VectorVectorEl

Expand Down