@@ -305,14 +305,16 @@ class dtype : public object {
305
305
306
306
class array : public buffer {
307
307
public:
308
- PYBIND11_OBJECT_DEFAULT (array, buffer, detail::npy_api::get().PyArray_Check_)
308
+ PYBIND11_OBJECT_CVT (array, buffer, detail::npy_api::get().PyArray_Check_, raw_array )
309
309
310
310
enum {
311
311
c_style = detail::npy_api::NPY_C_CONTIGUOUS_,
312
312
f_style = detail::npy_api::NPY_F_CONTIGUOUS_,
313
313
forcecast = detail::npy_api::NPY_ARRAY_FORCECAST_
314
314
};
315
315
316
+ array () : array(0 , static_cast <const double *>(nullptr )) {}
317
+
316
318
array (const pybind11::dtype &dt, const std::vector<size_t > &shape,
317
319
const std::vector<size_t > &strides, const void *ptr = nullptr ,
318
320
handle base = handle()) {
@@ -478,10 +480,12 @@ class array : public buffer {
478
480
}
479
481
480
482
// / Ensure that the argument is a NumPy array
481
- static array ensure (object input, int ExtraFlags = 0 ) {
482
- auto & api = detail::npy_api::get ();
483
- return reinterpret_steal<array>(api.PyArray_FromAny_ (
484
- input.release ().ptr (), nullptr , 0 , 0 , detail::npy_api::NPY_ENSURE_ARRAY_ | ExtraFlags, nullptr ));
483
+ // / In case of an error, nullptr is returned and the Python error is cleared.
484
+ static array ensure (handle h, int ExtraFlags = 0 ) {
485
+ auto result = reinterpret_steal<array>(raw_array (h.ptr (), ExtraFlags));
486
+ if (!result)
487
+ PyErr_Clear ();
488
+ return result;
485
489
}
486
490
487
491
protected:
@@ -521,15 +525,31 @@ class array : public buffer {
521
525
}
522
526
return strides;
523
527
}
528
+
529
+ // / Create array from any object -- always returns a new reference
530
+ static PyObject *raw_array (PyObject *ptr, int ExtraFlags = 0 ) {
531
+ if (ptr == nullptr )
532
+ return nullptr ;
533
+ return detail::npy_api::get ().PyArray_FromAny_ (
534
+ ptr, nullptr , 0 , 0 , detail::npy_api::NPY_ENSURE_ARRAY_ | ExtraFlags, nullptr );
535
+ }
524
536
};
525
537
526
538
template <typename T, int ExtraFlags = array::forcecast> class array_t : public array {
527
539
public:
528
- array_t () : array() { }
540
+ array_t () : array(0 , static_cast <const T *>(nullptr )) {}
541
+ array_t (handle h, borrowed_t ) : array(h, borrowed) { }
542
+ array_t (handle h, stolen_t ) : array(h, stolen) { }
529
543
530
- array_t (handle h, bool is_borrowed) : array(h, is_borrowed) { m_ptr = ensure_ (m_ptr); }
544
+ PYBIND11_DEPRECATED (" Use array_t<T>::ensure() instead" )
545
+ array_t (handle h, bool is_borrowed) : array(raw_array_t (h.ptr()), stolen) {
546
+ if (!m_ptr) PyErr_Clear ();
547
+ if (!is_borrowed) Py_XDECREF (h.ptr ());
548
+ }
531
549
532
- array_t (const object &o) : array(o) { m_ptr = ensure_ (m_ptr); }
550
+ array_t (const object &o) : array(raw_array_t (o.ptr()), stolen) {
551
+ if (!m_ptr) throw error_already_set ();
552
+ }
533
553
534
554
explicit array_t (const buffer_info& info) : array(info) { }
535
555
@@ -577,17 +597,30 @@ template <typename T, int ExtraFlags = array::forcecast> class array_t : public
577
597
return *(static_cast <T*>(array::mutable_data ()) + get_byte_offset (index ...) / itemsize ());
578
598
}
579
599
580
- static PyObject *ensure_ (PyObject *ptr) {
581
- if (ptr == nullptr )
582
- return nullptr ;
583
- auto & api = detail::npy_api::get ();
584
- PyObject *result = api.PyArray_FromAny_ (ptr, pybind11::dtype::of<T>().release ().ptr (), 0 , 0 ,
585
- detail::npy_api::NPY_ENSURE_ARRAY_ | ExtraFlags, nullptr );
600
+ // / Ensure that the argument is a NumPy array of the correct dtype.
601
+ // / In case of an error, nullptr is returned and the Python error is cleared.
602
+ static array_t ensure (handle h) {
603
+ auto result = reinterpret_steal<array_t >(raw_array_t (h.ptr ()));
586
604
if (!result)
587
605
PyErr_Clear ();
588
- Py_DECREF (ptr);
589
606
return result;
590
607
}
608
+
609
+ static bool _check (handle h) {
610
+ const auto &api = detail::npy_api::get ();
611
+ return api.PyArray_Check_ (h.ptr ())
612
+ && api.PyArray_EquivTypes_ (PyArray_GET_ (h.ptr (), descr), dtype::of<T>().ptr ());
613
+ }
614
+
615
+ protected:
616
+ // / Create array from any object -- always returns a new reference
617
+ static PyObject *raw_array_t (PyObject *ptr) {
618
+ if (ptr == nullptr )
619
+ return nullptr ;
620
+ return detail::npy_api::get ().PyArray_FromAny_ (
621
+ ptr, dtype::of<T>().release ().ptr (), 0 , 0 ,
622
+ detail::npy_api::NPY_ENSURE_ARRAY_ | ExtraFlags, nullptr );
623
+ }
591
624
};
592
625
593
626
template <typename T>
@@ -618,7 +651,7 @@ struct pyobject_caster<array_t<T, ExtraFlags>> {
618
651
using type = array_t <T, ExtraFlags>;
619
652
620
653
bool load (handle src, bool /* convert */ ) {
621
- value = type (src, true );
654
+ value = type::ensure (src);
622
655
return static_cast <bool >(value);
623
656
}
624
657
0 commit comments