Skip to content

Commit a483056

Browse files
committed
safe/unsafe access policy
1 parent 7830e85 commit a483056

File tree

1 file changed

+110
-73
lines changed

1 file changed

+110
-73
lines changed

include/pybind11/numpy.h

Lines changed: 110 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -316,19 +316,74 @@ class dtype : public object {
316316
}
317317
};
318318

319-
class array : public buffer {
319+
NAMESPACE_BEGIN(detail)
320+
[[noreturn]] PYBIND11_NOINLINE inline void fail_dim_check(size_t dim, size_t ndim, const std::string& msg) {
321+
throw index_error(msg + ": " + std::to_string(dim) + " (ndim = " + std::to_string(ndim) + ")");
322+
}
323+
NAMESPACE_END(detail)
324+
325+
class safe_access_policy {
320326
public:
321-
PYBIND11_OBJECT_CVT(array, buffer, detail::npy_api::get().PyArray_Check_, raw_array)
327+
void check_axis(size_t dim, size_t ndim) const {
328+
if(dim >= ndim) {
329+
detail::fail_dim_check(dim, ndim, "invalid axis");
330+
}
331+
}
332+
333+
template <typename... Ix>
334+
void check_indices(size_t ndim, Ix...) const {
335+
if(sizeof...(Ix) > ndim) {
336+
detail::fail_dim_check(sizeof...(Ix), ndim, "too many indices for an array");
337+
}
338+
}
339+
340+
template<typename... Ix>
341+
void check_dimensions(const size_t* shape, Ix... index) const {
342+
check_dimensions_impl(size_t(0), shape, size_t(index)...);
343+
}
344+
345+
private:
346+
void check_dimensions_impl(size_t, const size_t*) const { }
347+
348+
template<typename... Ix>
349+
void check_dimensions_impl(size_t axis, const size_t* shape, size_t i, Ix... index) const {
350+
if (i >= *shape) {
351+
throw index_error(std::string("index ") + std::to_string(i) +
352+
" is out of bounds for axis " + std::to_string(axis) +
353+
" with size " + std::to_string(*shape));
354+
}
355+
check_dimensions_impl(axis + 1, shape + 1, index...);
356+
}
357+
};
358+
359+
class unsafe_access_policy {
360+
public:
361+
void check_axis(size_t, size_t) const {
362+
}
363+
364+
template <typename... Ix>
365+
void check_indices(size_t, Ix...) const {
366+
}
367+
368+
template <typename... Ix>
369+
void check_dimensions(const size_t*, Ix...) const {
370+
}
371+
};
372+
373+
template <class access_policy = safe_access_policy>
374+
class array_base : public buffer, private access_policy {
375+
public:
376+
PYBIND11_OBJECT_CVT(array_base, buffer, detail::npy_api::get().PyArray_Check_, raw_array)
322377

323378
enum {
324379
c_style = detail::npy_api::NPY_C_CONTIGUOUS_,
325380
f_style = detail::npy_api::NPY_F_CONTIGUOUS_,
326381
forcecast = detail::npy_api::NPY_ARRAY_FORCECAST_
327382
};
328383

329-
array() : array(0, static_cast<const double *>(nullptr)) {}
384+
array_base() : array_base(0, static_cast<const double *>(nullptr)) {}
330385

331-
array(const pybind11::dtype &dt, const std::vector<size_t> &shape,
386+
array_base(const pybind11::dtype &dt, const std::vector<size_t> &shape,
332387
const std::vector<size_t> &strides, const void *ptr = nullptr,
333388
handle base = handle()) {
334389
auto& api = detail::npy_api::get();
@@ -339,9 +394,9 @@ class array : public buffer {
339394

340395
int flags = 0;
341396
if (base && ptr) {
342-
if (isinstance<array>(base))
397+
if (isinstance<array_base>(base))
343398
/* Copy flags from base (except baseship bit) */
344-
flags = reinterpret_borrow<array>(base).flags() & ~detail::npy_api::NPY_ARRAY_OWNDATA_;
399+
flags = reinterpret_borrow<array_base>(base).flags() & ~detail::npy_api::NPY_ARRAY_OWNDATA_;
345400
else
346401
/* Writable by default, easy to downgrade later on if needed */
347402
flags = detail::npy_api::NPY_ARRAY_WRITEABLE_;
@@ -362,30 +417,30 @@ class array : public buffer {
362417
m_ptr = tmp.release().ptr();
363418
}
364419

365-
array(const pybind11::dtype &dt, const std::vector<size_t> &shape,
420+
array_base(const pybind11::dtype &dt, const std::vector<size_t> &shape,
366421
const void *ptr = nullptr, handle base = handle())
367-
: array(dt, shape, default_strides(shape, dt.itemsize()), ptr, base) { }
422+
: array_base(dt, shape, default_strides(shape, dt.itemsize()), ptr, base) { }
368423

369-
array(const pybind11::dtype &dt, size_t count, const void *ptr = nullptr,
424+
array_base(const pybind11::dtype &dt, size_t count, const void *ptr = nullptr,
370425
handle base = handle())
371-
: array(dt, std::vector<size_t>{ count }, ptr, base) { }
426+
: array_base(dt, std::vector<size_t>{ count }, ptr, base) { }
372427

373-
template<typename T> array(const std::vector<size_t>& shape,
428+
template<typename T> array_base(const std::vector<size_t>& shape,
374429
const std::vector<size_t>& strides,
375430
const T* ptr, handle base = handle())
376-
: array(pybind11::dtype::of<T>(), shape, strides, (void *) ptr, base) { }
431+
: array_base(pybind11::dtype::of<T>(), shape, strides, (void *) ptr, base) { }
377432

378433
template <typename T>
379-
array(const std::vector<size_t> &shape, const T *ptr,
434+
array_base(const std::vector<size_t> &shape, const T *ptr,
380435
handle base = handle())
381-
: array(shape, default_strides(shape, sizeof(T)), ptr, base) { }
436+
: array_base(shape, default_strides(shape, sizeof(T)), ptr, base) { }
382437

383438
template <typename T>
384-
array(size_t count, const T *ptr, handle base = handle())
385-
: array(std::vector<size_t>{ count }, ptr, base) { }
439+
array_base(size_t count, const T *ptr, handle base = handle())
440+
: array_base(std::vector<size_t>{ count }, ptr, base) { }
386441

387-
explicit array(const buffer_info &info)
388-
: array(pybind11::dtype(info), info.shape, info.strides, info.ptr) { }
442+
explicit array_base(const buffer_info &info)
443+
: array_base(pybind11::dtype(info), info.shape, info.strides, info.ptr) { }
389444

390445
/// Array descriptor (dtype)
391446
pybind11::dtype dtype() const {
@@ -424,8 +479,7 @@ class array : public buffer {
424479

425480
/// Dimension along a given axis
426481
size_t shape(size_t dim) const {
427-
if (dim >= ndim())
428-
fail_dim_check(dim, "invalid axis");
482+
access_policy::check_axis(dim, ndim());
429483
return shape()[dim];
430484
}
431485

@@ -436,8 +490,7 @@ class array : public buffer {
436490

437491
/// Stride along a given axis
438492
size_t strides(size_t dim) const {
439-
if (dim >= ndim())
440-
fail_dim_check(dim, "invalid axis");
493+
access_policy::check_axis(dim, ndim());
441494
return strides()[dim];
442495
}
443496

@@ -473,8 +526,7 @@ class array : public buffer {
473526
/// Byte offset from beginning of the array to a given index (full or partial).
474527
/// May throw if the index would lead to out of bounds access.
475528
template<typename... Ix> size_t offset_at(Ix... index) const {
476-
if (sizeof...(index) > ndim())
477-
fail_dim_check(sizeof...(index), "too many indices for an array");
529+
access_policy::check_indices(ndim(), index...);
478530
return byte_offset(size_t(index)...);
479531
}
480532

@@ -487,15 +539,15 @@ class array : public buffer {
487539
}
488540

489541
/// Return a new view with all of the dimensions of length 1 removed
490-
array squeeze() {
542+
array_base squeeze() {
491543
auto& api = detail::npy_api::get();
492-
return reinterpret_steal<array>(api.PyArray_Squeeze_(m_ptr));
544+
return reinterpret_steal<array_base>(api.PyArray_Squeeze_(m_ptr));
493545
}
494546

495547
/// Ensure that the argument is a NumPy array
496548
/// In case of an error, nullptr is returned and the Python error is cleared.
497-
static array ensure(handle h, int ExtraFlags = 0) {
498-
auto result = reinterpret_steal<array>(raw_array(h.ptr(), ExtraFlags));
549+
static array_base ensure(handle h, int ExtraFlags = 0) {
550+
auto result = reinterpret_steal<array_base>(raw_array(h.ptr(), ExtraFlags));
499551
if (!result)
500552
PyErr_Clear();
501553
return result;
@@ -504,13 +556,8 @@ class array : public buffer {
504556
protected:
505557
template<typename, typename> friend struct detail::npy_format_descriptor;
506558

507-
void fail_dim_check(size_t dim, const std::string& msg) const {
508-
throw index_error(msg + ": " + std::to_string(dim) +
509-
" (ndim = " + std::to_string(ndim()) + ")");
510-
}
511-
512559
template<typename... Ix> size_t byte_offset(Ix... index) const {
513-
check_dimensions(index...);
560+
access_policy::check_dimensions(shape(), index...);
514561
return byte_offset_unsafe(index...);
515562
}
516563

@@ -537,21 +584,6 @@ class array : public buffer {
537584
return strides;
538585
}
539586

540-
template<typename... Ix> void check_dimensions(Ix... index) const {
541-
check_dimensions_impl(size_t(0), shape(), size_t(index)...);
542-
}
543-
544-
void check_dimensions_impl(size_t, const size_t*) const { }
545-
546-
template<typename... Ix> void check_dimensions_impl(size_t axis, const size_t* shape, size_t i, Ix... index) const {
547-
if (i >= *shape) {
548-
throw index_error(std::string("index ") + std::to_string(i) +
549-
" is out of bounds for axis " + std::to_string(axis) +
550-
" with size " + std::to_string(*shape));
551-
}
552-
check_dimensions_impl(axis + 1, shape + 1, index...);
553-
}
554-
555587
/// Create array from any object -- always returns a new reference
556588
static PyObject *raw_array(PyObject *ptr, int ExtraFlags = 0) {
557589
if (ptr == nullptr)
@@ -561,64 +593,69 @@ class array : public buffer {
561593
}
562594
};
563595

564-
template <typename T, int ExtraFlags = array::forcecast> class array_t : public array {
596+
using array = array_base<safe_access_policy>;
597+
using array_unchecked = array_base<unsafe_access_policy>;
598+
599+
template <typename T, int ExtraFlags = array_base<>::forcecast, class access_policy = safe_access_policy>
600+
class array_t : public array_base<access_policy> {
565601
public:
566-
array_t() : array(0, static_cast<const T *>(nullptr)) {}
567-
array_t(handle h, borrowed_t) : array(h, borrowed) { }
568-
array_t(handle h, stolen_t) : array(h, stolen) { }
602+
using base_type = array_base<access_policy>;
603+
array_t() : base_type(0, static_cast<const T *>(nullptr)) {}
604+
array_t(handle h, object::borrowed_t) : base_type(h, object::borrowed) { }
605+
array_t(handle h, object::stolen_t) : base_type(h, object::stolen) { }
569606

570607
PYBIND11_DEPRECATED("Use array_t<T>::ensure() instead")
571-
array_t(handle h, bool is_borrowed) : array(raw_array_t(h.ptr()), stolen) {
572-
if (!m_ptr) PyErr_Clear();
608+
array_t(handle h, bool is_borrowed) : base_type(raw_array_t(h.ptr()), object::stolen) {
609+
if (!this->m_ptr) PyErr_Clear();
573610
if (!is_borrowed) Py_XDECREF(h.ptr());
574611
}
575612

576-
array_t(const object &o) : array(raw_array_t(o.ptr()), stolen) {
577-
if (!m_ptr) throw error_already_set();
613+
array_t(const object &o) : base_type(raw_array_t(o.ptr()), object::stolen) {
614+
if (!this->m_ptr) throw error_already_set();
578615
}
579616

580-
explicit array_t(const buffer_info& info) : array(info) { }
617+
explicit array_t(const buffer_info& info) : base_type(info) { }
581618

582619
array_t(const std::vector<size_t> &shape,
583620
const std::vector<size_t> &strides, const T *ptr = nullptr,
584621
handle base = handle())
585-
: array(shape, strides, ptr, base) { }
622+
: base_type(shape, strides, ptr, base) { }
586623

587624
explicit array_t(const std::vector<size_t> &shape, const T *ptr = nullptr,
588625
handle base = handle())
589-
: array(shape, ptr, base) { }
626+
: base_type(shape, ptr, base) { }
590627

591628
explicit array_t(size_t count, const T *ptr = nullptr, handle base = handle())
592-
: array(count, ptr, base) { }
629+
: base_type(count, ptr, base) { }
593630

594631
constexpr size_t itemsize() const {
595632
return sizeof(T);
596633
}
597634

598635
template<typename... Ix> size_t index_at(Ix... index) const {
599-
return offset_at(index...) / itemsize();
636+
return base_type::offset_at(index...) / itemsize();
600637
}
601638

602639
template<typename... Ix> const T* data(Ix... index) const {
603-
return static_cast<const T*>(array::data(index...));
640+
return static_cast<const T*>(base_type::data(index...));
604641
}
605642

606643
template<typename... Ix> T* mutable_data(Ix... index) {
607-
return static_cast<T*>(array::mutable_data(index...));
644+
return static_cast<T*>(base_type::mutable_data(index...));
608645
}
609646

610647
// Reference to element at a given index
611648
template<typename... Ix> const T& at(Ix... index) const {
612-
if (sizeof...(index) != ndim())
613-
fail_dim_check(sizeof...(index), "index dimension mismatch");
614-
return *(static_cast<const T*>(array::data()) + byte_offset(size_t(index)...) / itemsize());
649+
if (sizeof...(index) != base_type::ndim())
650+
detail::fail_dim_check(sizeof...(index), base_type::ndim(), "index dimension mismatch");
651+
return *(static_cast<const T*>(base_type::data()) + base_type::byte_offset(size_t(index)...) / itemsize());
615652
}
616653

617654
// Mutable reference to element at a given index
618655
template<typename... Ix> T& mutable_at(Ix... index) {
619-
if (sizeof...(index) != ndim())
620-
fail_dim_check(sizeof...(index), "index dimension mismatch");
621-
return *(static_cast<T*>(array::mutable_data()) + byte_offset(size_t(index)...) / itemsize());
656+
if (sizeof...(index) != base_type::ndim())
657+
detail::fail_dim_check(sizeof...(index), base_type::ndim(), "index dimension mismatch");
658+
return *(static_cast<T*>(base_type::mutable_data()) + base_type::byte_offset(size_t(index)...) / itemsize());
622659
}
623660

624661
/// Ensure that the argument is a NumPy array of the correct dtype.
@@ -811,7 +848,7 @@ inline PYBIND11_NOINLINE void register_structured_dtype(
811848

812849
// Sanity check: verify that NumPy properly parses our buffer format string
813850
auto& api = npy_api::get();
814-
auto arr = array(buffer_info(nullptr, itemsize, format_str, 1));
851+
auto arr = array_base<>(buffer_info(nullptr, itemsize, format_str, 1));
815852
if (!api.PyArray_EquivTypes_(dtype_ptr, arr.dtype().ptr()))
816853
pybind11_fail("NumPy: invalid buffer descriptor!");
817854

@@ -1076,11 +1113,11 @@ struct vectorize_helper {
10761113
template <typename T>
10771114
explicit vectorize_helper(T&&f) : f(std::forward<T>(f)) { }
10781115

1079-
object operator()(array_t<Args, array::c_style | array::forcecast>... args) {
1116+
object operator()(array_t<Args, array_base<>::c_style | array_base<>::forcecast>... args) {
10801117
return run(args..., make_index_sequence<sizeof...(Args)>());
10811118
}
10821119

1083-
template <size_t ... Index> object run(array_t<Args, array::c_style | array::forcecast>&... args, index_sequence<Index...> index) {
1120+
template <size_t ... Index> object run(array_t<Args, array_base<>::c_style | array_base<>::forcecast>&... args, index_sequence<Index...> index) {
10841121
/* Request buffers from all parameters */
10851122
const size_t N = sizeof...(Args);
10861123

0 commit comments

Comments
 (0)