@@ -316,19 +316,74 @@ class dtype : public object {
316
316
}
317
317
};
318
318
319
- class array : public buffer {
319
+ NAMESPACE_BEGIN (detail)
320
+ [[noreturn]] PYBIND11_NOINLINE inline void fail_dim_check(size_t dim, size_t ndim, const std::string& msg) {
321
+ throw index_error (msg + " : " + std::to_string (dim) + " (ndim = " + std::to_string (ndim) + " )" );
322
+ }
323
+ NAMESPACE_END (detail)
324
+
325
+ class safe_access_policy {
320
326
public:
321
- PYBIND11_OBJECT_CVT (array, buffer, detail::npy_api::get().PyArray_Check_, raw_array)
327
+ void check_axis (size_t dim, size_t ndim) const {
328
+ if (dim >= ndim) {
329
+ detail::fail_dim_check (dim, ndim, " invalid axis" );
330
+ }
331
+ }
332
+
333
+ template <typename ... Ix>
334
+ void check_indices (size_t ndim, Ix...) const {
335
+ if (sizeof ...(Ix) > ndim) {
336
+ detail::fail_dim_check (sizeof ...(Ix), ndim, " too many indices for an array" );
337
+ }
338
+ }
339
+
340
+ template <typename ... Ix>
341
+ void check_dimensions (const size_t * shape, Ix... index ) const {
342
+ check_dimensions_impl (size_t (0 ), shape, size_t (index )...);
343
+ }
344
+
345
+ private:
346
+ void check_dimensions_impl (size_t , const size_t *) const { }
347
+
348
+ template <typename ... Ix>
349
+ void check_dimensions_impl (size_t axis, const size_t * shape, size_t i, Ix... index ) const {
350
+ if (i >= *shape) {
351
+ throw index_error (std::string (" index " ) + std::to_string (i) +
352
+ " is out of bounds for axis " + std::to_string (axis) +
353
+ " with size " + std::to_string (*shape));
354
+ }
355
+ check_dimensions_impl (axis + 1 , shape + 1 , index ...);
356
+ }
357
+ };
358
+
359
+ class unsafe_access_policy {
360
+ public:
361
+ void check_axis (size_t , size_t ) const {
362
+ }
363
+
364
+ template <typename ... Ix>
365
+ void check_indices (size_t , Ix...) const {
366
+ }
367
+
368
+ template <typename ... Ix>
369
+ void check_dimensions (const size_t *, Ix...) const {
370
+ }
371
+ };
372
+
373
+ template <class access_policy = safe_access_policy>
374
+ class array_base : public buffer , private access_policy {
375
+ public:
376
+ PYBIND11_OBJECT_CVT (array_base, buffer, detail::npy_api::get().PyArray_Check_, raw_array)
322
377
323
378
enum {
324
379
c_style = detail::npy_api::NPY_C_CONTIGUOUS_,
325
380
f_style = detail::npy_api::NPY_F_CONTIGUOUS_,
326
381
forcecast = detail::npy_api::NPY_ARRAY_FORCECAST_
327
382
};
328
383
329
- array () : array (0 , static_cast <const double *>(nullptr )) {}
384
+ array_base () : array_base (0 , static_cast <const double *>(nullptr )) {}
330
385
331
- array (const pybind11::dtype &dt, const std::vector<size_t > &shape,
386
+ array_base (const pybind11::dtype &dt, const std::vector<size_t > &shape,
332
387
const std::vector<size_t > &strides, const void *ptr = nullptr ,
333
388
handle base = handle()) {
334
389
auto & api = detail::npy_api::get ();
@@ -339,9 +394,9 @@ class array : public buffer {
339
394
340
395
int flags = 0 ;
341
396
if (base && ptr) {
342
- if (isinstance<array >(base))
397
+ if (isinstance<array_base >(base))
343
398
/* Copy flags from base (except baseship bit) */
344
- flags = reinterpret_borrow<array >(base).flags () & ~detail::npy_api::NPY_ARRAY_OWNDATA_;
399
+ flags = reinterpret_borrow<array_base >(base).flags () & ~detail::npy_api::NPY_ARRAY_OWNDATA_;
345
400
else
346
401
/* Writable by default, easy to downgrade later on if needed */
347
402
flags = detail::npy_api::NPY_ARRAY_WRITEABLE_;
@@ -362,30 +417,30 @@ class array : public buffer {
362
417
m_ptr = tmp.release ().ptr ();
363
418
}
364
419
365
- array (const pybind11::dtype &dt, const std::vector<size_t > &shape,
420
+ array_base (const pybind11::dtype &dt, const std::vector<size_t > &shape,
366
421
const void *ptr = nullptr , handle base = handle())
367
- : array (dt, shape, default_strides(shape, dt.itemsize()), ptr, base) { }
422
+ : array_base (dt, shape, default_strides(shape, dt.itemsize()), ptr, base) { }
368
423
369
- array (const pybind11::dtype &dt, size_t count, const void *ptr = nullptr ,
424
+ array_base (const pybind11::dtype &dt, size_t count, const void *ptr = nullptr ,
370
425
handle base = handle())
371
- : array (dt, std::vector<size_t >{ count }, ptr, base) { }
426
+ : array_base (dt, std::vector<size_t >{ count }, ptr, base) { }
372
427
373
- template <typename T> array (const std::vector<size_t >& shape,
428
+ template <typename T> array_base (const std::vector<size_t >& shape,
374
429
const std::vector<size_t >& strides,
375
430
const T* ptr, handle base = handle())
376
- : array (pybind11::dtype::of<T>(), shape, strides, (void *) ptr, base) { }
431
+ : array_base (pybind11::dtype::of<T>(), shape, strides, (void *) ptr, base) { }
377
432
378
433
template <typename T>
379
- array (const std::vector<size_t > &shape, const T *ptr,
434
+ array_base (const std::vector<size_t > &shape, const T *ptr,
380
435
handle base = handle())
381
- : array (shape, default_strides(shape, sizeof (T)), ptr, base) { }
436
+ : array_base (shape, default_strides(shape, sizeof (T)), ptr, base) { }
382
437
383
438
template <typename T>
384
- array (size_t count, const T *ptr, handle base = handle())
385
- : array (std::vector<size_t >{ count }, ptr, base) { }
439
+ array_base (size_t count, const T *ptr, handle base = handle())
440
+ : array_base (std::vector<size_t >{ count }, ptr, base) { }
386
441
387
- explicit array (const buffer_info &info)
388
- : array (pybind11::dtype(info), info.shape, info.strides, info.ptr) { }
442
+ explicit array_base (const buffer_info &info)
443
+ : array_base (pybind11::dtype(info), info.shape, info.strides, info.ptr) { }
389
444
390
445
// / Array descriptor (dtype)
391
446
pybind11::dtype dtype () const {
@@ -424,8 +479,7 @@ class array : public buffer {
424
479
425
480
// / Dimension along a given axis
426
481
size_t shape (size_t dim) const {
427
- if (dim >= ndim ())
428
- fail_dim_check (dim, " invalid axis" );
482
+ access_policy::check_axis (dim, ndim ());
429
483
return shape ()[dim];
430
484
}
431
485
@@ -436,8 +490,7 @@ class array : public buffer {
436
490
437
491
// / Stride along a given axis
438
492
size_t strides (size_t dim) const {
439
- if (dim >= ndim ())
440
- fail_dim_check (dim, " invalid axis" );
493
+ access_policy::check_axis (dim, ndim ());
441
494
return strides ()[dim];
442
495
}
443
496
@@ -473,8 +526,7 @@ class array : public buffer {
473
526
// / Byte offset from beginning of the array to a given index (full or partial).
474
527
// / May throw if the index would lead to out of bounds access.
475
528
template <typename ... Ix> size_t offset_at (Ix... index) const {
476
- if (sizeof ...(index ) > ndim ())
477
- fail_dim_check (sizeof ...(index ), " too many indices for an array" );
529
+ access_policy::check_indices (ndim (), index ...);
478
530
return byte_offset (size_t (index )...);
479
531
}
480
532
@@ -487,15 +539,15 @@ class array : public buffer {
487
539
}
488
540
489
541
// / Return a new view with all of the dimensions of length 1 removed
490
- array squeeze () {
542
+ array_base squeeze () {
491
543
auto & api = detail::npy_api::get ();
492
- return reinterpret_steal<array >(api.PyArray_Squeeze_ (m_ptr));
544
+ return reinterpret_steal<array_base >(api.PyArray_Squeeze_ (m_ptr));
493
545
}
494
546
495
547
// / Ensure that the argument is a NumPy array
496
548
// / In case of an error, nullptr is returned and the Python error is cleared.
497
- static array ensure (handle h, int ExtraFlags = 0 ) {
498
- auto result = reinterpret_steal<array >(raw_array (h.ptr (), ExtraFlags));
549
+ static array_base ensure (handle h, int ExtraFlags = 0 ) {
550
+ auto result = reinterpret_steal<array_base >(raw_array (h.ptr (), ExtraFlags));
499
551
if (!result)
500
552
PyErr_Clear ();
501
553
return result;
@@ -504,13 +556,8 @@ class array : public buffer {
504
556
protected:
505
557
template <typename , typename > friend struct detail ::npy_format_descriptor;
506
558
507
- void fail_dim_check (size_t dim, const std::string& msg) const {
508
- throw index_error (msg + " : " + std::to_string (dim) +
509
- " (ndim = " + std::to_string (ndim ()) + " )" );
510
- }
511
-
512
559
template <typename ... Ix> size_t byte_offset (Ix... index) const {
513
- check_dimensions (index ...);
560
+ access_policy:: check_dimensions (shape (), index ...);
514
561
return byte_offset_unsafe (index ...);
515
562
}
516
563
@@ -537,21 +584,6 @@ class array : public buffer {
537
584
return strides;
538
585
}
539
586
540
- template <typename ... Ix> void check_dimensions (Ix... index) const {
541
- check_dimensions_impl (size_t (0 ), shape (), size_t (index )...);
542
- }
543
-
544
- void check_dimensions_impl (size_t , const size_t *) const { }
545
-
546
- template <typename ... Ix> void check_dimensions_impl (size_t axis, const size_t * shape, size_t i, Ix... index) const {
547
- if (i >= *shape) {
548
- throw index_error (std::string (" index " ) + std::to_string (i) +
549
- " is out of bounds for axis " + std::to_string (axis) +
550
- " with size " + std::to_string (*shape));
551
- }
552
- check_dimensions_impl (axis + 1 , shape + 1 , index ...);
553
- }
554
-
555
587
// / Create array from any object -- always returns a new reference
556
588
static PyObject *raw_array (PyObject *ptr, int ExtraFlags = 0 ) {
557
589
if (ptr == nullptr )
@@ -561,64 +593,69 @@ class array : public buffer {
561
593
}
562
594
};
563
595
564
- template <typename T, int ExtraFlags = array::forcecast> class array_t : public array {
596
+ using array = array_base<safe_access_policy>;
597
+ using array_unchecked = array_base<unsafe_access_policy>;
598
+
599
+ template <typename T, int ExtraFlags = array_base<>::forcecast, class access_policy = safe_access_policy>
600
+ class array_t : public array_base <access_policy> {
565
601
public:
566
- array_t () : array(0 , static_cast <const T *>(nullptr )) {}
567
- array_t (handle h, borrowed_t ) : array(h, borrowed) { }
568
- array_t (handle h, stolen_t ) : array(h, stolen) { }
602
+ using base_type = array_base<access_policy>;
603
+ array_t () : base_type(0 , static_cast <const T *>(nullptr )) {}
604
+ array_t (handle h, object::borrowed_t ) : base_type(h, object::borrowed) { }
605
+ array_t (handle h, object::stolen_t ) : base_type(h, object::stolen) { }
569
606
570
607
PYBIND11_DEPRECATED (" Use array_t<T>::ensure() instead" )
571
- array_t (handle h, bool is_borrowed) : array (raw_array_t (h.ptr()), stolen) {
572
- if (!m_ptr) PyErr_Clear ();
608
+ array_t (handle h, bool is_borrowed) : base_type (raw_array_t (h.ptr()), object:: stolen) {
609
+ if (!this -> m_ptr ) PyErr_Clear ();
573
610
if (!is_borrowed) Py_XDECREF (h.ptr ());
574
611
}
575
612
576
- array_t (const object &o) : array (raw_array_t (o.ptr()), stolen) {
577
- if (!m_ptr) throw error_already_set ();
613
+ array_t (const object &o) : base_type (raw_array_t (o.ptr()), object:: stolen) {
614
+ if (!this -> m_ptr ) throw error_already_set ();
578
615
}
579
616
580
- explicit array_t (const buffer_info& info) : array (info) { }
617
+ explicit array_t (const buffer_info& info) : base_type (info) { }
581
618
582
619
array_t (const std::vector<size_t > &shape,
583
620
const std::vector<size_t > &strides, const T *ptr = nullptr ,
584
621
handle base = handle())
585
- : array (shape, strides, ptr, base) { }
622
+ : base_type (shape, strides, ptr, base) { }
586
623
587
624
explicit array_t (const std::vector<size_t > &shape, const T *ptr = nullptr ,
588
625
handle base = handle())
589
- : array (shape, ptr, base) { }
626
+ : base_type (shape, ptr, base) { }
590
627
591
628
explicit array_t (size_t count, const T *ptr = nullptr , handle base = handle())
592
- : array (count, ptr, base) { }
629
+ : base_type (count, ptr, base) { }
593
630
594
631
constexpr size_t itemsize () const {
595
632
return sizeof (T);
596
633
}
597
634
598
635
template <typename ... Ix> size_t index_at (Ix... index) const {
599
- return offset_at (index ...) / itemsize ();
636
+ return base_type:: offset_at (index ...) / itemsize ();
600
637
}
601
638
602
639
template <typename ... Ix> const T* data (Ix... index) const {
603
- return static_cast <const T*>(array ::data (index ...));
640
+ return static_cast <const T*>(base_type ::data (index ...));
604
641
}
605
642
606
643
template <typename ... Ix> T* mutable_data (Ix... index) {
607
- return static_cast <T*>(array ::mutable_data (index ...));
644
+ return static_cast <T*>(base_type ::mutable_data (index ...));
608
645
}
609
646
610
647
// Reference to element at a given index
611
648
template <typename ... Ix> const T& at (Ix... index) const {
612
- if (sizeof ...(index ) != ndim ())
613
- fail_dim_check (sizeof ...(index ), " index dimension mismatch" );
614
- return *(static_cast <const T*>(array ::data ()) + byte_offset (size_t (index )...) / itemsize ());
649
+ if (sizeof ...(index ) != base_type:: ndim ())
650
+ detail:: fail_dim_check (sizeof ...(index ), base_type::ndim ( ), " index dimension mismatch" );
651
+ return *(static_cast <const T*>(base_type ::data ()) + base_type:: byte_offset (size_t (index )...) / itemsize ());
615
652
}
616
653
617
654
// Mutable reference to element at a given index
618
655
template <typename ... Ix> T& mutable_at (Ix... index) {
619
- if (sizeof ...(index ) != ndim ())
620
- fail_dim_check (sizeof ...(index ), " index dimension mismatch" );
621
- return *(static_cast <T*>(array ::mutable_data ()) + byte_offset (size_t (index )...) / itemsize ());
656
+ if (sizeof ...(index ) != base_type:: ndim ())
657
+ detail:: fail_dim_check (sizeof ...(index ), base_type::ndim ( ), " index dimension mismatch" );
658
+ return *(static_cast <T*>(base_type ::mutable_data ()) + base_type:: byte_offset (size_t (index )...) / itemsize ());
622
659
}
623
660
624
661
// / Ensure that the argument is a NumPy array of the correct dtype.
@@ -811,7 +848,7 @@ inline PYBIND11_NOINLINE void register_structured_dtype(
811
848
812
849
// Sanity check: verify that NumPy properly parses our buffer format string
813
850
auto & api = npy_api::get ();
814
- auto arr = array (buffer_info (nullptr , itemsize, format_str, 1 ));
851
+ auto arr = array_base<> (buffer_info (nullptr , itemsize, format_str, 1 ));
815
852
if (!api.PyArray_EquivTypes_ (dtype_ptr, arr.dtype ().ptr ()))
816
853
pybind11_fail (" NumPy: invalid buffer descriptor!" );
817
854
@@ -1076,11 +1113,11 @@ struct vectorize_helper {
1076
1113
template <typename T>
1077
1114
explicit vectorize_helper (T&&f) : f(std::forward<T>(f)) { }
1078
1115
1079
- object operator ()(array_t <Args, array ::c_style | array ::forcecast>... args) {
1116
+ object operator ()(array_t <Args, array_base<> ::c_style | array_base<> ::forcecast>... args) {
1080
1117
return run (args..., make_index_sequence<sizeof ...(Args)>());
1081
1118
}
1082
1119
1083
- template <size_t ... Index> object run (array_t <Args, array ::c_style | array ::forcecast>&... args, index_sequence<Index...> index) {
1120
+ template <size_t ... Index> object run (array_t <Args, array_base<> ::c_style | array_base<> ::forcecast>&... args, index_sequence<Index...> index) {
1084
1121
/* Request buffers from all parameters */
1085
1122
const size_t N = sizeof ...(Args);
1086
1123
0 commit comments