Skip to content

Commit 0da7ac7

Browse files
Mmanu ChaturvediEricCousineau-TRI
Mmanu Chaturvedi
authored andcommitted
Add ability to create object matrices
1 parent e763f04 commit 0da7ac7

File tree

4 files changed

+251
-22
lines changed

4 files changed

+251
-22
lines changed

include/pybind11/eigen.h

Lines changed: 144 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,10 @@ struct eigen_extract_stride<Eigen::Map<PlainObjectType, MapOptions, StrideType>>
113113
template <typename PlainObjectType, int Options, typename StrideType>
114114
struct eigen_extract_stride<Eigen::Ref<PlainObjectType, Options, StrideType>> { using type = StrideType; };
115115

116+
template <typename Scalar> bool is_pyobject_() {
117+
return static_cast<pybind11::detail::npy_api::constants>(npy_format_descriptor<Scalar>::value) == npy_api::NPY_OBJECT_;
118+
}
119+
116120
// Helper struct for extracting information from an Eigen type
117121
template <typename Type_> struct EigenProps {
118122
using Type = Type_;
@@ -145,14 +149,19 @@ template <typename Type_> struct EigenProps {
145149
const auto dims = a.ndim();
146150
if (dims < 1 || dims > 2)
147151
return false;
148-
152+
bool is_pyobject = false;
153+
if (is_pyobject_<Scalar>())
154+
is_pyobject = true;
155+
ssize_t scalar_size = (is_pyobject ? static_cast<ssize_t>(sizeof(PyObject*)) :
156+
static_cast<ssize_t>(sizeof(Scalar)));
149157
if (dims == 2) { // Matrix type: require exact match (or dynamic)
150158

151159
EigenIndex
152160
np_rows = a.shape(0),
153161
np_cols = a.shape(1),
154-
np_rstride = a.strides(0) / static_cast<ssize_t>(sizeof(Scalar)),
155-
np_cstride = a.strides(1) / static_cast<ssize_t>(sizeof(Scalar));
162+
np_rstride = a.strides(0) / scalar_size,
163+
np_cstride = a.strides(1) / scalar_size;
164+
156165
if ((fixed_rows && np_rows != rows) || (fixed_cols && np_cols != cols))
157166
return false;
158167

@@ -162,7 +171,7 @@ template <typename Type_> struct EigenProps {
162171
// Otherwise we're storing an n-vector. Only one of the strides will be used, but whichever
163172
// is used, we want the (single) numpy stride value.
164173
const EigenIndex n = a.shape(0),
165-
stride = a.strides(0) / static_cast<ssize_t>(sizeof(Scalar));
174+
stride = a.strides(0) / scalar_size;
166175

167176
if (vector) { // Eigen type is a compile-time vector
168177
if (fixed && size != n)
@@ -213,11 +222,51 @@ template <typename Type_> struct EigenProps {
213222
template <typename props> handle eigen_array_cast(typename props::Type const &src, handle base = handle(), bool writeable = true) {
214223
constexpr ssize_t elem_size = sizeof(typename props::Scalar);
215224
array a;
216-
if (props::vector)
217-
a = array({ src.size() }, { elem_size * src.innerStride() }, src.data(), base);
218-
else
219-
a = array({ src.rows(), src.cols() }, { elem_size * src.rowStride(), elem_size * src.colStride() },
220-
src.data(), base);
225+
using Scalar = typename props::Type::Scalar;
226+
bool is_pyoject = static_cast<pybind11::detail::npy_api::constants>(npy_format_descriptor<Scalar>::value) == npy_api::NPY_OBJECT_;
227+
228+
if (!is_pyoject) {
229+
if (props::vector)
230+
a = array({ src.size() }, { elem_size * src.innerStride() }, src.data(), base);
231+
else
232+
a = array({ src.rows(), src.cols() }, { elem_size * src.rowStride(), elem_size * src.colStride() },
233+
src.data(), base);
234+
}
235+
else {
236+
if (props::vector) {
237+
a = array(
238+
npy_format_descriptor<Scalar>::dtype(),
239+
{ (size_t) src.size() },
240+
nullptr,
241+
base
242+
);
243+
auto policy = base ? return_value_policy::automatic_reference : return_value_policy::copy;
244+
for (ssize_t i = 0; i < src.size(); ++i) {
245+
const Scalar src_val = props::fixed_rows ? src(0, i) : src(i, 0);
246+
auto value_ = reinterpret_steal<object>(make_caster<Scalar>::cast(src_val, policy, base));
247+
if (!value_)
248+
return handle();
249+
a.attr("itemset")(i, value_);
250+
}
251+
}
252+
else {
253+
a = array(
254+
npy_format_descriptor<Scalar>::dtype(),
255+
{(size_t) src.rows(), (size_t) src.cols()},
256+
nullptr,
257+
base
258+
);
259+
auto policy = base ? return_value_policy::automatic_reference : return_value_policy::copy;
260+
for (ssize_t i = 0; i < src.rows(); ++i) {
261+
for (ssize_t j = 0; j < src.cols(); ++j) {
262+
auto value_ = reinterpret_steal<object>(make_caster<Scalar>::cast(src(i, j), policy, base));
263+
if (!value_)
264+
return handle();
265+
a.attr("itemset")(i, j, value_);
266+
}
267+
}
268+
}
269+
}
221270

222271
if (!writeable)
223272
array_proxy(a.ptr())->flags &= ~detail::npy_api::NPY_ARRAY_WRITEABLE_;
@@ -271,14 +320,46 @@ struct type_caster<Type, enable_if_t<is_eigen_dense_plain<Type>::value>> {
271320
auto fits = props::conformable(buf);
272321
if (!fits)
273322
return false;
274-
323+
int result = 0;
275324
// Allocate the new type, then build a numpy reference into it
276325
value = Type(fits.rows, fits.cols);
277-
auto ref = reinterpret_steal<array>(eigen_ref_array<props>(value));
278-
if (dims == 1) ref = ref.squeeze();
279-
else if (ref.ndim() == 1) buf = buf.squeeze();
280-
281-
int result = detail::npy_api::get().PyArray_CopyInto_(ref.ptr(), buf.ptr());
326+
bool is_pyobject = is_pyobject_<Scalar>();
327+
328+
if (!is_pyobject) {
329+
auto ref = reinterpret_steal<array>(eigen_ref_array<props>(value));
330+
if (dims == 1) ref = ref.squeeze();
331+
else if (ref.ndim() == 1) buf = buf.squeeze();
332+
result =
333+
detail::npy_api::get().PyArray_CopyInto_(ref.ptr(), buf.ptr());
334+
}
335+
else {
336+
if (dims == 1) {
337+
if (Type::RowsAtCompileTime == Eigen::Dynamic)
338+
value.resize(buf.shape(0), 1);
339+
if (Type::ColsAtCompileTime == Eigen::Dynamic)
340+
value.resize(1, buf.shape(0));
341+
342+
for (ssize_t i = 0; i < buf.shape(0); ++i) {
343+
make_caster <Scalar> conv_val;
344+
if (!conv_val.load(buf.attr("item")(i).cast<pybind11::object>(), convert))
345+
return false;
346+
value(i) = cast_op<Scalar>(conv_val);
347+
}
348+
} else {
349+
if (Type::RowsAtCompileTime == Eigen::Dynamic || Type::ColsAtCompileTime == Eigen::Dynamic) {
350+
value.resize(buf.shape(0), buf.shape(1));
351+
}
352+
for (ssize_t i = 0; i < buf.shape(0); ++i) {
353+
for (ssize_t j = 0; j < buf.shape(1); ++j) {
354+
// p is the const void pointer to the item
355+
make_caster<Scalar> conv_val;
356+
if (!conv_val.load(buf.attr("item")(i,j).cast<pybind11::object>(), convert))
357+
return false;
358+
value(i,j) = cast_op<Scalar>(conv_val);
359+
}
360+
}
361+
}
362+
}
282363

283364
if (result < 0) { // Copy failed!
284365
PyErr_Clear();
@@ -430,13 +511,19 @@ struct type_caster<
430511
// storage order conversion. (Note that we refuse to use this temporary copy when loading an
431512
// argument for a Ref<M> with M non-const, i.e. a read-write reference).
432513
Array copy_or_ref;
514+
typename std::remove_cv<PlainObjectType>::type val;
433515
public:
434516
bool load(handle src, bool convert) {
435517
// First check whether what we have is already an array of the right type. If not, we can't
436518
// avoid a copy (because the copy is also going to do type conversion).
437519
bool need_copy = !isinstance<Array>(src);
438520

439521
EigenConformable<props::row_major> fits;
522+
bool is_pyobject = false;
523+
if (is_pyobject_<Scalar>()) {
524+
is_pyobject = true;
525+
need_copy = true;
526+
}
440527
if (!need_copy) {
441528
// We don't need a converting copy, but we also need to check whether the strides are
442529
// compatible with the Ref's stride requirements
@@ -459,15 +546,53 @@ struct type_caster<
459546
// We need to copy: If we need a mutable reference, or we're not supposed to convert
460547
// (either because we're in the no-convert overload pass, or because we're explicitly
461548
// instructed not to copy (via `py::arg().noconvert()`) we have to fail loading.
462-
if (!convert || need_writeable) return false;
549+
if (!is_pyobject && (!convert || need_writeable)) {
550+
return false;
551+
}
463552

464553
Array copy = Array::ensure(src);
465554
if (!copy) return false;
466555
fits = props::conformable(copy);
467-
if (!fits || !fits.template stride_compatible<props>())
556+
if (!fits || !fits.template stride_compatible<props>()) {
468557
return false;
469-
copy_or_ref = std::move(copy);
470-
loader_life_support::add_patient(copy_or_ref);
558+
}
559+
560+
if (!is_pyobject) {
561+
copy_or_ref = std::move(copy);
562+
loader_life_support::add_patient(copy_or_ref);
563+
}
564+
else {
565+
auto dims = copy.ndim();
566+
if (dims == 1) {
567+
if (Type::RowsAtCompileTime == Eigen::Dynamic || Type::ColsAtCompileTime == Eigen::Dynamic) {
568+
val.resize(copy.shape(0), 1);
569+
}
570+
for (ssize_t i = 0; i < copy.shape(0); ++i) {
571+
make_caster <Scalar> conv_val;
572+
if (!conv_val.load(copy.attr("item")(i).template cast<pybind11::object>(),
573+
convert))
574+
return false;
575+
val(i) = cast_op<Scalar>(conv_val);
576+
577+
}
578+
} else {
579+
if (Type::RowsAtCompileTime == Eigen::Dynamic || Type::ColsAtCompileTime == Eigen::Dynamic) {
580+
val.resize(copy.shape(0), copy.shape(1));
581+
}
582+
for (ssize_t i = 0; i < copy.shape(0); ++i) {
583+
for (ssize_t j = 0; j < copy.shape(1); ++j) {
584+
// p is the const void pointer to the item
585+
make_caster <Scalar> conv_val;
586+
if (!conv_val.load(copy.attr("item")(i, j).template cast<pybind11::object>(),
587+
convert))
588+
return false;
589+
val(i, j) = cast_op<Scalar>(conv_val);
590+
}
591+
}
592+
}
593+
ref.reset(new Type(val));
594+
return true;
595+
}
471596
}
472597

473598
ref.reset();

include/pybind11/numpy.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1230,6 +1230,21 @@ template <typename T, typename SFINAE> struct npy_format_descriptor {
12301230
(::std::vector<::pybind11::detail::field_descriptor> \
12311231
{PYBIND11_MAP2_LIST (PYBIND11_FIELD_DESCRIPTOR_EX, Type, __VA_ARGS__)})
12321232

1233+
#define PYBIND11_NUMPY_OBJECT_DTYPE(Type) \
1234+
namespace pybind11 { namespace detail { \
1235+
template <> struct npy_format_descriptor<Type> { \
1236+
public: \
1237+
enum { value = npy_api::NPY_OBJECT_ }; \
1238+
static pybind11::dtype dtype() { \
1239+
if (auto ptr = npy_api::get().PyArray_DescrFromType_(value)) { \
1240+
return reinterpret_borrow<pybind11::dtype>(ptr); \
1241+
} \
1242+
pybind11_fail("Unsupported buffer format!"); \
1243+
} \
1244+
static constexpr auto name = _("object"); \
1245+
}; \
1246+
}}
1247+
12331248
#endif // __CLION_IDE__
12341249

12351250
template <class T>

tests/test_eigen.cpp

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,14 @@
1717
#endif
1818

1919
#include <Eigen/Cholesky>
20+
#include <unsupported/Eigen/AutoDiff>
21+
#include "Eigen/src/Core/util/DisableStupidWarnings.h"
2022

2123
using MatrixXdR = Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
22-
23-
24+
typedef Eigen::AutoDiffScalar<Eigen::VectorXd> ADScalar;
25+
typedef Eigen::Matrix<ADScalar, Eigen::Dynamic, 1> VectorXADScalar;
26+
typedef Eigen::Matrix<ADScalar, 1, Eigen::Dynamic> VectorXADScalarR;
27+
PYBIND11_NUMPY_OBJECT_DTYPE(ADScalar);
2428

2529
// Sets/resets a testing reference matrix to have values of 10*r + c, where r and c are the
2630
// (1-based) row/column number.
@@ -79,7 +83,9 @@ TEST_SUBMODULE(eigen, m) {
7983
using FixedMatrixR = Eigen::Matrix<float, 5, 6, Eigen::RowMajor>;
8084
using FixedMatrixC = Eigen::Matrix<float, 5, 6>;
8185
using DenseMatrixR = Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
86+
using DenseADScalarMatrixR = Eigen::Matrix<ADScalar, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
8287
using DenseMatrixC = Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic>;
88+
using DenseADScalarMatrixC = Eigen::Matrix<ADScalar, Eigen::Dynamic, Eigen::Dynamic>;
8389
using FourRowMatrixC = Eigen::Matrix<float, 4, Eigen::Dynamic>;
8490
using FourColMatrixC = Eigen::Matrix<float, Eigen::Dynamic, 4>;
8591
using FourRowMatrixR = Eigen::Matrix<float, 4, Eigen::Dynamic>;
@@ -91,10 +97,14 @@ TEST_SUBMODULE(eigen, m) {
9197

9298
// various tests
9399
m.def("double_col", [](const Eigen::VectorXf &x) -> Eigen::VectorXf { return 2.0f * x; });
100+
m.def("double_adscalar_col", [](const VectorXADScalar &x) -> VectorXADScalar { return 2.0f * x; });
94101
m.def("double_row", [](const Eigen::RowVectorXf &x) -> Eigen::RowVectorXf { return 2.0f * x; });
102+
m.def("double_adscalar_row", [](const VectorXADScalarR &x) -> VectorXADScalarR { return 2.0f * x; });
95103
m.def("double_complex", [](const Eigen::VectorXcf &x) -> Eigen::VectorXcf { return 2.0f * x; });
96104
m.def("double_threec", [](py::EigenDRef<Eigen::Vector3f> x) { x *= 2; });
105+
m.def("double_adscalarc", [](py::EigenDRef<VectorXADScalar> x) { x *= 2; });
97106
m.def("double_threer", [](py::EigenDRef<Eigen::RowVector3f> x) { x *= 2; });
107+
m.def("double_adscalarr", [](py::EigenDRef<VectorXADScalarR> x) { x *= 2; });
98108
m.def("double_mat_cm", [](Eigen::MatrixXf x) -> Eigen::MatrixXf { return 2.0f * x; });
99109
m.def("double_mat_rm", [](DenseMatrixR x) -> DenseMatrixR { return 2.0f * x; });
100110

@@ -139,6 +149,12 @@ TEST_SUBMODULE(eigen, m) {
139149
return m;
140150
}, py::return_value_policy::reference);
141151

152+
// Increments ADScalar Matrix
153+
m.def("incr_adscalar_matrix", [](Eigen::Ref<DenseADScalarMatrixC> m, double v) {
154+
m += DenseADScalarMatrixC::Constant(m.rows(), m.cols(), v);
155+
return m;
156+
}, py::return_value_policy::reference);
157+
142158
// Same, but accepts a matrix of any strides
143159
m.def("incr_matrix_any", [](py::EigenDRef<Eigen::MatrixXd> m, double v) {
144160
m += Eigen::MatrixXd::Constant(m.rows(), m.cols(), v);
@@ -173,12 +189,16 @@ TEST_SUBMODULE(eigen, m) {
173189
// return value referencing/copying tests:
174190
class ReturnTester {
175191
Eigen::MatrixXd mat = create();
192+
DenseADScalarMatrixR ad_mat = create_ADScalar_mat();
176193
public:
177194
ReturnTester() { print_created(this); }
178195
~ReturnTester() { print_destroyed(this); }
179-
static Eigen::MatrixXd create() { return Eigen::MatrixXd::Ones(10, 10); }
196+
static Eigen::MatrixXd create() { return Eigen::MatrixXd::Ones(10, 10); }
197+
static DenseADScalarMatrixR create_ADScalar_mat() { DenseADScalarMatrixR ad_mat(2, 2);
198+
ad_mat << 1, 2, 3, 7; return ad_mat; }
180199
static const Eigen::MatrixXd createConst() { return Eigen::MatrixXd::Ones(10, 10); }
181200
Eigen::MatrixXd &get() { return mat; }
201+
DenseADScalarMatrixR& get_ADScalarMat() {return ad_mat;}
182202
Eigen::MatrixXd *getPtr() { return &mat; }
183203
const Eigen::MatrixXd &view() { return mat; }
184204
const Eigen::MatrixXd *viewPtr() { return &mat; }
@@ -197,6 +217,7 @@ TEST_SUBMODULE(eigen, m) {
197217
.def_static("create", &ReturnTester::create)
198218
.def_static("create_const", &ReturnTester::createConst)
199219
.def("get", &ReturnTester::get, rvp::reference_internal)
220+
.def("get_ADScalarMat", &ReturnTester::get_ADScalarMat, rvp::reference_internal)
200221
.def("get_ptr", &ReturnTester::getPtr, rvp::reference_internal)
201222
.def("view", &ReturnTester::view, rvp::reference_internal)
202223
.def("view_ptr", &ReturnTester::view, rvp::reference_internal)
@@ -216,6 +237,18 @@ TEST_SUBMODULE(eigen, m) {
216237
.def("corners_const", &ReturnTester::cornersConst, rvp::reference_internal)
217238
;
218239

240+
py::class_<ADScalar>(m, "AutoDiffXd")
241+
.def("__init__",
242+
[](ADScalar & self,
243+
double value,
244+
const Eigen::VectorXd& derivatives) {
245+
new (&self) ADScalar(value, derivatives);
246+
})
247+
.def("value", [](const ADScalar & self) {
248+
return self.value();
249+
})
250+
;
251+
219252
// test_special_matrix_objects
220253
// Returns a DiagonalMatrix with diagonal (1,2,3,...)
221254
m.def("incr_diag", [](int k) {
@@ -300,6 +333,9 @@ TEST_SUBMODULE(eigen, m) {
300333
m.def("iss1105_col", [](Eigen::VectorXd) { return true; });
301334
m.def("iss1105_row", [](Eigen::RowVectorXd) { return true; });
302335

336+
m.def("iss1105_col_obj", [](VectorXADScalar) { return true; });
337+
m.def("iss1105_row_obj", [](VectorXADScalarR) { return true; });
338+
303339
// test_named_arguments
304340
// Make sure named arguments are working properly:
305341
m.def("matrix_multiply", [](const py::EigenDRef<const Eigen::MatrixXd> A, const py::EigenDRef<const Eigen::MatrixXd> B)

0 commit comments

Comments
 (0)