Skip to content

Commit 46ff174

Browse files
committed
safe/unsafe access policy
1 parent 7830e85 commit 46ff174

File tree

1 file changed

+108
-73
lines changed

1 file changed

+108
-73
lines changed

include/pybind11/numpy.h

Lines changed: 108 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -316,19 +316,72 @@ class dtype : public object {
316316
}
317317
};
318318

319-
class array : public buffer {
319+
NAMESPACE_BEGIN(detail)
320+
[[noreturn]] PYBIND11_NOINLINE inline void fail_dim_check(size_t dim, size_t ndim, const std::string& msg) {
321+
throw index_error(msg + ": " + std::to_string(dim) + " (ndim = " + std::to_string(ndim) + ")");
322+
}
323+
NAMESPACE_END(detail)
324+
325+
class safe_access_policy {
326+
public:
327+
void check_axis(size_t dim, size_t ndim) const {
328+
if(dim >= ndim)
329+
detail::fail_dim_check(dim, ndim, "invalid axis");
330+
}
331+
332+
template <typename... Ix>
333+
void check_indices(size_t ndim, Ix...) const {
334+
if(sizeof...(Ix) > ndim)
335+
detail::fail_dim_check(sizeof...(Ix), ndim, "too many indices for an array");
336+
}
337+
338+
template<typename... Ix>
339+
void check_dimensions(const size_t* shape, Ix... index) const {
340+
check_dimensions_impl(size_t(0), shape, size_t(index)...);
341+
}
342+
343+
private:
344+
void check_dimensions_impl(size_t, const size_t*) const { }
345+
346+
template<typename... Ix>
347+
void check_dimensions_impl(size_t axis, const size_t* shape, size_t i, Ix... index) const {
348+
if (i >= *shape) {
349+
throw index_error(std::string("index ") + std::to_string(i) +
350+
" is out of bounds for axis " + std::to_string(axis) +
351+
" with size " + std::to_string(*shape));
352+
}
353+
check_dimensions_impl(axis + 1, shape + 1, index...);
354+
}
355+
};
356+
357+
class unsafe_access_policy {
320358
public:
321-
PYBIND11_OBJECT_CVT(array, buffer, detail::npy_api::get().PyArray_Check_, raw_array)
359+
void check_axis(size_t, size_t) const {
360+
}
361+
362+
template <typename... Ix>
363+
void check_indices(size_t, Ix...) const {
364+
}
365+
366+
template <typename... Ix>
367+
void check_dimensions(const size_t*, Ix...) const {
368+
}
369+
};
370+
371+
template <class access_policy = safe_access_policy>
372+
class array_base : public buffer, private access_policy {
373+
public:
374+
PYBIND11_OBJECT_CVT(array_base, buffer, detail::npy_api::get().PyArray_Check_, raw_array)
322375

323376
enum {
324377
c_style = detail::npy_api::NPY_C_CONTIGUOUS_,
325378
f_style = detail::npy_api::NPY_F_CONTIGUOUS_,
326379
forcecast = detail::npy_api::NPY_ARRAY_FORCECAST_
327380
};
328381

329-
array() : array(0, static_cast<const double *>(nullptr)) {}
382+
array_base() : array_base(0, static_cast<const double *>(nullptr)) {}
330383

331-
array(const pybind11::dtype &dt, const std::vector<size_t> &shape,
384+
array_base(const pybind11::dtype &dt, const std::vector<size_t> &shape,
332385
const std::vector<size_t> &strides, const void *ptr = nullptr,
333386
handle base = handle()) {
334387
auto& api = detail::npy_api::get();
@@ -339,9 +392,9 @@ class array : public buffer {
339392

340393
int flags = 0;
341394
if (base && ptr) {
342-
if (isinstance<array>(base))
395+
if (isinstance<array_base>(base))
343396
/* Copy flags from base (except baseship bit) */
344-
flags = reinterpret_borrow<array>(base).flags() & ~detail::npy_api::NPY_ARRAY_OWNDATA_;
397+
flags = reinterpret_borrow<array_base>(base).flags() & ~detail::npy_api::NPY_ARRAY_OWNDATA_;
345398
else
346399
/* Writable by default, easy to downgrade later on if needed */
347400
flags = detail::npy_api::NPY_ARRAY_WRITEABLE_;
@@ -362,30 +415,30 @@ class array : public buffer {
362415
m_ptr = tmp.release().ptr();
363416
}
364417

365-
array(const pybind11::dtype &dt, const std::vector<size_t> &shape,
418+
array_base(const pybind11::dtype &dt, const std::vector<size_t> &shape,
366419
const void *ptr = nullptr, handle base = handle())
367-
: array(dt, shape, default_strides(shape, dt.itemsize()), ptr, base) { }
420+
: array_base(dt, shape, default_strides(shape, dt.itemsize()), ptr, base) { }
368421

369-
array(const pybind11::dtype &dt, size_t count, const void *ptr = nullptr,
422+
array_base(const pybind11::dtype &dt, size_t count, const void *ptr = nullptr,
370423
handle base = handle())
371-
: array(dt, std::vector<size_t>{ count }, ptr, base) { }
424+
: array_base(dt, std::vector<size_t>{ count }, ptr, base) { }
372425

373-
template<typename T> array(const std::vector<size_t>& shape,
426+
template<typename T> array_base(const std::vector<size_t>& shape,
374427
const std::vector<size_t>& strides,
375428
const T* ptr, handle base = handle())
376-
: array(pybind11::dtype::of<T>(), shape, strides, (void *) ptr, base) { }
429+
: array_base(pybind11::dtype::of<T>(), shape, strides, (void *) ptr, base) { }
377430

378431
template <typename T>
379-
array(const std::vector<size_t> &shape, const T *ptr,
432+
array_base(const std::vector<size_t> &shape, const T *ptr,
380433
handle base = handle())
381-
: array(shape, default_strides(shape, sizeof(T)), ptr, base) { }
434+
: array_base(shape, default_strides(shape, sizeof(T)), ptr, base) { }
382435

383436
template <typename T>
384-
array(size_t count, const T *ptr, handle base = handle())
385-
: array(std::vector<size_t>{ count }, ptr, base) { }
437+
array_base(size_t count, const T *ptr, handle base = handle())
438+
: array_base(std::vector<size_t>{ count }, ptr, base) { }
386439

387-
explicit array(const buffer_info &info)
388-
: array(pybind11::dtype(info), info.shape, info.strides, info.ptr) { }
440+
explicit array_base(const buffer_info &info)
441+
: array_base(pybind11::dtype(info), info.shape, info.strides, info.ptr) { }
389442

390443
/// Array descriptor (dtype)
391444
pybind11::dtype dtype() const {
@@ -424,8 +477,7 @@ class array : public buffer {
424477

425478
/// Dimension along a given axis
426479
size_t shape(size_t dim) const {
427-
if (dim >= ndim())
428-
fail_dim_check(dim, "invalid axis");
480+
access_policy::check_axis(dim, ndim());
429481
return shape()[dim];
430482
}
431483

@@ -436,8 +488,7 @@ class array : public buffer {
436488

437489
/// Stride along a given axis
438490
size_t strides(size_t dim) const {
439-
if (dim >= ndim())
440-
fail_dim_check(dim, "invalid axis");
491+
access_policy::check_axis(dim, ndim());
441492
return strides()[dim];
442493
}
443494

@@ -473,8 +524,7 @@ class array : public buffer {
473524
/// Byte offset from beginning of the array to a given index (full or partial).
474525
/// May throw if the index would lead to out of bounds access.
475526
template<typename... Ix> size_t offset_at(Ix... index) const {
476-
if (sizeof...(index) > ndim())
477-
fail_dim_check(sizeof...(index), "too many indices for an array");
527+
access_policy::check_indices(ndim(), index...);
478528
return byte_offset(size_t(index)...);
479529
}
480530

@@ -487,15 +537,15 @@ class array : public buffer {
487537
}
488538

489539
/// Return a new view with all of the dimensions of length 1 removed
490-
array squeeze() {
540+
array_base squeeze() {
491541
auto& api = detail::npy_api::get();
492-
return reinterpret_steal<array>(api.PyArray_Squeeze_(m_ptr));
542+
return reinterpret_steal<array_base>(api.PyArray_Squeeze_(m_ptr));
493543
}
494544

495545
/// Ensure that the argument is a NumPy array
496546
/// In case of an error, nullptr is returned and the Python error is cleared.
497-
static array ensure(handle h, int ExtraFlags = 0) {
498-
auto result = reinterpret_steal<array>(raw_array(h.ptr(), ExtraFlags));
547+
static array_base ensure(handle h, int ExtraFlags = 0) {
548+
auto result = reinterpret_steal<array_base>(raw_array(h.ptr(), ExtraFlags));
499549
if (!result)
500550
PyErr_Clear();
501551
return result;
@@ -504,13 +554,8 @@ class array : public buffer {
504554
protected:
505555
template<typename, typename> friend struct detail::npy_format_descriptor;
506556

507-
void fail_dim_check(size_t dim, const std::string& msg) const {
508-
throw index_error(msg + ": " + std::to_string(dim) +
509-
" (ndim = " + std::to_string(ndim()) + ")");
510-
}
511-
512557
template<typename... Ix> size_t byte_offset(Ix... index) const {
513-
check_dimensions(index...);
558+
access_policy::check_dimensions(shape(), index...);
514559
return byte_offset_unsafe(index...);
515560
}
516561

@@ -537,21 +582,6 @@ class array : public buffer {
537582
return strides;
538583
}
539584

540-
template<typename... Ix> void check_dimensions(Ix... index) const {
541-
check_dimensions_impl(size_t(0), shape(), size_t(index)...);
542-
}
543-
544-
void check_dimensions_impl(size_t, const size_t*) const { }
545-
546-
template<typename... Ix> void check_dimensions_impl(size_t axis, const size_t* shape, size_t i, Ix... index) const {
547-
if (i >= *shape) {
548-
throw index_error(std::string("index ") + std::to_string(i) +
549-
" is out of bounds for axis " + std::to_string(axis) +
550-
" with size " + std::to_string(*shape));
551-
}
552-
check_dimensions_impl(axis + 1, shape + 1, index...);
553-
}
554-
555585
/// Create array from any object -- always returns a new reference
556586
static PyObject *raw_array(PyObject *ptr, int ExtraFlags = 0) {
557587
if (ptr == nullptr)
@@ -561,64 +591,69 @@ class array : public buffer {
561591
}
562592
};
563593

564-
template <typename T, int ExtraFlags = array::forcecast> class array_t : public array {
594+
using array = array_base<safe_access_policy>;
595+
using array_unchecked = array_base<unsafe_access_policy>;
596+
597+
template <typename T, int ExtraFlags = array_base<>::forcecast, class access_policy = safe_access_policy>
598+
class array_t : public array_base<access_policy> {
565599
public:
566-
array_t() : array(0, static_cast<const T *>(nullptr)) {}
567-
array_t(handle h, borrowed_t) : array(h, borrowed) { }
568-
array_t(handle h, stolen_t) : array(h, stolen) { }
600+
using base_type = array_base<access_policy>;
601+
array_t() : base_type(0, static_cast<const T *>(nullptr)) {}
602+
array_t(handle h, object::borrowed_t) : base_type(h, object::borrowed) { }
603+
array_t(handle h, object::stolen_t) : base_type(h, object::stolen) { }
569604

570605
PYBIND11_DEPRECATED("Use array_t<T>::ensure() instead")
571-
array_t(handle h, bool is_borrowed) : array(raw_array_t(h.ptr()), stolen) {
572-
if (!m_ptr) PyErr_Clear();
606+
array_t(handle h, bool is_borrowed) : base_type(raw_array_t(h.ptr()), object::stolen) {
607+
if (!this->m_ptr) PyErr_Clear();
573608
if (!is_borrowed) Py_XDECREF(h.ptr());
574609
}
575610

576-
array_t(const object &o) : array(raw_array_t(o.ptr()), stolen) {
577-
if (!m_ptr) throw error_already_set();
611+
array_t(const object &o) : base_type(raw_array_t(o.ptr()), object::stolen) {
612+
if (!this->m_ptr) throw error_already_set();
578613
}
579614

580-
explicit array_t(const buffer_info& info) : array(info) { }
615+
explicit array_t(const buffer_info& info) : base_type(info) { }
581616

582617
array_t(const std::vector<size_t> &shape,
583618
const std::vector<size_t> &strides, const T *ptr = nullptr,
584619
handle base = handle())
585-
: array(shape, strides, ptr, base) { }
620+
: base_type(shape, strides, ptr, base) { }
586621

587622
explicit array_t(const std::vector<size_t> &shape, const T *ptr = nullptr,
588623
handle base = handle())
589-
: array(shape, ptr, base) { }
624+
: base_type(shape, ptr, base) { }
590625

591626
explicit array_t(size_t count, const T *ptr = nullptr, handle base = handle())
592-
: array(count, ptr, base) { }
627+
: base_type(count, ptr, base) { }
593628

594629
constexpr size_t itemsize() const {
595630
return sizeof(T);
596631
}
597632

598633
template<typename... Ix> size_t index_at(Ix... index) const {
599-
return offset_at(index...) / itemsize();
634+
return base_type::offset_at(index...) / itemsize();
600635
}
601636

602637
template<typename... Ix> const T* data(Ix... index) const {
603-
return static_cast<const T*>(array::data(index...));
638+
return static_cast<const T*>(base_type::data(index...));
604639
}
605640

606641
template<typename... Ix> T* mutable_data(Ix... index) {
607-
return static_cast<T*>(array::mutable_data(index...));
642+
return static_cast<T*>(base_type::mutable_data(index...));
608643
}
609644

610645
// Reference to element at a given index
611646
template<typename... Ix> const T& at(Ix... index) const {
612-
if (sizeof...(index) != ndim())
613-
fail_dim_check(sizeof...(index), "index dimension mismatch");
614-
return *(static_cast<const T*>(array::data()) + byte_offset(size_t(index)...) / itemsize());
647+
if (sizeof...(index) != base_type::ndim())
648+
detail::fail_dim_check(sizeof...(index), base_type::ndim(), "index dimension mismatch");
649+
return *(static_cast<const T*>(base_type::data()) + base_type::byte_offset(size_t(index)...) / itemsize());
615650
}
616651

617652
// Mutable reference to element at a given index
618653
template<typename... Ix> T& mutable_at(Ix... index) {
619-
if (sizeof...(index) != ndim())
620-
fail_dim_check(sizeof...(index), "index dimension mismatch");
621-
return *(static_cast<T*>(array::mutable_data()) + byte_offset(size_t(index)...) / itemsize());
654+
if (sizeof...(index) != base_type::ndim())
655+
detail::fail_dim_check(sizeof...(index), base_type::ndim(), "index dimension mismatch");
656+
return *(static_cast<T*>(base_type::mutable_data()) + base_type::byte_offset(size_t(index)...) / itemsize());
622657
}
623658

624659
/// Ensure that the argument is a NumPy array of the correct dtype.
@@ -811,7 +846,7 @@ inline PYBIND11_NOINLINE void register_structured_dtype(
811846

812847
// Sanity check: verify that NumPy properly parses our buffer format string
813848
auto& api = npy_api::get();
814-
auto arr = array(buffer_info(nullptr, itemsize, format_str, 1));
849+
auto arr = array_base<>(buffer_info(nullptr, itemsize, format_str, 1));
815850
if (!api.PyArray_EquivTypes_(dtype_ptr, arr.dtype().ptr()))
816851
pybind11_fail("NumPy: invalid buffer descriptor!");
817852

@@ -1076,11 +1111,11 @@ struct vectorize_helper {
10761111
template <typename T>
10771112
explicit vectorize_helper(T&&f) : f(std::forward<T>(f)) { }
10781113

1079-
object operator()(array_t<Args, array::c_style | array::forcecast>... args) {
1114+
object operator()(array_t<Args, array_base<>::c_style | array_base<>::forcecast>... args) {
10801115
return run(args..., make_index_sequence<sizeof...(Args)>());
10811116
}
10821117

1083-
template <size_t ... Index> object run(array_t<Args, array::c_style | array::forcecast>&... args, index_sequence<Index...> index) {
1118+
template <size_t ... Index> object run(array_t<Args, array_base<>::c_style | array_base<>::forcecast>&... args, index_sequence<Index...> index) {
10841119
/* Request buffers from all parameters */
10851120
const size_t N = sizeof...(Args);
10861121

0 commit comments

Comments
 (0)