Skip to content

Commit 1d416ce

Browse files
committed
Make array_t’s converting constructor consistent with other pytypes
* `array` gets a converting constructor * `array_t(const object &)` throws on error * `array_t::ensure()` is intended for casters * `py::isinstance<array_T<T>>()` checks the type (but not flags)
1 parent 8afdce3 commit 1d416ce

File tree

4 files changed

+101
-17
lines changed

4 files changed

+101
-17
lines changed

include/pybind11/eigen.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ struct type_caster<Type, enable_if_t<is_eigen_dense<Type>::value && !is_eigen_re
5454
static constexpr bool isVector = Type::IsVectorAtCompileTime;
5555

5656
bool load(handle src, bool) {
57-
array_t<Scalar> buf(src, true);
57+
auto buf = array_t<Scalar>::ensure(src);
5858
if (!buf)
5959
return false;
6060

include/pybind11/numpy.h

Lines changed: 49 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -305,14 +305,16 @@ class dtype : public object {
305305

306306
class array : public buffer {
307307
public:
308-
PYBIND11_OBJECT_DEFAULT(array, buffer, detail::npy_api::get().PyArray_Check_)
308+
PYBIND11_OBJECT_CVT(array, buffer, detail::npy_api::get().PyArray_Check_, raw_array)
309309

310310
enum {
311311
c_style = detail::npy_api::NPY_C_CONTIGUOUS_,
312312
f_style = detail::npy_api::NPY_F_CONTIGUOUS_,
313313
forcecast = detail::npy_api::NPY_ARRAY_FORCECAST_
314314
};
315315

316+
array() : array(0, static_cast<const double *>(nullptr)) {}
317+
316318
array(const pybind11::dtype &dt, const std::vector<size_t> &shape,
317319
const std::vector<size_t> &strides, const void *ptr = nullptr,
318320
handle base = handle()) {
@@ -478,10 +480,12 @@ class array : public buffer {
478480
}
479481

480482
/// Ensure that the argument is a NumPy array
481-
static array ensure(object input, int ExtraFlags = 0) {
482-
auto& api = detail::npy_api::get();
483-
return reinterpret_steal<array>(api.PyArray_FromAny_(
484-
input.release().ptr(), nullptr, 0, 0, detail::npy_api::NPY_ENSURE_ARRAY_ | ExtraFlags, nullptr));
483+
/// In case of an error, nullptr is returned and the Python error is cleared.
484+
static array ensure(handle h, int ExtraFlags = 0) {
485+
auto result = reinterpret_steal<array>(raw_array(h.ptr(), ExtraFlags));
486+
if (!result)
487+
PyErr_Clear();
488+
return result;
485489
}
486490

487491
protected:
@@ -521,15 +525,31 @@ class array : public buffer {
521525
}
522526
return strides;
523527
}
528+
529+
/// Create array from any object -- always returns a new reference
530+
static PyObject *raw_array(PyObject *ptr, int ExtraFlags = 0) {
531+
if (ptr == nullptr)
532+
return nullptr;
533+
return detail::npy_api::get().PyArray_FromAny_(
534+
ptr, nullptr, 0, 0, detail::npy_api::NPY_ENSURE_ARRAY_ | ExtraFlags, nullptr);
535+
}
524536
};
525537

526538
template <typename T, int ExtraFlags = array::forcecast> class array_t : public array {
527539
public:
528-
array_t() : array() { }
540+
array_t() : array(0, static_cast<const T *>(nullptr)) {}
541+
array_t(handle h, borrowed_t) : array(h, borrowed) { }
542+
array_t(handle h, stolen_t) : array(h, stolen) { }
529543

530-
array_t(handle h, bool is_borrowed) : array(h, is_borrowed) { m_ptr = ensure_(m_ptr); }
544+
PYBIND11_DEPRECATED("Use array_t<T>::ensure() instead")
545+
array_t(handle h, bool is_borrowed) : array(raw_array_t(h.ptr()), stolen) {
546+
if (!m_ptr) PyErr_Clear();
547+
if (!is_borrowed) Py_XDECREF(h.ptr());
548+
}
531549

532-
array_t(const object &o) : array(o) { m_ptr = ensure_(m_ptr); }
550+
array_t(const object &o) : array(raw_array_t(o.ptr()), stolen) {
551+
if (!m_ptr) throw error_already_set();
552+
}
533553

534554
explicit array_t(const buffer_info& info) : array(info) { }
535555

@@ -577,17 +597,30 @@ template <typename T, int ExtraFlags = array::forcecast> class array_t : public
577597
return *(static_cast<T*>(array::mutable_data()) + get_byte_offset(index...) / itemsize());
578598
}
579599

580-
static PyObject *ensure_(PyObject *ptr) {
581-
if (ptr == nullptr)
582-
return nullptr;
583-
auto& api = detail::npy_api::get();
584-
PyObject *result = api.PyArray_FromAny_(ptr, pybind11::dtype::of<T>().release().ptr(), 0, 0,
585-
detail::npy_api::NPY_ENSURE_ARRAY_ | ExtraFlags, nullptr);
600+
/// Ensure that the argument is a NumPy array of the correct dtype.
601+
/// In case of an error, nullptr is returned and the Python error is cleared.
602+
static array_t ensure(handle h) {
603+
auto result = reinterpret_steal<array_t>(raw_array_t(h.ptr()));
586604
if (!result)
587605
PyErr_Clear();
588-
Py_DECREF(ptr);
589606
return result;
590607
}
608+
609+
static bool _check(handle h) {
610+
const auto &api = detail::npy_api::get();
611+
return api.PyArray_Check_(h.ptr())
612+
&& api.PyArray_EquivTypes_(PyArray_GET_(h.ptr(), descr), dtype::of<T>().ptr());
613+
}
614+
615+
protected:
616+
/// Create array from any object -- always returns a new reference
617+
static PyObject *raw_array_t(PyObject *ptr) {
618+
if (ptr == nullptr)
619+
return nullptr;
620+
return detail::npy_api::get().PyArray_FromAny_(
621+
ptr, dtype::of<T>().release().ptr(), 0, 0,
622+
detail::npy_api::NPY_ENSURE_ARRAY_ | ExtraFlags, nullptr);
623+
}
591624
};
592625

593626
template <typename T>
@@ -618,7 +651,7 @@ struct pyobject_caster<array_t<T, ExtraFlags>> {
618651
using type = array_t<T, ExtraFlags>;
619652

620653
bool load(handle src, bool /* convert */) {
621-
value = type(src, true);
654+
value = type::ensure(src);
622655
return static_cast<bool>(value);
623656
}
624657

tests/test_numpy_array.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,4 +126,28 @@ test_initializer numpy_array([](py::module &m) {
126126
);
127127

128128
sm.def("function_taking_uint64", [](uint64_t) { });
129+
130+
sm.def("isinstance_untyped", [](py::object yes, py::object no) {
131+
return py::isinstance<py::array>(yes) && !py::isinstance<py::array>(no);
132+
});
133+
134+
sm.def("isinstance_typed", [](py::object o) {
135+
return py::isinstance<py::array_t<double>>(o) && !py::isinstance<py::array_t<int>>(o);
136+
});
137+
138+
sm.def("default_constructors", []() {
139+
return py::dict(
140+
"array"_a=py::array(),
141+
"array_t<int32>"_a=py::array_t<std::int32_t>(),
142+
"array_t<double>"_a=py::array_t<double>()
143+
);
144+
});
145+
146+
sm.def("converting_constructors", [](py::object o) {
147+
return py::dict(
148+
"array"_a=py::array(o),
149+
"array_t<int32>"_a=py::array_t<std::int32_t>(o),
150+
"array_t<double>"_a=py::array_t<double>(o)
151+
);
152+
});
129153
});

tests/test_numpy_array.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,3 +245,30 @@ def test_cast_numpy_int64_to_uint64():
245245
from pybind11_tests.array import function_taking_uint64
246246
function_taking_uint64(123)
247247
function_taking_uint64(np.uint64(123))
248+
249+
250+
@pytest.requires_numpy
251+
def test_isinstance():
252+
from pybind11_tests.array import isinstance_untyped, isinstance_typed
253+
254+
assert isinstance_untyped(np.array([1, 2, 3]), "not an array")
255+
assert isinstance_typed(np.array([1.0, 2.0, 3.0]))
256+
257+
258+
@pytest.requires_numpy
259+
def test_constructors():
260+
from pybind11_tests.array import default_constructors, converting_constructors
261+
262+
defaults = default_constructors()
263+
for a in defaults.values():
264+
assert a.size == 0
265+
assert defaults["array"].dtype == np.array([]).dtype
266+
assert defaults["array_t<int32>"].dtype == np.int32
267+
assert defaults["array_t<double>"].dtype == np.float64
268+
269+
results = converting_constructors([1, 2, 3])
270+
for a in results.values():
271+
np.testing.assert_array_equal(a, [1, 2, 3])
272+
assert results["array"].dtype == np.int_
273+
assert results["array_t<int32>"].dtype == np.int32
274+
assert results["array_t<double>"].dtype == np.float64

0 commit comments

Comments
 (0)