Skip to content

Commit b4f065b

Browse files
committed
Add the buffer interface for wrapped STL vectors
Allows use of vectors as python buffers, so for example they can be adopted without a copy by numpy.asarray Allows faster conversion of numeric buffers to vectors with memcpy instead of individually casting the elements
1 parent 425b497 commit b4f065b

File tree

3 files changed

+107
-2
lines changed

3 files changed

+107
-2
lines changed

include/pybind11/stl_bind.h

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ template <typename, typename, typename... Args> void vector_if_copy_constructibl
6464
template <typename, typename, typename... Args> void vector_if_equal_operator(const Args &...) { }
6565
template <typename, typename, typename... Args> void vector_if_insertion_operator(const Args &...) { }
6666
template <typename, typename, typename... Args> void vector_modifiers(const Args &...) { }
67+
template <typename, typename, typename... Args> void vector_buffer(const Args&...) { }
6768

6869
template<typename Vector, typename Class_>
6970
void vector_if_copy_constructible(enable_if_t<
@@ -326,6 +327,36 @@ template <typename Vector, typename Class_> auto vector_if_insertion_operator(Cl
326327
);
327328
}
328329

330+
// Provide the buffer interface for vectors if we have data() and we have a format for it
331+
// GCC seems to have "void std::vector<bool>::data()" - doing SFINAE on the existence of data() is insufficient, we need to check it doesn't return void
332+
template <typename Vector, typename Class_>
333+
enable_if_t<!std::is_same<decltype(py::format_descriptor<typename Vector::value_type>::format(), std::declval<Vector>().data()), void>::value>
334+
vector_buffer(Class_& cl) {
335+
using T = typename Vector::value_type;
336+
337+
try {
338+
//numpy.h declares this for arbitrary types, but it may raise an exception if PYBIND11_NUMPY_DTYPE hasn't been called
339+
py::format_descriptor<T>::format();
340+
} catch (std::runtime_error&) {
341+
return;
342+
}
343+
344+
cl.def_buffer([](Vector& v) -> py::buffer_info {
345+
return py::buffer_info(v.data(), sizeof(T), py::format_descriptor<T>::format(), 1, {v.size()}, {sizeof(T)});
346+
});
347+
348+
cl.def("__init__", [](Vector& vec, py::buffer buf) {
349+
auto info = buf.request();
350+
if (info.ndim != 1)
351+
throw pybind11::type_error("Only 1D buffers can be copied to a vector");
352+
if (info.strides[0] != sizeof(T))
353+
throw pybind11::type_error("Item size mismatch (Python: " + std::to_string(info.strides[0]) + " C++: " + std::to_string(sizeof(T)) + ")");
354+
if (info.format != py::format_descriptor<T>::format())
355+
throw pybind11::type_error("Format mismatch (Python: " + info.format + " C++: " + py::format_descriptor<T>::format() + ")");
356+
new (&vec) Vector(static_cast<T*>(info.ptr), static_cast<T*>(info.ptr) + info.shape[0]);
357+
});
358+
}
359+
329360
NAMESPACE_END(detail)
330361

331362
//
@@ -348,6 +379,9 @@ pybind11::class_<Vector, holder_type> bind_vector(pybind11::module &m, std::stri
348379
// Register stream insertion operator (if possible)
349380
detail::vector_if_insertion_operator<Vector, Class_>(cl, name);
350381

382+
// Register the buffer interface (if possible)
383+
detail::vector_buffer<Vector, Class_>(cl);
384+
351385
// Modifiers require copyable vector value type
352386
detail::vector_modifiers<Vector, Class_>(cl);
353387

tests/test_stl_binders.cpp

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "pybind11_tests.h"
1111

1212
#include <pybind11/stl_bind.h>
13+
#include <pybind11/numpy.h>
1314
#include <map>
1415
#include <deque>
1516
#include <unordered_map>
@@ -53,17 +54,44 @@ template <class Map> Map *times_ten(int n) {
5354
return m;
5455
}
5556

57+
struct VStruct {
58+
bool w;
59+
uint32_t x;
60+
double y;
61+
bool z;
62+
};
63+
64+
struct VUndeclStruct { //dtype not declared for this version
65+
bool w;
66+
uint32_t x;
67+
double y;
68+
bool z;
69+
};
70+
5671
test_initializer stl_binder_vector([](py::module &m) {
5772
py::class_<El>(m, "El")
5873
.def(py::init<int>());
5974

75+
py::bind_vector<std::vector<unsigned char>>(m, "VectorUChar");
6076
py::bind_vector<std::vector<unsigned int>>(m, "VectorInt");
6177
py::bind_vector<std::vector<bool>>(m, "VectorBool");
6278

6379
py::bind_vector<std::vector<El>>(m, "VectorEl");
6480

6581
py::bind_vector<std::vector<std::vector<El>>>(m, "VectorVectorEl");
6682

83+
py::bind_vector<std::vector<VUndeclStruct>>(m, "VectorUndeclStruct");
84+
m.def("get_undeclstruct", [] {return std::vector<VUndeclStruct> {{0, 5, 3.0, 1}, {1, 30, -1e4, 0}};});
85+
86+
try {
87+
py::module::import("numpy");
88+
} catch (...) {
89+
return;
90+
}
91+
PYBIND11_NUMPY_DTYPE(VStruct, w, x, y, z);
92+
py::class_<VStruct>(m, "VStruct").def_readwrite("x", &VStruct::x);
93+
py::bind_vector<std::vector<VStruct>>(m, "VectorStruct");
94+
m.def("get_vectorstruct", [] {return std::vector<VStruct> {{0, 5, 3.0, 1}, {1, 30, -1e4, 0}};});
6795
});
6896

6997
test_initializer stl_binder_map([](py::module &m) {
@@ -92,4 +120,3 @@ test_initializer stl_binder_noncopyable([](py::module &m) {
92120
py::bind_map<std::unordered_map<int, E_nc>>(m, "UmapENC");
93121
m.def("get_umnc", &times_ten<std::unordered_map<int, E_nc>>, py::return_value_policy::reference);
94122
});
95-

tests/test_stl_binders.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
import pytest
2+
import sys
3+
4+
with pytest.suppress(ImportError):
5+
import numpy as np
6+
17
def test_vector_int():
28
from pybind11_tests import VectorInt
39

@@ -25,6 +31,45 @@ def test_vector_int():
2531
del v_int2[0]
2632
assert v_int2 == VectorInt([0, 99, 2, 3])
2733

34+
def test_vector_buffer():
35+
from pybind11_tests import VectorUChar, get_undeclstruct
36+
b = bytearray([1,2,3,4])
37+
v = VectorUChar(b);
38+
assert v[1] == 2
39+
v[2] = 5
40+
m = memoryview(v) #We expose the buffer interface
41+
if sys.version_info.major > 2:
42+
assert m[2] == 5
43+
m[2] = 6
44+
else:
45+
assert m[2] == '\x05'
46+
m[2] = '\x06'
47+
assert v[2] == 6
48+
49+
v = get_undeclstruct()
50+
with pytest.raises(TypeError):
51+
memoryview(v) #Undeclared struct contents, no buffer interface
52+
53+
@pytest.requires_numpy
54+
def test_vector_buffer_numpy():
55+
from pybind11_tests import VectorInt, VectorStruct, get_vectorstruct
56+
57+
a = np.array([1,2,3,4], dtype=np.int32)
58+
with pytest.raises(TypeError):
59+
VectorInt(a)
60+
61+
a = np.array([1,2,3,4], dtype=np.uintc)
62+
v = VectorInt(a)
63+
assert v[2] == 3
64+
m = np.asarray(v)
65+
m[2] = 5
66+
assert v[2] == 5
67+
68+
v = get_vectorstruct()
69+
assert v[0].x == 5
70+
m = np.asarray(v)
71+
m[1]['x'] = 99
72+
assert v[1].x == 99
2873

2974
def test_vector_custom():
3075
from pybind11_tests import El, VectorEl, VectorVectorEl
@@ -150,4 +195,3 @@ def test_noncopyable_unordered_map():
150195
vsum += v.value
151196

152197
assert(vsum == 150)
153-

0 commit comments

Comments
 (0)