Skip to content

Passing reference arguments to trampoline methods #2915

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/pybind11/pybind11.h
Original file line number Diff line number Diff line change
Expand Up @@ -2179,7 +2179,7 @@ template <class T> function get_override(const T *this_ptr, const char *name) {
pybind11::gil_scoped_acquire gil; \
pybind11::function override = pybind11::get_override(static_cast<const cname *>(this), name); \
if (override) { \
auto o = override(__VA_ARGS__); \
auto o = override.operator()<pybind11::return_value_policy::reference>(__VA_ARGS__); \
if (pybind11::detail::cast_is_temporary_value_reference<ret_type>::value) { \
static pybind11::detail::override_caster_t<ret_type> caster; \
return pybind11::detail::cast_ref<ret_type>(std::move(o), caster); \
Expand Down
91 changes: 91 additions & 0 deletions tests/test_virtual_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,79 @@ template <class Base = D_Tpl> class PyD_Tpl : public PyC_Tpl<Base> {
};
*/

/* Tests passing objects by reference to Python-derived class methods */
struct Passenger { // object class passed around, recording its copy and move constructions
std::string mtxt;
Passenger() = default;
// on copy or move: keep old mtxt and augment operation as well as new pointer id
Passenger(const Passenger &other) { mtxt = other.mtxt + "Copy->" + std::to_string(id()); }
Passenger(Passenger &&other) { mtxt = other.mtxt + "Move->" + std::to_string(id()); }
uintptr_t id() const { return reinterpret_cast<uintptr_t>(this); }
};
struct ReferencePassingTest { // virtual base class used to test reference passing
ReferencePassingTest() = default;
ReferencePassingTest(const ReferencePassingTest &) = default;
ReferencePassingTest(ReferencePassingTest &&) = default;
virtual ~ReferencePassingTest() = default;
// NOLINTNEXTLINE(clang-analyzer-core.StackAddrEscapeBase)
virtual uintptr_t pass_valu(Passenger obj) { return modify(obj); };
virtual uintptr_t pass_mref(Passenger &obj) { return modify(obj); };
virtual uintptr_t pass_mptr(Passenger *obj) { return modify(*obj); };
virtual uintptr_t pass_cref(const Passenger &obj) { return modify(obj); };
virtual uintptr_t pass_cptr(const Passenger *obj) { return modify(*obj); };
uintptr_t modify(const Passenger &obj) { return obj.id(); }
uintptr_t modify(Passenger &obj) {
obj.mtxt.append("_MODIFIED");
return obj.id();
}
};
struct PyReferencePassingTest : ReferencePassingTest {
using ReferencePassingTest::ReferencePassingTest;
uintptr_t pass_valu(Passenger obj) override {
PYBIND11_OVERRIDE(uintptr_t, ReferencePassingTest, pass_valu, obj);
}
uintptr_t pass_mref(Passenger &obj) override {
PYBIND11_OVERRIDE(uintptr_t, ReferencePassingTest, pass_mref, obj);
}
uintptr_t pass_cref(const Passenger &obj) override {
PYBIND11_OVERRIDE(uintptr_t, ReferencePassingTest, pass_cref, obj);
}
uintptr_t pass_mptr(Passenger *obj) override {
PYBIND11_OVERRIDE(uintptr_t, ReferencePassingTest, pass_mptr, obj);
}
uintptr_t pass_cptr(const Passenger *obj) override {
PYBIND11_OVERRIDE(uintptr_t, ReferencePassingTest, pass_cptr, obj);
}
};

std::string evaluate(const Passenger &orig, uintptr_t cycled) {
return orig.mtxt + (orig.id() == cycled ? "_REF" : "_COPY");
}
// Functions triggering virtual-method calls from python-derived class (caller)
// Goal: modifications to Passenger happening in Python-code methods
// overriding the C++ virtual methods, should remain visible in C++.
// TODO: Find template magic to avoid this code duplication
std::string check_roundtrip_valu(ReferencePassingTest &caller) {
Passenger obj;
return evaluate(obj, caller.pass_valu(obj));
}
std::string check_roundtrip_mref(ReferencePassingTest &caller) {
Passenger obj;
return evaluate(obj, caller.pass_mref(obj));
}
std::string check_roundtrip_cref(ReferencePassingTest &caller) {
Passenger obj;
return evaluate(obj, caller.pass_cref(obj));
}
std::string check_roundtrip_mptr(ReferencePassingTest &caller) {
Passenger obj;
return evaluate(obj, caller.pass_mptr(&obj));
}
std::string check_roundtrip_cptr(ReferencePassingTest &caller) {
Passenger obj;
return evaluate(obj, caller.pass_cptr(&obj));
}

void initialize_inherited_virtuals(py::module_ &m) {
// test_inherited_virtuals

Expand Down Expand Up @@ -495,4 +568,22 @@ void initialize_inherited_virtuals(py::module_ &m) {
// Fix issue #1454 (crash when acquiring/releasing GIL on another thread in Python 2.7)
m.def("test_gil", &test_gil);
m.def("test_gil_from_thread", &test_gil_from_thread);

py::class_<Passenger>(m, "Passenger")
.def_property_readonly("id", &Passenger::id)
.def_readwrite("mtxt", &Passenger::mtxt);

py::class_<ReferencePassingTest, PyReferencePassingTest>(m, "ReferencePassingTest")
.def(py::init<>())
.def("pass_valu", &ReferencePassingTest::pass_valu)
.def("pass_mref", &ReferencePassingTest::pass_mref)
.def("pass_cref", &ReferencePassingTest::pass_cref)
.def("pass_mptr", &ReferencePassingTest::pass_mptr)
.def("pass_cptr", &ReferencePassingTest::pass_cptr);

m.def("check_roundtrip_valu", check_roundtrip_valu);
m.def("check_roundtrip_mref", check_roundtrip_mref);
m.def("check_roundtrip_cref", check_roundtrip_cref);
m.def("check_roundtrip_mptr", check_roundtrip_mptr);
m.def("check_roundtrip_cptr", check_roundtrip_cptr);
};
132 changes: 132 additions & 0 deletions tests/test_virtual_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,3 +406,135 @@ def test_issue_1454():
# Fix issue #1454 (crash when acquiring/releasing GIL on another thread in Python 2.7)
m.test_gil()
m.test_gil_from_thread()


# Python class inheriting from C++ class ReferencePassingTest
# virtual methods modify the obj's mtxt, which should become visible in C++
# To ensure that the original object instance was passed through,
# the pointer id of the received obj is returned by all pass_*() functions
# (and compared by the C++ caller with the originally passed obj id).
class PyReferencePassingTest1(m.ReferencePassingTest):
def __init__(self):
m.ReferencePassingTest.__init__(self)

def pass_valu(self, obj):
obj.mtxt = obj.mtxt + "pass_valu"
return obj.id

def pass_mref(self, obj):
obj.mtxt = obj.mtxt + "pass_mref"
return obj.id

def pass_mptr(self, obj):
obj.mtxt = obj.mtxt + "pass_mptr"
return obj.id

def pass_cref(self, obj):
with pytest.raises(Exception): # should be forbidden
obj.mtxt = obj.mtxt + "pass_cref"
return obj.id

def pass_cptr(self, obj):
with pytest.raises(Exception): # should be forbidden
obj.mtxt = obj.mtxt + "pass_cptr"
return obj.id


# This class, in contrast to PyReferencePassingTest1, calls the base class methods as well,
# which will augment mtxt with a _MODIFIED stamp.
# These calls to the base class methods actually result in a 2nd call to the
# trampoline override dispatcher, requiring argument loading, which should pass
# references through as well, to make these tests succeed.
# argument is passed like this: C++ -> Python (call #1) -> C++ (call #2).
class PyReferencePassingTest2(m.ReferencePassingTest):
def __init__(self):
m.ReferencePassingTest.__init__(self)

def pass_valu(self, obj):
obj.mtxt = obj.mtxt + "pass_valu"
return m.ReferencePassingTest.pass_valu(self, obj)

def pass_mref(self, obj):
obj.mtxt = obj.mtxt + "pass_mref"
return m.ReferencePassingTest.pass_mref(self, obj)

def pass_mptr(self, obj):
obj.mtxt = obj.mtxt + "pass_mptr"
return m.ReferencePassingTest.pass_mptr(self, obj)

def pass_cref(self, obj):
with pytest.raises(Exception): # should be forbidden
obj.mtxt = obj.mtxt + "pass_cref"
return m.ReferencePassingTest.pass_cref(self, obj)

def pass_cptr(self, obj):
with pytest.raises(Exception): # should be forbidden
obj.mtxt = obj.mtxt + "pass_cptr"
return m.ReferencePassingTest.pass_cptr(self, obj)


# roundtrip tests, creating a Passenger object in C++ that is passed by reference
# to a virtual method of a class derived in Python (PyReferencePassingTest1).
# If the object is correctly passed by reference, modifications should be visible
# by the C++ caller. The final obj's mtxt is returned by the check_* functions
# and validated here. Expected scheme: <func name>_[REF|COPY]
@pytest.mark.parametrize(
"f, expected",
[
(m.check_roundtrip_valu, "_COPY"), # modification not passed back to C++
(m.check_roundtrip_mref, "pass_mref_REF"),
(m.check_roundtrip_mptr, "pass_mptr_REF"),
],
)
def test_refpassing1_roundtrip_modifyable(f, expected):
c = PyReferencePassingTest1()
assert f(c) == expected


@pytest.mark.parametrize(
"f, expected",
[
# object passed as reference, but not modified
(m.check_roundtrip_cref, "_REF"),
(m.check_roundtrip_cptr, "_REF"),
],
)
# PYPY always copies the argument (to ensure constness?)
@pytest.mark.skipif("env.PYPY")
@pytest.mark.xfail # maintaining constness isn't implemented yet
def test_refpassing1_roundtrip_const(f, expected):
c = PyReferencePassingTest1()
assert f(c) == expected


# Similar test as above, but now using PyReferencePassingTest2, calling
# to the C++ base class methods as well.
# Expected mtxt scheme: <func name>_MODIFIED_[REF|COPY]
@pytest.mark.parametrize(
"f, expected",
[
# object copied, modification not passed back to C++
(m.check_roundtrip_valu, "_COPY"),
(m.check_roundtrip_mref, "pass_mref_MODIFIED_REF"),
(m.check_roundtrip_mptr, "pass_mptr_MODIFIED_REF"),
],
)
def test_refpassing2_roundtrip_modifyable(f, expected):
c = PyReferencePassingTest2()
assert f(c) == expected


@pytest.mark.parametrize(
"f, expected",
[
# object passed as reference, but not modified
(m.check_roundtrip_cref, "_REF"),
(m.check_roundtrip_cptr, "_REF"),
],
)
# PYPY always copies the argument (to ensure constness?)
@pytest.mark.skipif("env.PYPY")
@pytest.mark.xfail # maintaining constness isn't implemented yet
def test_refpassing2_roundtrip_const(f, expected):
c = PyReferencePassingTest2()
assert f(c) == expected