Skip to content

Commit 56784c4

Browse files
Add unchecked_reference::operator() and operator[] to overload resolution of unchecked_mutable_reference (#2514)
1 parent 2b6b98e commit 56784c4

File tree

3 files changed

+19
-0
lines changed

3 files changed

+19
-0
lines changed

include/pybind11/numpy.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -419,6 +419,10 @@ class unchecked_mutable_reference : public unchecked_reference<T, Dims> {
419419
using ConstBase::ConstBase;
420420
using ConstBase::Dynamic;
421421
public:
422+
// Bring in const-qualified versions from base class
423+
using ConstBase::operator();
424+
using ConstBase::operator[];
425+
422426
/// Mutable, unchecked access to data at the given indices.
423427
template <typename... Ix> T& operator()(Ix... index) {
424428
static_assert(ssize_t{sizeof...(Ix)} == Dims || Dynamic,

tests/test_numpy_array.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,18 @@ TEST_SUBMODULE(numpy_array, sm) {
318318
return auxiliaries(r, r2);
319319
});
320320

321+
sm.def("proxy_auxiliaries1_const_ref", [](py::array_t<double> a) {
322+
const auto &r = a.unchecked<1>();
323+
const auto &r2 = a.mutable_unchecked<1>();
324+
return r(0) == r2(0) && r[0] == r2[0];
325+
});
326+
327+
sm.def("proxy_auxiliaries2_const_ref", [](py::array_t<double> a) {
328+
const auto &r = a.unchecked<2>();
329+
const auto &r2 = a.mutable_unchecked<2>();
330+
return r(0, 0) == r2(0, 0);
331+
});
332+
321333
// test_array_unchecked_dyn_dims
322334
// Same as the above, but without a compile-time dimensions specification:
323335
sm.def("proxy_add2_dyn", [](py::array_t<double> a, double v) {

tests/test_numpy_array.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,9 @@ def test_array_unchecked_fixed_dims(msg):
364364
assert m.proxy_auxiliaries2(z1) == [11, 11, True, 2, 8, 2, 2, 4, 32]
365365
assert m.proxy_auxiliaries2(z1) == m.array_auxiliaries2(z1)
366366

367+
assert m.proxy_auxiliaries1_const_ref(z1[0, :])
368+
assert m.proxy_auxiliaries2_const_ref(z1)
369+
367370

368371
def test_array_unchecked_dyn_dims(msg):
369372
z1 = np.array([[1, 2], [3, 4]], dtype='float64')

0 commit comments

Comments
 (0)