From b943d9cb4845d7b56911439180595958a726070a Mon Sep 17 00:00:00 2001 From: Johan Mabille Date: Wed, 25 Jan 2017 00:57:53 +0100 Subject: [PATCH] safe/unsafe access policy --- include/pybind11/numpy.h | 181 +++++++++++++++++++++++---------------- 1 file changed, 108 insertions(+), 73 deletions(-) diff --git a/include/pybind11/numpy.h b/include/pybind11/numpy.h index 6fecf28531..364a29af85 100644 --- a/include/pybind11/numpy.h +++ b/include/pybind11/numpy.h @@ -316,9 +316,62 @@ class dtype : public object { } }; -class array : public buffer { +NAMESPACE_BEGIN(detail) +[[noreturn]] PYBIND11_NOINLINE inline void fail_dim_check(size_t dim, size_t ndim, const std::string& msg) { + throw index_error(msg + ": " + std::to_string(dim) + " (ndim = " + std::to_string(ndim) + ")"); +} +NAMESPACE_END(detail) + +class safe_access_policy { +public: + void check_axis(size_t dim, size_t ndim) const { + if (dim >= ndim) + detail::fail_dim_check(dim, ndim, "invalid axis"); + } + + template + void check_indices(size_t ndim, Ix...) const { + if (sizeof...(Ix) > ndim) + detail::fail_dim_check(sizeof...(Ix), ndim, "too many indices for an array"); + } + + template + void check_dimensions(const size_t* shape, Ix... index) const { + check_dimensions_impl(size_t(0), shape, size_t(index)...); + } + +private: + void check_dimensions_impl(size_t, const size_t*) const { } + + template + void check_dimensions_impl(size_t axis, const size_t* shape, size_t i, Ix... index) const { + if (i >= *shape) { + throw index_error(std::string("index ") + std::to_string(i) + + " is out of bounds for axis " + std::to_string(axis) + + " with size " + std::to_string(*shape)); + } + check_dimensions_impl(axis + 1, shape + 1, index...); + } +}; + +class unsafe_access_policy { public: - PYBIND11_OBJECT_CVT(array, buffer, detail::npy_api::get().PyArray_Check_, raw_array) + void check_axis(size_t, size_t) const { + } + + template + void check_indices(size_t, Ix...) const { + } + + template + void check_dimensions(const size_t*, Ix...) const { + } +}; + +template +class array_base : public buffer, private access_policy { +public: + PYBIND11_OBJECT_CVT(array_base, buffer, detail::npy_api::get().PyArray_Check_, raw_array) enum { c_style = detail::npy_api::NPY_C_CONTIGUOUS_, @@ -326,9 +379,9 @@ class array : public buffer { forcecast = detail::npy_api::NPY_ARRAY_FORCECAST_ }; - array() : array(0, static_cast(nullptr)) {} + array_base() : array_base(0, static_cast(nullptr)) {} - array(const pybind11::dtype &dt, const std::vector &shape, + array_base(const pybind11::dtype &dt, const std::vector &shape, const std::vector &strides, const void *ptr = nullptr, handle base = handle()) { auto& api = detail::npy_api::get(); @@ -339,9 +392,9 @@ class array : public buffer { int flags = 0; if (base && ptr) { - if (isinstance(base)) + if (isinstance(base)) /* Copy flags from base (except baseship bit) */ - flags = reinterpret_borrow(base).flags() & ~detail::npy_api::NPY_ARRAY_OWNDATA_; + flags = reinterpret_borrow(base).flags() & ~detail::npy_api::NPY_ARRAY_OWNDATA_; else /* Writable by default, easy to downgrade later on if needed */ flags = detail::npy_api::NPY_ARRAY_WRITEABLE_; @@ -362,30 +415,30 @@ class array : public buffer { m_ptr = tmp.release().ptr(); } - array(const pybind11::dtype &dt, const std::vector &shape, + array_base(const pybind11::dtype &dt, const std::vector &shape, const void *ptr = nullptr, handle base = handle()) - : array(dt, shape, default_strides(shape, dt.itemsize()), ptr, base) { } + : array_base(dt, shape, default_strides(shape, dt.itemsize()), ptr, base) { } - array(const pybind11::dtype &dt, size_t count, const void *ptr = nullptr, + array_base(const pybind11::dtype &dt, size_t count, const void *ptr = nullptr, handle base = handle()) - : array(dt, std::vector{ count }, ptr, base) { } + : array_base(dt, std::vector{ count }, ptr, base) { } - template array(const std::vector& shape, + template array_base(const std::vector& shape, const std::vector& strides, const T* ptr, handle base = handle()) - : array(pybind11::dtype::of(), shape, strides, (void *) ptr, base) { } + : array_base(pybind11::dtype::of(), shape, strides, (void *) ptr, base) { } template - array(const std::vector &shape, const T *ptr, + array_base(const std::vector &shape, const T *ptr, handle base = handle()) - : array(shape, default_strides(shape, sizeof(T)), ptr, base) { } + : array_base(shape, default_strides(shape, sizeof(T)), ptr, base) { } template - array(size_t count, const T *ptr, handle base = handle()) - : array(std::vector{ count }, ptr, base) { } + array_base(size_t count, const T *ptr, handle base = handle()) + : array_base(std::vector{ count }, ptr, base) { } - explicit array(const buffer_info &info) - : array(pybind11::dtype(info), info.shape, info.strides, info.ptr) { } + explicit array_base(const buffer_info &info) + : array_base(pybind11::dtype(info), info.shape, info.strides, info.ptr) { } /// Array descriptor (dtype) pybind11::dtype dtype() const { @@ -424,8 +477,7 @@ class array : public buffer { /// Dimension along a given axis size_t shape(size_t dim) const { - if (dim >= ndim()) - fail_dim_check(dim, "invalid axis"); + access_policy::check_axis(dim, ndim()); return shape()[dim]; } @@ -436,8 +488,7 @@ class array : public buffer { /// Stride along a given axis size_t strides(size_t dim) const { - if (dim >= ndim()) - fail_dim_check(dim, "invalid axis"); + access_policy::check_axis(dim, ndim()); return strides()[dim]; } @@ -473,8 +524,7 @@ class array : public buffer { /// Byte offset from beginning of the array to a given index (full or partial). /// May throw if the index would lead to out of bounds access. template size_t offset_at(Ix... index) const { - if (sizeof...(index) > ndim()) - fail_dim_check(sizeof...(index), "too many indices for an array"); + access_policy::check_indices(ndim(), index...); return byte_offset(size_t(index)...); } @@ -487,15 +537,15 @@ class array : public buffer { } /// Return a new view with all of the dimensions of length 1 removed - array squeeze() { + array_base squeeze() { auto& api = detail::npy_api::get(); - return reinterpret_steal(api.PyArray_Squeeze_(m_ptr)); + return reinterpret_steal(api.PyArray_Squeeze_(m_ptr)); } /// Ensure that the argument is a NumPy array /// In case of an error, nullptr is returned and the Python error is cleared. - static array ensure(handle h, int ExtraFlags = 0) { - auto result = reinterpret_steal(raw_array(h.ptr(), ExtraFlags)); + static array_base ensure(handle h, int ExtraFlags = 0) { + auto result = reinterpret_steal(raw_array(h.ptr(), ExtraFlags)); if (!result) PyErr_Clear(); return result; @@ -504,13 +554,8 @@ class array : public buffer { protected: template friend struct detail::npy_format_descriptor; - void fail_dim_check(size_t dim, const std::string& msg) const { - throw index_error(msg + ": " + std::to_string(dim) + - " (ndim = " + std::to_string(ndim()) + ")"); - } - template size_t byte_offset(Ix... index) const { - check_dimensions(index...); + access_policy::check_dimensions(shape(), index...); return byte_offset_unsafe(index...); } @@ -537,21 +582,6 @@ class array : public buffer { return strides; } - template void check_dimensions(Ix... index) const { - check_dimensions_impl(size_t(0), shape(), size_t(index)...); - } - - void check_dimensions_impl(size_t, const size_t*) const { } - - template void check_dimensions_impl(size_t axis, const size_t* shape, size_t i, Ix... index) const { - if (i >= *shape) { - throw index_error(std::string("index ") + std::to_string(i) + - " is out of bounds for axis " + std::to_string(axis) + - " with size " + std::to_string(*shape)); - } - check_dimensions_impl(axis + 1, shape + 1, index...); - } - /// Create array from any object -- always returns a new reference static PyObject *raw_array(PyObject *ptr, int ExtraFlags = 0) { if (ptr == nullptr) @@ -561,64 +591,69 @@ class array : public buffer { } }; -template class array_t : public array { +using array = array_base; +using array_unchecked = array_base; + +template ::forcecast, class access_policy = safe_access_policy> +class array_t : public array_base { public: - array_t() : array(0, static_cast(nullptr)) {} - array_t(handle h, borrowed_t) : array(h, borrowed) { } - array_t(handle h, stolen_t) : array(h, stolen) { } + using base_type = array_base; + array_t() : base_type(0, static_cast(nullptr)) {} + array_t(handle h, object::borrowed_t) : base_type(h, object::borrowed) { } + array_t(handle h, object::stolen_t) : base_type(h, object::stolen) { } PYBIND11_DEPRECATED("Use array_t::ensure() instead") - array_t(handle h, bool is_borrowed) : array(raw_array_t(h.ptr()), stolen) { - if (!m_ptr) PyErr_Clear(); + array_t(handle h, bool is_borrowed) : base_type(raw_array_t(h.ptr()), object::stolen) { + if (!this->m_ptr) PyErr_Clear(); if (!is_borrowed) Py_XDECREF(h.ptr()); } - array_t(const object &o) : array(raw_array_t(o.ptr()), stolen) { - if (!m_ptr) throw error_already_set(); + array_t(const object &o) : base_type(raw_array_t(o.ptr()), object::stolen) { + if (!this->m_ptr) throw error_already_set(); } - explicit array_t(const buffer_info& info) : array(info) { } + explicit array_t(const buffer_info& info) : base_type(info) { } array_t(const std::vector &shape, const std::vector &strides, const T *ptr = nullptr, handle base = handle()) - : array(shape, strides, ptr, base) { } + : base_type(shape, strides, ptr, base) { } explicit array_t(const std::vector &shape, const T *ptr = nullptr, handle base = handle()) - : array(shape, ptr, base) { } + : base_type(shape, ptr, base) { } explicit array_t(size_t count, const T *ptr = nullptr, handle base = handle()) - : array(count, ptr, base) { } + : base_type(count, ptr, base) { } constexpr size_t itemsize() const { return sizeof(T); } template size_t index_at(Ix... index) const { - return offset_at(index...) / itemsize(); + return base_type::offset_at(index...) / itemsize(); } template const T* data(Ix... index) const { - return static_cast(array::data(index...)); + return static_cast(base_type::data(index...)); } template T* mutable_data(Ix... index) { - return static_cast(array::mutable_data(index...)); + return static_cast(base_type::mutable_data(index...)); } // Reference to element at a given index template const T& at(Ix... index) const { - if (sizeof...(index) != ndim()) - fail_dim_check(sizeof...(index), "index dimension mismatch"); - return *(static_cast(array::data()) + byte_offset(size_t(index)...) / itemsize()); + if (sizeof...(index) != base_type::ndim()) + detail::fail_dim_check(sizeof...(index), base_type::ndim(), "index dimension mismatch"); + return *(static_cast(base_type::data()) + base_type::byte_offset(size_t(index)...) / itemsize()); } // Mutable reference to element at a given index template T& mutable_at(Ix... index) { - if (sizeof...(index) != ndim()) - fail_dim_check(sizeof...(index), "index dimension mismatch"); - return *(static_cast(array::mutable_data()) + byte_offset(size_t(index)...) / itemsize()); + if (sizeof...(index) != base_type::ndim()) + detail::fail_dim_check(sizeof...(index), base_type::ndim(), "index dimension mismatch"); + return *(static_cast(base_type::mutable_data()) + base_type::byte_offset(size_t(index)...) / itemsize()); } /// Ensure that the argument is a NumPy array of the correct dtype. @@ -811,7 +846,7 @@ inline PYBIND11_NOINLINE void register_structured_dtype( // Sanity check: verify that NumPy properly parses our buffer format string auto& api = npy_api::get(); - auto arr = array(buffer_info(nullptr, itemsize, format_str, 1)); + auto arr = array_base<>(buffer_info(nullptr, itemsize, format_str, 1)); if (!api.PyArray_EquivTypes_(dtype_ptr, arr.dtype().ptr())) pybind11_fail("NumPy: invalid buffer descriptor!"); @@ -1076,11 +1111,11 @@ struct vectorize_helper { template explicit vectorize_helper(T&&f) : f(std::forward(f)) { } - object operator()(array_t... args) { + object operator()(array_t::c_style | array_base<>::forcecast>... args) { return run(args..., make_index_sequence()); } - template object run(array_t&... args, index_sequence index) { + template object run(array_t::c_style | array_base<>::forcecast>&... args, index_sequence index) { /* Request buffers from all parameters */ const size_t N = sizeof...(Args);