Skip to content

Commit aefe2c0

Browse files
committed
roundtrip test via reference passed to aliased class method
Probably the test is failing, because it passes the arguments by value instead of by reference.
1 parent 612a597 commit aefe2c0

File tree

2 files changed

+133
-0
lines changed

2 files changed

+133
-0
lines changed

tests/test_class_sh_with_alias.cpp

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,14 +73,74 @@ void wrap(py::module_ m, const char *py_class_name) {
7373
m.def("AddInCppUniquePtr", AddInCppUniquePtr<SerNo>, py::arg("obj"), py::arg("other_val"));
7474
}
7575

76+
struct Passenger {
77+
std::string mtxt;
78+
Passenger(const std::string &txt = "DefaultCtor") : mtxt(txt) {}
79+
Passenger(const Passenger &other) { mtxt = other.mtxt + "_CpCtor"; }
80+
Passenger(Passenger &&other) { mtxt = other.mtxt + "_MvCtor"; }
81+
};
82+
struct ConsumerBase {
83+
ConsumerBase() = default;
84+
ConsumerBase(const ConsumerBase &) = default;
85+
ConsumerBase(ConsumerBase &&) = default;
86+
virtual ~ConsumerBase() = default;
87+
virtual void pass_uq_cref(const std::unique_ptr<Passenger> &obj) { obj->mtxt += "_base"; };
88+
virtual void pass_lref(Passenger &obj) { obj.mtxt += "_base"; };
89+
virtual void pass_cref(const Passenger &obj) { const_cast<Passenger &>(obj).mtxt += "_base"; };
90+
};
91+
struct ConsumerBaseAlias : ConsumerBase {
92+
using ConsumerBase::ConsumerBase;
93+
void pass_uq_cref(const std::unique_ptr<Passenger> &obj) override {
94+
PYBIND11_OVERRIDE(void, ConsumerBase, pass_uq_cref, obj);
95+
}
96+
void pass_lref(Passenger &obj) override {
97+
PYBIND11_OVERRIDE(void, ConsumerBase, pass_lref, obj);
98+
}
99+
void pass_cref(const Passenger &obj) override {
100+
PYBIND11_OVERRIDE(void, ConsumerBase, pass_cref, obj);
101+
}
102+
};
103+
104+
// check roundtrip of Passenger send to ConsumerBaseAlias
105+
// TODO: Find template magic to avoid code duplication
106+
std::string check_roundtrip_uq_cref(ConsumerBase &consumer) {
107+
std::unique_ptr<Passenger> obj(new Passenger(""));
108+
consumer.pass_uq_cref(obj);
109+
return obj->mtxt;
110+
}
111+
std::string check_roundtrip_lref(ConsumerBase &consumer) {
112+
Passenger obj("");
113+
consumer.pass_lref(obj);
114+
return obj.mtxt;
115+
}
116+
std::string check_roundtrip_cref(ConsumerBase &consumer) {
117+
Passenger obj("");
118+
consumer.pass_cref(obj);
119+
return obj.mtxt;
120+
}
121+
76122
} // namespace test_class_sh_with_alias
77123
} // namespace pybind11_tests
78124

79125
PYBIND11_SMART_HOLDER_TYPE_CASTERS(pybind11_tests::test_class_sh_with_alias::Abase<0>)
80126
PYBIND11_SMART_HOLDER_TYPE_CASTERS(pybind11_tests::test_class_sh_with_alias::Abase<1>)
127+
PYBIND11_SMART_HOLDER_TYPE_CASTERS(pybind11_tests::test_class_sh_with_alias::Passenger)
128+
PYBIND11_SMART_HOLDER_TYPE_CASTERS(pybind11_tests::test_class_sh_with_alias::ConsumerBase)
81129

82130
TEST_SUBMODULE(class_sh_with_alias, m) {
83131
using namespace pybind11_tests::test_class_sh_with_alias;
84132
wrap<0>(m, "Abase0");
85133
wrap<1>(m, "Abase1");
134+
135+
py::classh<Passenger>(m, "Passenger").def_readwrite("mtxt", &Passenger::mtxt);
136+
137+
py::classh<ConsumerBase, ConsumerBaseAlias>(m, "ConsumerBase")
138+
.def(py::init<>())
139+
.def("pass_uq_cref", &ConsumerBase::pass_uq_cref)
140+
.def("pass_lref", &ConsumerBase::pass_lref)
141+
.def("pass_cref", &ConsumerBase::pass_cref);
142+
143+
m.def("check_roundtrip_uq_cref", check_roundtrip_uq_cref);
144+
m.def("check_roundtrip_lref", check_roundtrip_lref);
145+
m.def("check_roundtrip_cref", check_roundtrip_cref);
86146
}

tests/test_class_sh_with_alias.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# -*- coding: utf-8 -*-
22
import pytest
3+
import env # noqa: F401
34

45
from pybind11_tests import class_sh_with_alias as m
56

@@ -56,3 +57,75 @@ def test_drvd1_add_in_cpp_unique_ptr():
5657
drvd = PyDrvd1(25)
5758
assert m.AddInCppUniquePtr(drvd, 83) == ((25 * 10 + 3) * 200 + 83) * 100 + 13
5859
return # Comment out for manual leak checking (use `top` command).
60+
61+
62+
class PyConsumer1(m.ConsumerBase):
63+
def __init__(self):
64+
m.ConsumerBase.__init__(self)
65+
66+
def pass_uq_cref(self, obj):
67+
obj.mtxt = obj.mtxt + "pass_uq_cref"
68+
69+
def pass_lref(self, obj):
70+
obj.mtxt = obj.mtxt + "pass_lref"
71+
72+
def pass_cref(self, obj):
73+
obj.mtxt = obj.mtxt + "pass_cref"
74+
75+
76+
class PyConsumer2(m.ConsumerBase):
77+
"""This one, additionally to PyConsumer1 calls its base methods"""
78+
79+
def __init__(self):
80+
m.ConsumerBase.__init__(self)
81+
82+
def pass_uq_cref(self, obj):
83+
obj.mtxt = obj.mtxt + "pass_uq_cref"
84+
m.ConsumerBase.pass_uq_cref(self, obj)
85+
86+
def pass_lref(self, obj):
87+
obj.mtxt = obj.mtxt + "pass_lref"
88+
m.ConsumerBase.pass_lref(self, obj)
89+
90+
def pass_cref(self, obj):
91+
obj.mtxt = obj.mtxt + "pass_cref"
92+
m.ConsumerBase.pass_cref(self, obj)
93+
94+
95+
# roundtrip tests, creating an object in C++ that is passed by reference
96+
# to a virtual method of a class derived in Python. Thus:
97+
# C++ -> Python -> C++
98+
@pytest.mark.parametrize(
99+
"f, expected",
100+
[
101+
(m.check_roundtrip_uq_cref, "pass_uq_cref"),
102+
(m.check_roundtrip_lref, "pass_lref"), # modification passed through 1:1
103+
pytest.param(
104+
m.check_roundtrip_cref,
105+
"", # modification lost (forbidden due to constness)
106+
marks=pytest.mark.skipif("env.PYPY"),
107+
),
108+
],
109+
)
110+
def test_unique_ptr_consumer1_roundtrip(f, expected):
111+
c = PyConsumer1()
112+
assert f(c) == expected
113+
114+
115+
@pytest.mark.parametrize(
116+
"f, expected",
117+
[
118+
pytest.param( # cannot (yet) pass unowned const unique_ptr&
119+
m.check_roundtrip_uq_cref, "pass_uq_cref_base", marks=pytest.mark.xfail
120+
),
121+
(m.check_roundtrip_lref, "pass_lref_base"), # modification passed through 1:1
122+
pytest.param( # PYPY always copies the argument instead of passing the reference
123+
m.check_roundtrip_cref,
124+
"", # modification lost (forbidden due to constness)
125+
marks=pytest.mark.skipif("env.PYPY"),
126+
),
127+
],
128+
)
129+
def test_unique_ptr_consumer2_roundtrip(f, expected):
130+
c = PyConsumer2()
131+
assert f(c) == expected

0 commit comments

Comments
 (0)