Skip to content

Unsafe access #617

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 108 additions & 73 deletions include/pybind11/numpy.h
Original file line number Diff line number Diff line change
Expand Up @@ -316,19 +316,72 @@ class dtype : public object {
}
};

class array : public buffer {
NAMESPACE_BEGIN(detail)
[[noreturn]] PYBIND11_NOINLINE inline void fail_dim_check(size_t dim, size_t ndim, const std::string& msg) {
throw index_error(msg + ": " + std::to_string(dim) + " (ndim = " + std::to_string(ndim) + ")");
}
NAMESPACE_END(detail)

class safe_access_policy {
public:
void check_axis(size_t dim, size_t ndim) const {
if (dim >= ndim)
detail::fail_dim_check(dim, ndim, "invalid axis");
}

template <typename... Ix>
void check_indices(size_t ndim, Ix...) const {
if (sizeof...(Ix) > ndim)
detail::fail_dim_check(sizeof...(Ix), ndim, "too many indices for an array");
}

template<typename... Ix>
void check_dimensions(const size_t* shape, Ix... index) const {
check_dimensions_impl(size_t(0), shape, size_t(index)...);
}

private:
void check_dimensions_impl(size_t, const size_t*) const { }

template<typename... Ix>
void check_dimensions_impl(size_t axis, const size_t* shape, size_t i, Ix... index) const {
if (i >= *shape) {
throw index_error(std::string("index ") + std::to_string(i) +
" is out of bounds for axis " + std::to_string(axis) +
" with size " + std::to_string(*shape));
}
check_dimensions_impl(axis + 1, shape + 1, index...);
}
};

class unsafe_access_policy {
public:
PYBIND11_OBJECT_CVT(array, buffer, detail::npy_api::get().PyArray_Check_, raw_array)
void check_axis(size_t, size_t) const {
}

template <typename... Ix>
void check_indices(size_t, Ix...) const {
}

template <typename... Ix>
void check_dimensions(const size_t*, Ix...) const {
}
};

template <class access_policy = safe_access_policy>
class array_base : public buffer, private access_policy {
public:
PYBIND11_OBJECT_CVT(array_base, buffer, detail::npy_api::get().PyArray_Check_, raw_array)

enum {
c_style = detail::npy_api::NPY_C_CONTIGUOUS_,
f_style = detail::npy_api::NPY_F_CONTIGUOUS_,
forcecast = detail::npy_api::NPY_ARRAY_FORCECAST_
};

array() : array(0, static_cast<const double *>(nullptr)) {}
array_base() : array_base(0, static_cast<const double *>(nullptr)) {}

array(const pybind11::dtype &dt, const std::vector<size_t> &shape,
array_base(const pybind11::dtype &dt, const std::vector<size_t> &shape,
const std::vector<size_t> &strides, const void *ptr = nullptr,
handle base = handle()) {
auto& api = detail::npy_api::get();
Expand All @@ -339,9 +392,9 @@ class array : public buffer {

int flags = 0;
if (base && ptr) {
if (isinstance<array>(base))
if (isinstance<array_base>(base))
/* Copy flags from base (except baseship bit) */
flags = reinterpret_borrow<array>(base).flags() & ~detail::npy_api::NPY_ARRAY_OWNDATA_;
flags = reinterpret_borrow<array_base>(base).flags() & ~detail::npy_api::NPY_ARRAY_OWNDATA_;
else
/* Writable by default, easy to downgrade later on if needed */
flags = detail::npy_api::NPY_ARRAY_WRITEABLE_;
Expand All @@ -362,30 +415,30 @@ class array : public buffer {
m_ptr = tmp.release().ptr();
}

array(const pybind11::dtype &dt, const std::vector<size_t> &shape,
array_base(const pybind11::dtype &dt, const std::vector<size_t> &shape,
const void *ptr = nullptr, handle base = handle())
: array(dt, shape, default_strides(shape, dt.itemsize()), ptr, base) { }
: array_base(dt, shape, default_strides(shape, dt.itemsize()), ptr, base) { }

array(const pybind11::dtype &dt, size_t count, const void *ptr = nullptr,
array_base(const pybind11::dtype &dt, size_t count, const void *ptr = nullptr,
handle base = handle())
: array(dt, std::vector<size_t>{ count }, ptr, base) { }
: array_base(dt, std::vector<size_t>{ count }, ptr, base) { }

template<typename T> array(const std::vector<size_t>& shape,
template<typename T> array_base(const std::vector<size_t>& shape,
const std::vector<size_t>& strides,
const T* ptr, handle base = handle())
: array(pybind11::dtype::of<T>(), shape, strides, (void *) ptr, base) { }
: array_base(pybind11::dtype::of<T>(), shape, strides, (void *) ptr, base) { }

template <typename T>
array(const std::vector<size_t> &shape, const T *ptr,
array_base(const std::vector<size_t> &shape, const T *ptr,
handle base = handle())
: array(shape, default_strides(shape, sizeof(T)), ptr, base) { }
: array_base(shape, default_strides(shape, sizeof(T)), ptr, base) { }

template <typename T>
array(size_t count, const T *ptr, handle base = handle())
: array(std::vector<size_t>{ count }, ptr, base) { }
array_base(size_t count, const T *ptr, handle base = handle())
: array_base(std::vector<size_t>{ count }, ptr, base) { }

explicit array(const buffer_info &info)
: array(pybind11::dtype(info), info.shape, info.strides, info.ptr) { }
explicit array_base(const buffer_info &info)
: array_base(pybind11::dtype(info), info.shape, info.strides, info.ptr) { }

/// Array descriptor (dtype)
pybind11::dtype dtype() const {
Expand Down Expand Up @@ -424,8 +477,7 @@ class array : public buffer {

/// Dimension along a given axis
size_t shape(size_t dim) const {
if (dim >= ndim())
fail_dim_check(dim, "invalid axis");
access_policy::check_axis(dim, ndim());
return shape()[dim];
}

Expand All @@ -436,8 +488,7 @@ class array : public buffer {

/// Stride along a given axis
size_t strides(size_t dim) const {
if (dim >= ndim())
fail_dim_check(dim, "invalid axis");
access_policy::check_axis(dim, ndim());
return strides()[dim];
}

Expand Down Expand Up @@ -473,8 +524,7 @@ class array : public buffer {
/// Byte offset from beginning of the array to a given index (full or partial).
/// May throw if the index would lead to out of bounds access.
template<typename... Ix> size_t offset_at(Ix... index) const {
if (sizeof...(index) > ndim())
fail_dim_check(sizeof...(index), "too many indices for an array");
access_policy::check_indices(ndim(), index...);
return byte_offset(size_t(index)...);
}

Expand All @@ -487,15 +537,15 @@ class array : public buffer {
}

/// Return a new view with all of the dimensions of length 1 removed
array squeeze() {
array_base squeeze() {
auto& api = detail::npy_api::get();
return reinterpret_steal<array>(api.PyArray_Squeeze_(m_ptr));
return reinterpret_steal<array_base>(api.PyArray_Squeeze_(m_ptr));
}

/// Ensure that the argument is a NumPy array
/// In case of an error, nullptr is returned and the Python error is cleared.
static array ensure(handle h, int ExtraFlags = 0) {
auto result = reinterpret_steal<array>(raw_array(h.ptr(), ExtraFlags));
static array_base ensure(handle h, int ExtraFlags = 0) {
auto result = reinterpret_steal<array_base>(raw_array(h.ptr(), ExtraFlags));
if (!result)
PyErr_Clear();
return result;
Expand All @@ -504,13 +554,8 @@ class array : public buffer {
protected:
template<typename, typename> friend struct detail::npy_format_descriptor;

void fail_dim_check(size_t dim, const std::string& msg) const {
throw index_error(msg + ": " + std::to_string(dim) +
" (ndim = " + std::to_string(ndim()) + ")");
}

template<typename... Ix> size_t byte_offset(Ix... index) const {
check_dimensions(index...);
access_policy::check_dimensions(shape(), index...);
return byte_offset_unsafe(index...);
}

Expand All @@ -537,21 +582,6 @@ class array : public buffer {
return strides;
}

template<typename... Ix> void check_dimensions(Ix... index) const {
check_dimensions_impl(size_t(0), shape(), size_t(index)...);
}

void check_dimensions_impl(size_t, const size_t*) const { }

template<typename... Ix> void check_dimensions_impl(size_t axis, const size_t* shape, size_t i, Ix... index) const {
if (i >= *shape) {
throw index_error(std::string("index ") + std::to_string(i) +
" is out of bounds for axis " + std::to_string(axis) +
" with size " + std::to_string(*shape));
}
check_dimensions_impl(axis + 1, shape + 1, index...);
}

/// Create array from any object -- always returns a new reference
static PyObject *raw_array(PyObject *ptr, int ExtraFlags = 0) {
if (ptr == nullptr)
Expand All @@ -561,64 +591,69 @@ class array : public buffer {
}
};

template <typename T, int ExtraFlags = array::forcecast> class array_t : public array {
using array = array_base<safe_access_policy>;
using array_unchecked = array_base<unsafe_access_policy>;

template <typename T, int ExtraFlags = array_base<>::forcecast, class access_policy = safe_access_policy>
class array_t : public array_base<access_policy> {
public:
array_t() : array(0, static_cast<const T *>(nullptr)) {}
array_t(handle h, borrowed_t) : array(h, borrowed) { }
array_t(handle h, stolen_t) : array(h, stolen) { }
using base_type = array_base<access_policy>;
array_t() : base_type(0, static_cast<const T *>(nullptr)) {}
array_t(handle h, object::borrowed_t) : base_type(h, object::borrowed) { }
array_t(handle h, object::stolen_t) : base_type(h, object::stolen) { }

PYBIND11_DEPRECATED("Use array_t<T>::ensure() instead")
array_t(handle h, bool is_borrowed) : array(raw_array_t(h.ptr()), stolen) {
if (!m_ptr) PyErr_Clear();
array_t(handle h, bool is_borrowed) : base_type(raw_array_t(h.ptr()), object::stolen) {
if (!this->m_ptr) PyErr_Clear();
if (!is_borrowed) Py_XDECREF(h.ptr());
}

array_t(const object &o) : array(raw_array_t(o.ptr()), stolen) {
if (!m_ptr) throw error_already_set();
array_t(const object &o) : base_type(raw_array_t(o.ptr()), object::stolen) {
if (!this->m_ptr) throw error_already_set();
}

explicit array_t(const buffer_info& info) : array(info) { }
explicit array_t(const buffer_info& info) : base_type(info) { }

array_t(const std::vector<size_t> &shape,
const std::vector<size_t> &strides, const T *ptr = nullptr,
handle base = handle())
: array(shape, strides, ptr, base) { }
: base_type(shape, strides, ptr, base) { }

explicit array_t(const std::vector<size_t> &shape, const T *ptr = nullptr,
handle base = handle())
: array(shape, ptr, base) { }
: base_type(shape, ptr, base) { }

explicit array_t(size_t count, const T *ptr = nullptr, handle base = handle())
: array(count, ptr, base) { }
: base_type(count, ptr, base) { }

constexpr size_t itemsize() const {
return sizeof(T);
}

template<typename... Ix> size_t index_at(Ix... index) const {
return offset_at(index...) / itemsize();
return base_type::offset_at(index...) / itemsize();
}

template<typename... Ix> const T* data(Ix... index) const {
return static_cast<const T*>(array::data(index...));
return static_cast<const T*>(base_type::data(index...));
}

template<typename... Ix> T* mutable_data(Ix... index) {
return static_cast<T*>(array::mutable_data(index...));
return static_cast<T*>(base_type::mutable_data(index...));
}

// Reference to element at a given index
template<typename... Ix> const T& at(Ix... index) const {
if (sizeof...(index) != ndim())
fail_dim_check(sizeof...(index), "index dimension mismatch");
return *(static_cast<const T*>(array::data()) + byte_offset(size_t(index)...) / itemsize());
if (sizeof...(index) != base_type::ndim())
detail::fail_dim_check(sizeof...(index), base_type::ndim(), "index dimension mismatch");
return *(static_cast<const T*>(base_type::data()) + base_type::byte_offset(size_t(index)...) / itemsize());
}

// Mutable reference to element at a given index
template<typename... Ix> T& mutable_at(Ix... index) {
if (sizeof...(index) != ndim())
fail_dim_check(sizeof...(index), "index dimension mismatch");
return *(static_cast<T*>(array::mutable_data()) + byte_offset(size_t(index)...) / itemsize());
if (sizeof...(index) != base_type::ndim())
detail::fail_dim_check(sizeof...(index), base_type::ndim(), "index dimension mismatch");
return *(static_cast<T*>(base_type::mutable_data()) + base_type::byte_offset(size_t(index)...) / itemsize());
}

/// Ensure that the argument is a NumPy array of the correct dtype.
Expand Down Expand Up @@ -811,7 +846,7 @@ inline PYBIND11_NOINLINE void register_structured_dtype(

// Sanity check: verify that NumPy properly parses our buffer format string
auto& api = npy_api::get();
auto arr = array(buffer_info(nullptr, itemsize, format_str, 1));
auto arr = array_base<>(buffer_info(nullptr, itemsize, format_str, 1));
if (!api.PyArray_EquivTypes_(dtype_ptr, arr.dtype().ptr()))
pybind11_fail("NumPy: invalid buffer descriptor!");

Expand Down Expand Up @@ -1076,11 +1111,11 @@ struct vectorize_helper {
template <typename T>
explicit vectorize_helper(T&&f) : f(std::forward<T>(f)) { }

object operator()(array_t<Args, array::c_style | array::forcecast>... args) {
object operator()(array_t<Args, array_base<>::c_style | array_base<>::forcecast>... args) {
return run(args..., make_index_sequence<sizeof...(Args)>());
}

template <size_t ... Index> object run(array_t<Args, array::c_style | array::forcecast>&... args, index_sequence<Index...> index) {
template <size_t ... Index> object run(array_t<Args, array_base<>::c_style | array_base<>::forcecast>&... args, index_sequence<Index...> index) {
/* Request buffers from all parameters */
const size_t N = sizeof...(Args);

Expand Down