@@ -316,19 +316,72 @@ class dtype : public object {
316
316
}
317
317
};
318
318
319
- class array : public buffer {
319
+ NAMESPACE_BEGIN (detail)
320
+ [[noreturn]] PYBIND11_NOINLINE inline void fail_dim_check(size_t dim, size_t ndim, const std::string& msg) {
321
+ throw index_error (msg + " : " + std::to_string (dim) + " (ndim = " + std::to_string (ndim) + " )" );
322
+ }
323
+ NAMESPACE_END (detail)
324
+
325
+ class safe_access_policy {
326
+ public:
327
+ void check_axis (size_t dim, size_t ndim) const {
328
+ if (dim >= ndim)
329
+ detail::fail_dim_check (dim, ndim, " invalid axis" );
330
+ }
331
+
332
+ template <typename ... Ix>
333
+ void check_indices (size_t ndim, Ix...) const {
334
+ if (sizeof ...(Ix) > ndim)
335
+ detail::fail_dim_check (sizeof ...(Ix), ndim, " too many indices for an array" );
336
+ }
337
+
338
+ template <typename ... Ix>
339
+ void check_dimensions (const size_t * shape, Ix... index ) const {
340
+ check_dimensions_impl (size_t (0 ), shape, size_t (index )...);
341
+ }
342
+
343
+ private:
344
+ void check_dimensions_impl (size_t , const size_t *) const { }
345
+
346
+ template <typename ... Ix>
347
+ void check_dimensions_impl (size_t axis, const size_t * shape, size_t i, Ix... index ) const {
348
+ if (i >= *shape) {
349
+ throw index_error (std::string (" index " ) + std::to_string (i) +
350
+ " is out of bounds for axis " + std::to_string (axis) +
351
+ " with size " + std::to_string (*shape));
352
+ }
353
+ check_dimensions_impl (axis + 1 , shape + 1 , index ...);
354
+ }
355
+ };
356
+
357
+ class unsafe_access_policy {
320
358
public:
321
- PYBIND11_OBJECT_CVT (array, buffer, detail::npy_api::get().PyArray_Check_, raw_array)
359
+ void check_axis (size_t , size_t ) const {
360
+ }
361
+
362
+ template <typename ... Ix>
363
+ void check_indices (size_t , Ix...) const {
364
+ }
365
+
366
+ template <typename ... Ix>
367
+ void check_dimensions (const size_t *, Ix...) const {
368
+ }
369
+ };
370
+
371
+ template <class access_policy = safe_access_policy>
372
+ class array_base : public buffer , private access_policy {
373
+ public:
374
+ PYBIND11_OBJECT_CVT (array_base, buffer, detail::npy_api::get().PyArray_Check_, raw_array)
322
375
323
376
enum {
324
377
c_style = detail::npy_api::NPY_C_CONTIGUOUS_,
325
378
f_style = detail::npy_api::NPY_F_CONTIGUOUS_,
326
379
forcecast = detail::npy_api::NPY_ARRAY_FORCECAST_
327
380
};
328
381
329
- array () : array (0 , static_cast <const double *>(nullptr )) {}
382
+ array_base () : array_base (0 , static_cast <const double *>(nullptr )) {}
330
383
331
- array (const pybind11::dtype &dt, const std::vector<size_t > &shape,
384
+ array_base (const pybind11::dtype &dt, const std::vector<size_t > &shape,
332
385
const std::vector<size_t > &strides, const void *ptr = nullptr ,
333
386
handle base = handle()) {
334
387
auto & api = detail::npy_api::get ();
@@ -339,9 +392,9 @@ class array : public buffer {
339
392
340
393
int flags = 0 ;
341
394
if (base && ptr) {
342
- if (isinstance<array >(base))
395
+ if (isinstance<array_base >(base))
343
396
/* Copy flags from base (except baseship bit) */
344
- flags = reinterpret_borrow<array >(base).flags () & ~detail::npy_api::NPY_ARRAY_OWNDATA_;
397
+ flags = reinterpret_borrow<array_base >(base).flags () & ~detail::npy_api::NPY_ARRAY_OWNDATA_;
345
398
else
346
399
/* Writable by default, easy to downgrade later on if needed */
347
400
flags = detail::npy_api::NPY_ARRAY_WRITEABLE_;
@@ -362,30 +415,30 @@ class array : public buffer {
362
415
m_ptr = tmp.release ().ptr ();
363
416
}
364
417
365
- array (const pybind11::dtype &dt, const std::vector<size_t > &shape,
418
+ array_base (const pybind11::dtype &dt, const std::vector<size_t > &shape,
366
419
const void *ptr = nullptr , handle base = handle())
367
- : array (dt, shape, default_strides(shape, dt.itemsize()), ptr, base) { }
420
+ : array_base (dt, shape, default_strides(shape, dt.itemsize()), ptr, base) { }
368
421
369
- array (const pybind11::dtype &dt, size_t count, const void *ptr = nullptr ,
422
+ array_base (const pybind11::dtype &dt, size_t count, const void *ptr = nullptr ,
370
423
handle base = handle())
371
- : array (dt, std::vector<size_t >{ count }, ptr, base) { }
424
+ : array_base (dt, std::vector<size_t >{ count }, ptr, base) { }
372
425
373
- template <typename T> array (const std::vector<size_t >& shape,
426
+ template <typename T> array_base (const std::vector<size_t >& shape,
374
427
const std::vector<size_t >& strides,
375
428
const T* ptr, handle base = handle())
376
- : array (pybind11::dtype::of<T>(), shape, strides, (void *) ptr, base) { }
429
+ : array_base (pybind11::dtype::of<T>(), shape, strides, (void *) ptr, base) { }
377
430
378
431
template <typename T>
379
- array (const std::vector<size_t > &shape, const T *ptr,
432
+ array_base (const std::vector<size_t > &shape, const T *ptr,
380
433
handle base = handle())
381
- : array (shape, default_strides(shape, sizeof (T)), ptr, base) { }
434
+ : array_base (shape, default_strides(shape, sizeof (T)), ptr, base) { }
382
435
383
436
template <typename T>
384
- array (size_t count, const T *ptr, handle base = handle())
385
- : array (std::vector<size_t >{ count }, ptr, base) { }
437
+ array_base (size_t count, const T *ptr, handle base = handle())
438
+ : array_base (std::vector<size_t >{ count }, ptr, base) { }
386
439
387
- explicit array (const buffer_info &info)
388
- : array (pybind11::dtype(info), info.shape, info.strides, info.ptr) { }
440
+ explicit array_base (const buffer_info &info)
441
+ : array_base (pybind11::dtype(info), info.shape, info.strides, info.ptr) { }
389
442
390
443
// / Array descriptor (dtype)
391
444
pybind11::dtype dtype () const {
@@ -424,8 +477,7 @@ class array : public buffer {
424
477
425
478
// / Dimension along a given axis
426
479
size_t shape (size_t dim) const {
427
- if (dim >= ndim ())
428
- fail_dim_check (dim, " invalid axis" );
480
+ access_policy::check_axis (dim, ndim ());
429
481
return shape ()[dim];
430
482
}
431
483
@@ -436,8 +488,7 @@ class array : public buffer {
436
488
437
489
// / Stride along a given axis
438
490
size_t strides (size_t dim) const {
439
- if (dim >= ndim ())
440
- fail_dim_check (dim, " invalid axis" );
491
+ access_policy::check_axis (dim, ndim ());
441
492
return strides ()[dim];
442
493
}
443
494
@@ -473,8 +524,7 @@ class array : public buffer {
473
524
// / Byte offset from beginning of the array to a given index (full or partial).
474
525
// / May throw if the index would lead to out of bounds access.
475
526
template <typename ... Ix> size_t offset_at (Ix... index) const {
476
- if (sizeof ...(index ) > ndim ())
477
- fail_dim_check (sizeof ...(index ), " too many indices for an array" );
527
+ access_policy::check_indices (ndim (), index ...);
478
528
return byte_offset (size_t (index )...);
479
529
}
480
530
@@ -487,15 +537,15 @@ class array : public buffer {
487
537
}
488
538
489
539
// / Return a new view with all of the dimensions of length 1 removed
490
- array squeeze () {
540
+ array_base squeeze () {
491
541
auto & api = detail::npy_api::get ();
492
- return reinterpret_steal<array >(api.PyArray_Squeeze_ (m_ptr));
542
+ return reinterpret_steal<array_base >(api.PyArray_Squeeze_ (m_ptr));
493
543
}
494
544
495
545
// / Ensure that the argument is a NumPy array
496
546
// / In case of an error, nullptr is returned and the Python error is cleared.
497
- static array ensure (handle h, int ExtraFlags = 0 ) {
498
- auto result = reinterpret_steal<array >(raw_array (h.ptr (), ExtraFlags));
547
+ static array_base ensure (handle h, int ExtraFlags = 0 ) {
548
+ auto result = reinterpret_steal<array_base >(raw_array (h.ptr (), ExtraFlags));
499
549
if (!result)
500
550
PyErr_Clear ();
501
551
return result;
@@ -504,13 +554,8 @@ class array : public buffer {
504
554
protected:
505
555
template <typename , typename > friend struct detail ::npy_format_descriptor;
506
556
507
- void fail_dim_check (size_t dim, const std::string& msg) const {
508
- throw index_error (msg + " : " + std::to_string (dim) +
509
- " (ndim = " + std::to_string (ndim ()) + " )" );
510
- }
511
-
512
557
template <typename ... Ix> size_t byte_offset (Ix... index) const {
513
- check_dimensions (index ...);
558
+ access_policy:: check_dimensions (shape (), index ...);
514
559
return byte_offset_unsafe (index ...);
515
560
}
516
561
@@ -537,21 +582,6 @@ class array : public buffer {
537
582
return strides;
538
583
}
539
584
540
- template <typename ... Ix> void check_dimensions (Ix... index) const {
541
- check_dimensions_impl (size_t (0 ), shape (), size_t (index )...);
542
- }
543
-
544
- void check_dimensions_impl (size_t , const size_t *) const { }
545
-
546
- template <typename ... Ix> void check_dimensions_impl (size_t axis, const size_t * shape, size_t i, Ix... index) const {
547
- if (i >= *shape) {
548
- throw index_error (std::string (" index " ) + std::to_string (i) +
549
- " is out of bounds for axis " + std::to_string (axis) +
550
- " with size " + std::to_string (*shape));
551
- }
552
- check_dimensions_impl (axis + 1 , shape + 1 , index ...);
553
- }
554
-
555
585
// / Create array from any object -- always returns a new reference
556
586
static PyObject *raw_array (PyObject *ptr, int ExtraFlags = 0 ) {
557
587
if (ptr == nullptr )
@@ -561,64 +591,69 @@ class array : public buffer {
561
591
}
562
592
};
563
593
564
- template <typename T, int ExtraFlags = array::forcecast> class array_t : public array {
594
+ using array = array_base<safe_access_policy>;
595
+ using array_unchecked = array_base<unsafe_access_policy>;
596
+
597
+ template <typename T, int ExtraFlags = array_base<>::forcecast, class access_policy = safe_access_policy>
598
+ class array_t : public array_base <access_policy> {
565
599
public:
566
- array_t () : array(0 , static_cast <const T *>(nullptr )) {}
567
- array_t (handle h, borrowed_t ) : array(h, borrowed) { }
568
- array_t (handle h, stolen_t ) : array(h, stolen) { }
600
+ using base_type = array_base<access_policy>;
601
+ array_t () : base_type(0 , static_cast <const T *>(nullptr )) {}
602
+ array_t (handle h, object::borrowed_t ) : base_type(h, object::borrowed) { }
603
+ array_t (handle h, object::stolen_t ) : base_type(h, object::stolen) { }
569
604
570
605
PYBIND11_DEPRECATED (" Use array_t<T>::ensure() instead" )
571
- array_t (handle h, bool is_borrowed) : array (raw_array_t (h.ptr()), stolen) {
572
- if (!m_ptr) PyErr_Clear ();
606
+ array_t (handle h, bool is_borrowed) : base_type (raw_array_t (h.ptr()), object:: stolen) {
607
+ if (!this -> m_ptr ) PyErr_Clear ();
573
608
if (!is_borrowed) Py_XDECREF (h.ptr ());
574
609
}
575
610
576
- array_t (const object &o) : array (raw_array_t (o.ptr()), stolen) {
577
- if (!m_ptr) throw error_already_set ();
611
+ array_t (const object &o) : base_type (raw_array_t (o.ptr()), object:: stolen) {
612
+ if (!this -> m_ptr ) throw error_already_set ();
578
613
}
579
614
580
- explicit array_t (const buffer_info& info) : array (info) { }
615
+ explicit array_t (const buffer_info& info) : base_type (info) { }
581
616
582
617
array_t (const std::vector<size_t > &shape,
583
618
const std::vector<size_t > &strides, const T *ptr = nullptr ,
584
619
handle base = handle())
585
- : array (shape, strides, ptr, base) { }
620
+ : base_type (shape, strides, ptr, base) { }
586
621
587
622
explicit array_t (const std::vector<size_t > &shape, const T *ptr = nullptr ,
588
623
handle base = handle())
589
- : array (shape, ptr, base) { }
624
+ : base_type (shape, ptr, base) { }
590
625
591
626
explicit array_t (size_t count, const T *ptr = nullptr , handle base = handle())
592
- : array (count, ptr, base) { }
627
+ : base_type (count, ptr, base) { }
593
628
594
629
constexpr size_t itemsize () const {
595
630
return sizeof (T);
596
631
}
597
632
598
633
template <typename ... Ix> size_t index_at (Ix... index) const {
599
- return offset_at (index ...) / itemsize ();
634
+ return base_type:: offset_at (index ...) / itemsize ();
600
635
}
601
636
602
637
template <typename ... Ix> const T* data (Ix... index) const {
603
- return static_cast <const T*>(array ::data (index ...));
638
+ return static_cast <const T*>(base_type ::data (index ...));
604
639
}
605
640
606
641
template <typename ... Ix> T* mutable_data (Ix... index) {
607
- return static_cast <T*>(array ::mutable_data (index ...));
642
+ return static_cast <T*>(base_type ::mutable_data (index ...));
608
643
}
609
644
610
645
// Reference to element at a given index
611
646
template <typename ... Ix> const T& at (Ix... index) const {
612
- if (sizeof ...(index ) != ndim ())
613
- fail_dim_check (sizeof ...(index ), " index dimension mismatch" );
614
- return *(static_cast <const T*>(array ::data ()) + byte_offset (size_t (index )...) / itemsize ());
647
+ if (sizeof ...(index ) != base_type:: ndim ())
648
+ detail:: fail_dim_check (sizeof ...(index ), base_type::ndim ( ), " index dimension mismatch" );
649
+ return *(static_cast <const T*>(base_type ::data ()) + base_type:: byte_offset (size_t (index )...) / itemsize ());
615
650
}
616
651
617
652
// Mutable reference to element at a given index
618
653
template <typename ... Ix> T& mutable_at (Ix... index) {
619
- if (sizeof ...(index ) != ndim ())
620
- fail_dim_check (sizeof ...(index ), " index dimension mismatch" );
621
- return *(static_cast <T*>(array ::mutable_data ()) + byte_offset (size_t (index )...) / itemsize ());
654
+ if (sizeof ...(index ) != base_type:: ndim ())
655
+ detail:: fail_dim_check (sizeof ...(index ), base_type::ndim ( ), " index dimension mismatch" );
656
+ return *(static_cast <T*>(base_type ::mutable_data ()) + base_type:: byte_offset (size_t (index )...) / itemsize ());
622
657
}
623
658
624
659
// / Ensure that the argument is a NumPy array of the correct dtype.
@@ -811,7 +846,7 @@ inline PYBIND11_NOINLINE void register_structured_dtype(
811
846
812
847
// Sanity check: verify that NumPy properly parses our buffer format string
813
848
auto & api = npy_api::get ();
814
- auto arr = array (buffer_info (nullptr , itemsize, format_str, 1 ));
849
+ auto arr = array_base<> (buffer_info (nullptr , itemsize, format_str, 1 ));
815
850
if (!api.PyArray_EquivTypes_ (dtype_ptr, arr.dtype ().ptr ()))
816
851
pybind11_fail (" NumPy: invalid buffer descriptor!" );
817
852
@@ -1076,11 +1111,11 @@ struct vectorize_helper {
1076
1111
template <typename T>
1077
1112
explicit vectorize_helper (T&&f) : f(std::forward<T>(f)) { }
1078
1113
1079
- object operator ()(array_t <Args, array ::c_style | array ::forcecast>... args) {
1114
+ object operator ()(array_t <Args, array_base<> ::c_style | array_base<> ::forcecast>... args) {
1080
1115
return run (args..., make_index_sequence<sizeof ...(Args)>());
1081
1116
}
1082
1117
1083
- template <size_t ... Index> object run (array_t <Args, array ::c_style | array ::forcecast>&... args, index_sequence<Index...> index) {
1118
+ template <size_t ... Index> object run (array_t <Args, array_base<> ::c_style | array_base<> ::forcecast>&... args, index_sequence<Index...> index) {
1084
1119
/* Request buffers from all parameters */
1085
1120
const size_t N = sizeof ...(Args);
1086
1121
0 commit comments